In [None]:
def resnet_v2(input_shape, depth, n_classes=10):
    if (depth - 2) % 9 != 0:
        raise ValueError(
            'Depth must be 9n + 2 (56, 65, 74, 82, 91, 110...)')
    n_filters_in = 16
    n_res_blocks = int((depth -2) / 9)
    inputs = Input(shape=input_shape)
    
    # v2 performs Conv2D with BN-ReLU on input before splitting into 2 
    # paths
    x = resnet_layer(
        inputs=inputs, num_filters=n_filters_in, conv_first=True)
    
    # instantiate stack of res units
    for stage in range(3):
        for res_block in range(n_res_blocks):
            activation = 'relu'
            batch_normalization = True
            strides = 1
            if stage == 0:
                n_filters_out = 4 * n_filters_in
                if res_block == 0:
                    activation = None
                    batch_normalization = False
            else:
                n_filters_out = 2 * n_filters_in
                if res_block == 0:
                    strides = 2 # down sample
                    
            # bottleneck res unit
            y = resnet_layer(inputs=x,
                             num_filters=n_filters_in,
                             kernel_size=1,
                             strides=1,
                             activation=activation,
                             batch_normalization=batch_normalization,
                             conv_first=False)
            y = resnet_layer(inputs=y,
                             num_filters=n_filters_in,
                             conv_first=False)
            y = resnet_layer(inputs=y,
                             num_filters=n_filters_out,
                             kernel_size=1,
                             conv_first=False)
            if res_block == 0:
                # linear projection res shortcut connection to match dims
                x = resnet_layer(inputs=x,
                                 num_filters=n_filters_out,
                                 kernel_size=1,
                                 strides=strides,
                                 activation=None,
                                 batch_normalization=False)
            x = add([x, y])
        n_filters_in = n_filters_out
    
    # Add classifier on top; v2 has BN-ReLU before pooling
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = AveragePooling2D(pool_size=8)(x)
    y = Flatten()(x)
    outputs = Dense(
        n_classes, 
        activation='softmax', 
        kernel_initialization='he_normal'
    )(y)
    
    # instantiate model
    mod = Model(inputs=inputs, outputs=outputs)
    return mod