# Wide Residual Networks: notes on parameters and implementation

There are three key parameters associated with a Wide Residual Networks (WRNs), $k$, $n$ and $N$. The architecture WRN-$n$-$k$ will have $n$ convolutional layers and a widening factor $k$. The most important section of the model is composed of 3 blocks where the number of convolutional filters in block $i$ where $i \in \{0, 1, 2\}$ is given by $16\times 2^i \times k$. 

#### Table 1 from the paper showing the structure of a basic WRN classifier (excluding the final dense layer)
<img src = 'wrn_table1.png'/>

What confused me for bit from looking at the [official implementation](https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py) was that the depth was required to be $6*N + 4$. The code actually states `6n + 4` but from the fact that it uses `n` to determine the number of pairs of convolutional layers per block it is evident that lower-case `n` in the code corresponds to the upper-case $N$ in the paper. What the code describes as `depth` is the number of convolutional layers $n$. The table above does not illustrate the skip connections in the residual blocks. If $k > 1$ then first skip connection in each block will have convolutional layer to make the number of channels of the skip connection equal to the block output. 

Thus for $k > 1$, there is $1$ initial convolutional layer and $N \times 2$ layers and $1$ layer in the skip connection for each of $3$ blocks so $3\cdot(2\cdot N + 1) + 1 = 6\cdot N + 4$

If $k = 1$ then there will be one less skip convolutional layer so the total number will not actually be $ 6\cdot N + 4$. However since *wide* residual networks are defined in the paper as those where $k > 1$, this formula is true for all WRNs.

In [13]:
from keras.models import Model
from keras.layers import Input, Add, Activation, Dropout, Flatten, Dense
from keras.layers.convolutional import Convolution2D, MaxPooling2D, AveragePooling2D
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
from keras import backend as K

weight_decay = 0.0005

class WRNBlock(object):
    def __init__(self, num, k=1, N=2, channel_axis=-1, dropout=0.0):
        self.num = num
        self.k = k
        self.N = N
        self.channel_axis = channel_axis
        self.dropout = dropout
        
    @property
    def base(self):
        return 16*(2**self.num)
    
    def _conv(self, x, kernel_size=(3, 3), strides=(1, 1)):
        return Convolution2D(self.base * self.k, 
                             kernel_size=kernel_size, 
                    strides=strides,
                    padding='same', 
                      kernel_initializer='he_normal',
                      W_regularizer=l2(weight_decay),
                      use_bias=False)(x)
        
    def _bn_relu_conv(self, x, strides=(1, 1)):
        x = BatchNormalization(axis=self.channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
        x = Activation('relu')(x)
        x = self._conv(x, strides=strides)
        return x
        
    
    def _res_block(self, x, strides1=(1, 1)):
        skip = x
        x = self._bn_relu_conv(x, strides1)
        x = self._bn_relu_conv(x)
        
        if K.int_shape(x)[self.channel_axis] != K.int_shape(skip)[self.channel_axis]:
            skip = self._conv(skip, kernel_size=(1, 1), strides=strides1)
        return Add()([x, skip])
    
    def __call__(self, x):
        for i in range(self.N):
            strides1 = (2, 2) if ((i == 0) and (self.num>0)) else (1, 1)
            x = self._res_block(x, strides1=strides1)
        return x
            
def create_wide_residual_network(input_dim, n_classes=100, N=2, k=1, dropout=0.0, verbose=1):
    """
    Creates a Wide Residual Network with specified parameters
    :param input_dim: tuple input dimensions to be passed into Keras Input
    :param nb_classes: Number of output classes
    :param N: Depth of the network. Compute N = (n - 4) / 6.
              Example : For a depth of 16, n = 16, N = (16 - 4) / 6 = 2
              Example2: For a depth of 28, n = 28, N = (28 - 4) / 6 = 4
              Example3: For a depth of 40, n = 40, N = (40 - 4) / 6 = 6
    :param k: Width of the network.
    :param dropout: Adds dropout if value is greater than 0.0
    :param verbose: Debug info to describe created WRN
    :return:
    """
    channel_axis = 1 if K.image_data_format() == "channels_first" else -1

    inpt = Input(shape=input_dim)
    
    #     def conv_params(ni, no, k=1):
    #         return kaiming_normal_(torch.Tensor(no, ni, k, k)) -> kernel_initializer='he_normal'
    #     Note there is no bias
    x =  Convolution2D(16, (3, 3), padding='same', kernel_initializer='he_normal',
                      W_regularizer=l2(weight_decay),
                      use_bias=False)(inpt)
    
    for i in range(3):
        x = WRNBlock(num=i, k=k, N=N, channel_axis=channel_axis, dropout=dropout)(x)
    
    # Official 
    #     def bnparams(n):
    #         return {'weight': torch.rand(n), -> gamma_initializer='uniform'
    #                 'bias': torch.zeros(n),
    #                 'running_mean': torch.zeros(n),
    #                 'running_var': torch.ones(n)}
    #
    # From PyTorch docs (https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html)
    #     eps: a value added to the denominator for numerical stability.
    #     Default: 1e-5
    #     momentum: the value used for the running_mean and running_var
    #         computation. Can be set to ``None`` for cumulative moving average
    #         (i.e. simple average). Default: 0.1
    x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
    x = Activation('relu')(x)

    x = AveragePooling2D((8, 8))(x)
    x = Flatten()(x)

    #     def linear_params(ni, no):
    #         return {'weight': kaiming_normal_(torch.Tensor(no, ni)), -> kernel_initializer='he_normal' 
    #                 'bias': torch.zeros(no)}
    x = Dense(n_classes, W_regularizer=l2(weight_decay), activation='softmax', kernel_initializer='he_normal')(x)

    model = Model(inpt, x)
        
    return model
    
if __name__ == "__main__":
    from keras.utils import plot_model
    from keras.layers import Input
    from keras.models import Model

    init = (32, 32, 3)

    wrn_28_10 = create_wide_residual_network(init, n_classes=10, N=10, k=4, dropout=0.0)

    wrn_28_10.summary()

    plot_model(wrn_28_10, "WRN-28-10.png", show_shapes=True, show_layer_names=True)



__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
conv2d_65 (Conv2D)              (None, 32, 32, 16)   432         input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_62 (BatchNo (None, 32, 32, 16)   64          conv2d_65[0][0]                  
__________________________________________________________________________________________________
activation_62 (Activation)      (None, 32, 32, 16)   0           batch_normalization_62[0][0]     
__________________________________________________________________________________________________
conv2d_66 

> In all our experiments we use SGD with Nesterov momentum and cross-entropy loss. The initial learning rate is set to 0.1, weight decay to 0.0005, dampening to 0, momentum to 0.9 and minibatch size to 128. On CIFAR learning rate dropped by 0.2 at 60, 120 and 160 epochs and we train for total 200 epochs. On

In [14]:
import numpy as np
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import np_utils

In [15]:
(x_trainval, y_trainval), (x_test, y_test) = cifar10.load_data()
y_trainval = np_utils.to_categorical(y_trainval)
y_test = np_utils.to_categorical(y_test)
np.random.seed(456)
n_images = len(x_trainval)
n_train = 40000
n_valid = n_images - n_train
n_test = len(x_test)
splits = np.split(np.random.permutation(len(x_trainval)), [n_train])
(x_train, y_train), (x_val, y_val) = [(x_trainval[inds], y_trainval[inds])
                                     for inds in splits]

mean_train = np.mean(x_train, axis=(0, 1, 2), keepdims=True)[0]


>We reproduce the results of Zagoruyko &Komodakis (2016) with the same settings except that i) we subtract per-pixel mean only and do not use ZCA whitening

In [16]:
def random_crop(image):
    padding = [(4, 4), (4, 4), (0, 0)]
    image = np.pad(image, pad_width=padding, mode='reflect')
    start = np.random.randint(0, image.shape[0]-32)
    slc = slice(start, start+32)
    return image[slc, slc]

In [17]:
train_datagen = ImageDataGenerator(featurewise_center=True, 
                                   horizontal_flip=True, 
                                   preprocessing_function=lambda x: random_crop,
                                   rescale=1/255.)

test_datagen = ImageDataGenerator(featurewise_center=True, rescale=1/255.)

train_datagen.fit(x_train)
test_datagen.fit(x_train)

In [18]:
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.optimizers import SGD

In [24]:
def step_schedule(index, lr):
    if index == 0:
        assert np.close(lr, 0.1)
    assert index < 200
    if (index + 1) in [60, 120, 160]:
        return lr*0.2
    return lr

wrn_28_10.compile(optimizer=SGD(lr=0.1, momentum=0.9, nesterov=False),
                  loss='categorical_crossentropy',
                  metrics=['acc'])
                  

In [25]:
batch_size = 128
train_gen = train_datagen.flow(x_train, y_train, batch_size=batch_size, shuffle=True)
valid_gen = test_datagen.flow(x_val, y_val, batch_size=batch_size, shuffle=False)
wrn_28_10.fit_generator(
        train_gen,
        steps_per_epoch = np.ceil(n_train/batch_size).astype('int'),
        epochs = 200,
        validation_data=valid_gen,
        validation_steps = np.ceil(n_valid/batch_size).astype('int'),
        callbacks=[ModelCheckpoint('WRN-28-10-basic', monitor='val_acc', save_best_only=True),
                            LearningRateScheduler(step_schedule)])

    

Epoch 1/200
0.10000000149011612


AssertionError: 