# Cleaner implementation of Difference Target Propagation (DTP)

In this notebook, I propose a more elegant and generalizable implementation of DTP which essentially builds on customized backward. 
In the end, the gradients computed by DTP could be simply given by ```loss.backward()```.

Note: although the code here is complete, I could not test it yet since the import of ```cudnn_convolution_transpose``` does not work.

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
import numpy as np
from torch.utils.cpp_extension import load

# load the PyTorch extension
cudnn_convolution = load(name="cudnn_convolution", sources=["cudnn_convolution.cpp"], verbose=True)
cudnn_transpose_convolution = load(name="cudnn_convolution_transpose", sources=["cudnn_convolution_transpose.cpp"],verbose=True)

### Debugging cudnn imports

Here, we just want to make sure the imported ```cudnn_convolution``` and ```cudnn_transpose_convolution``` work. Of course, this is just a sanity check: the final goal is to wrap them with a subclass of ```nn.Module``` to write customized layers.

In [None]:
# create dummy input, convolutional weights and bias
weight = torch.zeros(64, 3, 5, 5).to('cuda')
bias   = torch.zeros(64).to('cuda')

stride   = (2, 2)
padding  = (0, 0)
output_padding  = (0, 0)
dilation = (1, 1)
groups   = 1


input  = torch.zeros(128, 3, 32, 32).to('cuda')

# compute the result of convolution
output = cudnn_convolution.convolution(input, weight, bias, stride, padding, dilation, groups, False, False)

# create dummy gradient w.r.t. the output
grad_output = torch.zeros(128, 64, 14, 14).to('cuda')

# compute the gradient w.r.t. the weights and input
grad_weight = cudnn_convolution.convolution_backward_weight(input, weight.shape, grad_output, stride, padding, dilation, groups, False, False, False)
grad_input  = cudnn_convolution.convolution_backward_input(input.shape, weight, grad_output, stride, padding, dilation, groups, False, False, False)
#grad_input = cudnn_convolution.convolution_transpose(grad_output, weight, bias, stride, padding, output_padding, dilation, groups, False, False)

'''
transpose_conv_output = cudnn_transpose_convolution.convolution_transpose(
    grad_output,
    weight,
    bias,
    padding,
    output_padding,
    stride,
    dilation,
    groups,
    False,
    False,
)
'''
print(output.shape)
print(grad_weight.shape)
print(grad_input.shape)


### Customizing fully connected autoencoders 

Here, we define the ```neuron_fc``` class which is where we customize the ```backward``` for DTP. Then, we wrap it with the ```layer_fc``` class. 

**Important**: note that depending on whether we want to train the feedback weights (```w_b_learning=True```) or the feedforward weights (```back=True```), the ```backward``` method of ```neuron_fc``` behaves differently. 

In [None]:
#Define the layer_fc class

class neuron_fc(torch.autograd.Function):
    '''
    Defines a fully connected autoencoder with customized backward where:
    F: x -> Linear layer
    G: y -> Linear layer 
    (not activation function specified since it is meant to compute the logits of the 
    final classification layer before being passed through the softmax function)
    '''
    
    @staticmethod
    def forward(ctx, input, w_f, bias_f, w_b, bias_b, beta, w_b_learning, back, extra_input):
        '''
        w_f, bias_f: feedforward parameters
        w_b, bias_b: feedback parameters
        w_b_learning: specifies whether we want to train feedback weights
        back: specifies whether we want to train the feedforward weights
        beta: nudging strength in the output layer
        extra_input: extra input injected in the output of the autoencoder (needed for feedback weights training)
        '''
        if w_b_learning:
            if extra_input is not None:
                y = extra_input
            else:
                input_flat = input.view(input.size(0), -1)
                y = F.linear(input_flat, w_f, bias_f)
            
            r = F.linear(y, w_b, bias_b)
            r = r.view(input.size())
            ctx.save_for_backward(y)
            ctx.w_b_learning = w_b_learning

            return y, r

        elif back:
            input_flat = input.view(input.size(0), - 1)
            y = F.linear(input_flat, w_f, bias_f)      
            r = F.linear(y, w_b, bias_b)
            r = r.view(input.size())
            ctx.save_for_backward(input, y, r, w_b, bias_b)
            ctx.beta = beta
            ctx.w_b_learning = w_b_learning

            return y

        else:
            input_flat = input.view(input.size(0), - 1)
            y = F.linear(input_flat, w_f, bias_f)
            return y

    @staticmethod
    def backward(ctx, grad_y, *args):        
        '''
        Note: the backward method changes according to whether 
        we train the feedback weights (w_b_learning=True) or
        the feedforward weights otherwise (in which case we compute the targets)
        '''
        
        w_b_learning = ctx.w_b_learning
        
        grad_beta = None
        grad_w_b_learning = None
        grad_back = None
        grad_extra_input = None

        if w_b_learning:
            grad_r = args[0]
            y = ctx.saved_tensors[0]
            grad_input = None
            grad_w_f = None
            grad_bias_f = None
            grad_r_flat = grad_r.view(grad_r.size(0), -1)
            grad_w_b = grad_r_flat.t().mm(y)
            grad_bias_b = grad_r_flat.sum(0).squeeze(0)

        else:
            input, y, r, w_b, bias_b = ctx.saved_tensors
            beta = ctx.beta
            

            '''Computation of the first target'''
            t = y + beta*grad_y
            
            r_perturb = F.linear(t, w_b, bias_b)
            r_perturb = r_perturb.view(input.size())
            grad_input = r_perturb - r
            input_flat = input.view(input.size(0), - 1)
            grad_w_f = grad_y.t().mm(input_flat)
            grad_bias_f = grad_y.sum(0).squeeze(0)
            grad_w_b = None
            grad_bias_b = None

        return grad_input, grad_w_f, grad_bias_f, grad_w_b, grad_bias_b, grad_beta, grad_w_b_learning, grad_back, grad_extra_input
                   
    

class layer_fc(nn.Module):
    '''
    Defines the final class for fully connected autoencoders,
    including most importantly the forward method and the method
    to train the feedback weights
    '''
    
    def __init__(self, in_size, out_size, beta, iter, noise):
        super(layer_fc, self).__init__()

        self.beta = beta
        self.iter = iter
        self.noise = noise

        self.w_f = nn.Parameter(torch.Tensor(out_size, in_size))
        self.bias_f = nn.Parameter(torch.Tensor(out_size))
        nn.init.normal_(self.w_f, 0, 0.01)
        nn.init.constant_(self.bias_f, 0)
        
        self.w_b = nn.Parameter(torch.Tensor(in_size, out_size))
        self.bias_b = nn.Parameter(torch.Tensor(in_size))
        
        nn.init.normal_(self.w_b, 0, 0.01)
        nn.init.constant_(self.bias_b, 0)
        

    def forward(self, input, w_b_learning = False, back = False, extra_input = None):
        return neuron_fc.apply(input, 
                               self.w_f, self.bias_f, 
                               self.w_b, self.bias_b, 
                               self.beta, w_b_learning, back, extra_input
                              )

    
    def weight_b_train(self, input, optimizer, arg_return = False):
        for iter in range(1, self.iter + 1):
            
            #Uncomment for sanity check: see if the angle between feedforward and feedback weights is decreasing
            '''
            if iter % 50 == 0:
                dist, angle = self.compute_dist_angle()
                print('\n Step {}: Distance = {}, angle = {} \n'.format(iter, dist, angle))
            '''

            
            
            y_temp, r_temp = self(input, w_b_learning = True)

            noise = self.noise*torch.randn_like(input)
            
            y_noise, r_noise = self(input + noise, w_b_learning = True)
            dy = y_noise - y_temp
            dr = r_noise - r_temp
            
            noise_y = self.noise*torch.randn_like(y_temp)
            
            _, r_noise_y = self(input, extra_input = y_temp + noise_y, w_b_learning = True)
            dr_y = (r_noise_y - r_temp)

            
            '''Loss of interest for the feedback weights'''
            loss_b = -2*(noise*dr).view(dr.size(0), -1).sum(1).mean() + (dr_y**2).view(dr_y.size(0), -1).sum(1).mean() 
            
            optimizer.zero_grad() 
            loss_b.backward()
            optimizer.step()
            
            #Uncomment for sanity check: make sure that gradients do not propagate "further than locally"
            '''
            for name, param in self.named_parameters():
                if param.grad is not None: print(name + ' has mean gradient {}'.format(param.grad.mean()))
            '''
  
        if arg_return:
            return loss_b

    def compute_dist_angle(self, *args):
        '''
        Computes distance and angle
        between feedforward and feedback weights
        '''
        F = self.w_f
        G = self.w_b.t()

        dist = torch.sqrt(((F - G)**2).sum()/(F**2).sum())

        F_flat = torch.reshape(F, (F.size(0), -1))
        G_flat = torch.reshape(G, (G.size(0), -1))
        cos_angle = ((F_flat*G_flat).sum(1))/torch.sqrt(((F_flat**2).sum(1))*((G_flat**2).sum(1)))     
        angle = (180.0/np.pi)*(torch.acos(cos_angle).mean().item())

        return dist, angle

Below, we test the ```layer_fc``` class, in particular the forward pass and make sure that when running the algorithm on the feedback weights (**keeping feedforward weights fixed**), feedback weights come into alignment with feedforward weights.

In [None]:
#Test the layer_fc class
device = torch.device('cuda')

batch_size = 128
input_dim = 512
output_dim = 10
beta = 0.5
iter = 200*50
noise = 0.03
lr = 0.035

#1 - Instantiate layer_fc class
dummy_layer = layer_fc(input_dim, output_dim, beta, iter, noise)
dummy_layer.to(device)

#2 - Forward pass
dummy_input = torch.rand(batch_size, 512, 1, 1, requires_grad=True, device=device)
print(dummy_input.shape)
out = dummy_layer(dummy_input)

#3 - Build optimizer for feedback weights
my_list = ['w_b', 'bias_b']
named_params_b = list(filter(lambda kv: kv[0] in my_list, dummy_layer.named_parameters()))
params_b = []
for name, param in named_params_b:
    params_b.append(param)
    print(name + ' has mean {}'.format(param.mean()))

optim_params_b = [{'params': params_b, 'lr': lr}]
optimizer_b = torch.optim.SGD(optim_params_b, momentum = 0.9)

#4 - Train feedback weights
dummy_layer.weight_b_train(dummy_input, optimizer_b)

### Customizing convolutional autoencoders

We proceed in the exact same way to customize the conv layer.

**Important remarks**: 
+ Some hyperparameters have been fixed within the class for simplicity for now, but in the future we will pass an argument parser into the constructor of the class to make it more general.
+ In particular, the ELU function here is being used because, after careful experimental checking, it makes the algorithm on the feedback weights work much better than with the ReLU function. 

In [None]:
#Define the layer_convpool class

class neuron_convpool(torch.autograd.Function):
    '''
    Defines a convolutional autoencoder where:
    F: x -> conv -> activation function -> pooling
    G: y -> unpooling -> activation function -> convtranspose
    '''
    
    @staticmethod
    def forward(ctx, input, w_f, bias_f, 
                w_b, bias_b, 
                stride, padding, w_b_learning, back, extra_input):
        
        '''
        Again here, we distinguish the two cases where we learn the feedback weights
        (w_b_learning=True) of the feedforward weights otherwise. Depending on the case,
        we do not store the same variables for the backward pass
        '''
        
        if w_b_learning:
            if extra_input is not None:
                y, ind = extra_input
            else:
                y = cudnn_convolution.convolution(
                    input, 
                    w_f, 
                    bias_f, 
                    stride,
                    padding, 
                    (1, 1), 
                    1, 
                    False, 
                    False
                )
                
                y, ind = F.max_pool2d(F.elu(y), 2, stride = 2, return_indices=True)

            r = F.max_unpool2d(y, ind, 2, stride = 2, output_size = input.size())
            r_prev = F.elu(r)
                    
            r = cudnn_transpose_convolution.convolution_transpose(
                r_prev,
                w_b,
                bias_b, 
                stride, 
                padding_backward, 
                (1, 1),
                1, 
                False, 
                False
            )    
            
            ctx.save_for_backward(r_prev, w_b)
            ctx.w_b_learning = w_b_learning
            ctx.stride = stride
            ctx.padding = padding
            
            return y, r, ind

        elif back:
            y = cudnn_convolution.convolution(
                input, 
                w_f, 
                bias_f, 
                stride, 
                padding, 
                (1, 1),
                1, 
                False,
                False
            )
            
            y, ind = F.max_pool2d(F.elu(y), 2, stride = 2, return_indices=True)
            ctx.w_b_learning = w_b_learning
            ctx.stride = stride
            ctx.padding = padding
            r_1 = F.elu(F.max_unpool2d(y, ind, 2, stride = 2, output_size = input.size()))
            
            if w_b is not None:
                r_2 = cudnn_transpose_convolution.convolution_transpose(
                    r_1,
                    w_b, 
                    bias_b, 
                    stride, 
                    padding_backward, 
                    (1, 1),
                    1, 
                    False, 
                    False
                )
                
                ctx.save_for_backward(input, ind, r_1, r_2, w_b, bias_b, y)
                
            else:
                ctx.save_for_backward(input, ind, r_1, w_f, y)
            
            return y

        else:
            y = cudnn_convolution.convolution(
                input, 
                w_f, 
                bias_f,
                stride, 
                padding, 
                (1, 1),
                1, 
                False, 
                False
            )
            
            y, ind = F.max_pool2d(F.elu(y), 2, stride = 2, return_indices=True)
            return y
        

    @staticmethod
    def backward(ctx, grad_y, *args):
        w_b_learning = ctx.w_b_learning
        stride = ctx.stride
        padding = ctx.padding
        
        grad_stride = None
        grad_padding = None
        grad_w_b_learning = None
        grad_back = None
        grad_extra_input = None

        #If we train the feedback weights
        if w_b_learning:
            r_prev, w_b = ctx.saved_tensors
            grad_r = args[0]
            grad_input = None
            grad_w_f = None
            grad_bias_f = None
            w_f_shape = torch.transpose(w_b, 0, 1).shape
            grad_w_b = cudnn_transpose_convolution.convolution_transpose_backward_weight(
                r_prev, 
                w_b.shape, 
                grad_r, 
                stride, 
                padding, 
                (1, 1), 
                1, 
                False, 
                False, 
                False
            )
            
            grad_bias_b = torch.sum(grad_r, dim=[0, 2, 3]).squeeze(0)           
            
        #If we train the feedforward weights (and therefore compute the targets)
        else:
            grad_w_b = None
            grad_bias_b = None
            
            if len(ctx.saved_tensors) > 4:
                input, ind, r_1, r_2, w_b, bias_b, y = ctx.saved_tensors
                
                #Compute target
                t = y + grad_y
                
                #Compute parameter gradient
                r_perturb_1 = F.elu(
                    F.max_unpool2d(
                    t, 
                    ind, 
                    2, 
                    stride = 2, 
                    output_size = input.size()
                    )
                )
                
                delta_post_w_f = r_perturb_1 - r_1
                
                grad_w_f = cudnn_convolution.convolution_backward_weight(
                    input,
                    w_b.shape,
                    delta_post_w_f,
                    stride, 
                    padding, 
                    (1, 1), 
                    1,
                    False, 
                    False, 
                    False
                )
                
                grad_bias_f = torch.sum(delta_post_w_f, dim=[0, 2, 3]).squeeze(0)
                
                #Compute input gradient
                r_perturb_2 = cudnn_transpose_convolution.convolution_transpose(
                    r_perturb_1, 
                    w_b,
                    bias_b, 
                    stride, 
                    padding_backward, 
                    (1, 1), 
                    1, 
                    False,
                    False
                )

                grad_input = r_perturb_2 - r_2
        
            else:
                input, ind, r_1, w_f, y = ctx.saved_tensors
                
                #Compute target
                t = y + grad_y
                
                #Compute parameter gradient
                r_perturb_1 = F.elu(
                    F.max_unpool2d(
                    t, 
                    ind, 
                    2, 
                    stride = 2, 
                    output_size = input.size()
                    )
                )
                
                delta_post_w_f = r_perturb_1 - r_1
                
                grad_w_f = cudnn_convolution.convolution_backward_weight(
                    input,
                    w_f.shape,
                    delta_post_w_f,
                    stride, 
                    padding, 
                    (1, 1), 
                    1,
                    False, 
                    False, 
                    False
                )
                
                grad_bias_f = torch.sum(delta_post_w_f, dim=[0, 2, 3]).squeeze(0)
                grad_input = None
            

        return grad_input, grad_w_f, grad_bias_f, grad_w_b, grad_bias_b, grad_stride, grad_padding, grad_w_b_learning, grad_back, grad_extra_input
          

class layer_convpool(nn.Module):
    '''
    Defines the final class for convolutional autoencoders,
    including most importantly the forward method and the method
    to train the feedback weights
    '''
    def __init__(self, 
                 in_channels,
                 out_channels, 
                 kernel_size, 
                 stride, 
                 padding, 
                 activation,
                 iter=None, 
                 noise=None
                ):
        
        super(layer_convpool, self).__init__()
        
        kernel_size = _pair(kernel_size)
        stride      = _pair(stride)
        padding     = _pair(padding)
        
        self.stride = stride
        self.padding = padding
        self.w_f = nn.Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
        self.bias_f = nn.Parameter(torch.Tensor(out_channels))
        nn.init.kaiming_normal_(self.w_f, mode='fan_in', nonlinearity='relu')
        nn.init.constant_(self.bias_f, 0)

        if iter is not None:
            self.w_b = nn.Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
            self.bias_b = nn.Parameter(torch.Tensor(in_channels))
            nn.init.kaiming_normal_(self.w_b, mode='fan_in', nonlinearity='relu')
            nn.init.constant_(self.bias_b, 0)
            self.iter = iter
            self.noise = noise
        else:
            self.w_b = None
            self.bias_b = None
            self.iter = None
            self.noise = None
               
        if activation == 'elu':
            self.rho = nn.ELU()
        elif activation == 'relu':
            self.rho = nn.ReLU()
            
        self.iter = iter
        self.noise = noise
        
    def forward(self, input, w_b_learning = False, back = False, extra_input = None):
        return neuron_convpool.apply(
            input,
            self.w_f, 
            self.bias_f, 
            self.w_b,
            self.bias_b, 
            self.stride, 
            self.padding, 
            w_b_learning, 
            back, 
            extra_input
        )
                                     

    def weight_b_train(self, input, optimizer, arg_return = False):
        for iter in range(1, self.iter + 1):
            
            if iter % 50 == 0:
                dist, angle = self.compute_dist_angle()
                print('\n Step {}: Distance = {}, angle = {} \n'.format(iter, dist, angle))
            
            y_temp, r_temp, ind = self(input, w_b_learning = True)
            noise = self.noise*torch.randn_like(input)
            _, r_noise, _ = self(input + noise, w_b_learning = True)
            dr = (r_noise - r_temp)
       
            noise_y = self.noise*torch.randn_like(y_temp)
            _, r_noise_y, _ = self(input, 
                                   extra_input = (y_temp + noise_y, ind), 
                                   w_b_learning = True
                                  )
            
            dr_y = (r_noise_y - r_temp)

            loss_b = -2*(noise*dr).view(dr.size(0), -1).sum(1).mean() + (dr_y**2).view(dr_y.size(0), -1).sum(1).mean() 
            
            optimizer.zero_grad() 
            loss_b.backward()            
            optimizer.step()
  
        if arg_return:
            return loss_b
    
    def compute_dist_angle(self):
        
        F = self.w_f
        G = self.w_b
        
        dist = torch.sqrt(((F - G)**2).sum()/(F**2).sum())

        F_flat = torch.reshape(F, (F.size(0), -1))
        G_flat = torch.reshape(G, (G.size(0), -1))
        #dist = torch.sqrt(((F_flat - G_flat)**2).sum()/(F_flat**2).sum())
        cos_angle = ((F_flat*G_flat).sum(1))/torch.sqrt(((F_flat**2).sum(1))*((G_flat**2).sum(1)))     
        angle = (180.0/np.pi)*(torch.acos(cos_angle).mean().item())
        
        return dist, angle

Likewise here, we do sanity checks, debugging the forward pass and making sure that the algorithm on the feedback weights works (keeping feedforward weights fixed)

In [None]:
#Test the layer_convpool class
#torch.manual_seed(0)
batch_size = 128
padding = 1
stride = 1
kernel_size = 3
in_channels = 128
out_channels = 128*2
in_size = 16
activation = 'elu'
iter = 200*30
noise = 0.4
lr = 1e-4

#1 - Instantiate layer_fc class
dummy_layer = layer_convpool(
    in_channels,
    out_channels, 
    kernel_size,
    stride,
    padding, 
    activation, 
    iter, 
    noise
    )

dummy_layer.to('cuda')

#2 - Forward pass
dummy_input = torch.randn(batch_size, in_channels, in_size, in_size).to('cuda')
out = dummy_layer(dummy_input)

#3 - Build optimizer for feedback weights
my_list = ['w_b', 'bias_b']
named_params_b = list(filter(lambda kv: kv[0] in my_list, dummy_layer.named_parameters()))
params_b = []
for name, param in named_params_b:
    params_b.append(param)
    #print(name + ' has mean {}'.format(param.mean()))

optim_params_b = [{'params': params_b, 'lr': lr}]
optimizer_b = torch.optim.SGD(optim_params_b, momentum = 0.9)

#4 - Train feedback weights
dummy_layer.weight_b_train(dummy_input, optimizer_b)

### Stacking several autoencoders 

Now, we stack several convolutional autoencoders and a final fully connected autoencoder to form a VGG-like architecture to be used on CIFAR-10.

In [None]:
#Define the small_VGG class

class small_VGG(nn.Module):
    def __init__(self, iter):
        super(small_VGG, self).__init__()
        
        size = 32
        
        C = [3, 128, 128, 256, 256, 512]
        #iter = [None, 20, 30, 35, 55, 20]
        noise = [None, 0.4, 0.4, 0.2, 0.2, 0.08]
        padding = 1
        stride = 1
        kernel_size = 3
        activation = 'elu'
        beta = 0.7
        
        layers = []

        for i in range(len(C) - 1):
            layers += [layer_convpool(C[i], 
                                      C[i + 1],
                                      kernel_size, 
                                      stride, 
                                      padding, 
                                      activation, 
                                      iter[i], 
                                      noise[i])
                      ]
            
            size = int(np.floor(size/2))
 
        layers += [layer_fc((size**2)*C[-1], 10, beta, iter[-1], noise[-1])]

        layers = nn.ModuleList(layers)
        self.layers = layers
        
        
    def weight_b_train(self, x, optimizer_b):
        y = self.layers[0](x).detach()
        for id_layer in range(len(self.layers) - 1):
            print('Optimizing feedback weights of layer {}...'.format(id_layer + 1))
            self.layers[id_layer + 1].weight_b_train(y, optimizer_b)
            
            #Uncomment to debug
            '''
            for name, param in self.named_parameters():
                if param.grad is not None:
                    print(name + ' has mean gradient: {}'.format(param.grad.mean()))
                else:
                    print(name + ' has None gradient')
            '''
            
            y = self.layers[id_layer + 1](y).detach()
            #self.zero_grad()
            

    def forward(self, x, back = False):
        s = x
        for i in range(len(self.layers)):
            s  = self.layers[i](s, back = back)
        return s


Here, we debug the resulting architecture (forward pass, backward pass, feedback weights training)

In [None]:
#Test the small_VGG class
device = torch.device('cuda:0')
net = small_VGG()
net.to(device)
torch.manual_seed(1)

#Check model parameters
for name, p in net.named_parameters():
    print(name + ': {}'.format(p.size()))

#Test forward pass
batch_size = 128
input_channels = 3
input_size = 32
dummy_input = torch.randn(128, input_channels, input_size, input_size).to(device)
output = net(dummy_input, back=True)
print(output.size())


#Test backward pass
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
dummy_target = torch.randint(1, 10, (batch_size,)).to('cuda')
loss = criterion(output.float(), dummy_target).squeeze()
print(loss)
loss.backward()

for name, p in net.named_parameters():
    if p.grad is not None:
        print(name + ' has mean gradient: {}'.format(p.grad.mean()))

#Test feedback weights training
net.weight_b_train(dummy_input, optimizer_b)
print('Done!')

### Training by DTP on CIFAR-10

Lastly, we present the final piece to train the previous architecture with our algorithm

In [None]:
#Load CIFAR-10
batch_size = 128
import torchvision
import torchvision.transforms as transforms
transform_train = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(0.5),
                                                          torchvision.transforms.RandomCrop(size=[32,32], padding=4, padding_mode='edge'),
                                                          torchvision.transforms.ToTensor(), 
                                                          torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 
                                                                                           std=(3*0.2023, 3*0.1994, 3*0.2010)) ])    

transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                                 torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 
                                                                                  std=(3*0.2023, 3*0.1994, 3*0.2010)) ]) 

cifar10_train_dset = torchvision.datasets.CIFAR10('./cifar10_pytorch', train=True, transform=transform_train, download=True)
cifar10_test_dset = torchvision.datasets.CIFAR10('./cifar10_pytorch', train=False, transform=transform_test, download=True)


train_loader = torch.utils.data.DataLoader(cifar10_train_dset, batch_size=batch_size, shuffle=True, num_workers=1)
test_loader = torch.utils.data.DataLoader(cifar10_test_dset, batch_size=200, shuffle=False, num_workers=1)

In [None]:
#Test VGG training on CIFAR-10

#1 - Build the net
device = torch.device('cuda:0')
iter = [None, 200, 300, 350, 550, 200]
net = small_VGG(iter)
net.to(device)


#2 - Build optimizer for *feedback* weights
lr_b = [1e-4, 3.5e-4, 8e-3, 8e-3, 0.18]

optim_params_b = []
for i in range(len(net.layers) - 1):
    my_list = ['layers.' + str(i + 1) + '.w_b', 'layers.'+ str(i + 1) + '.bias_b']
    named_params_b = list(filter(lambda kv: kv[0] in my_list, net.named_parameters()))
    params_b = []
    for name, param in named_params_b:
        params_b.append(param)
        #print(name)
        
    optim_params_b += [{'params': params_b, 'lr': lr_b[i]}] 
    
optimizer_b = torch.optim.SGD(optim_params_b, momentum = 0.9)

#3 - Build optimizer for *forward* weights
lr_f = 0.08
wdecay = 1e-4

optim_params_f = []
for i in range(len(net.layers)):
    my_list = ['layers.' + str(i) + '.w_f', 'layers.'+ str(i) + '.bias_f']
    named_params_f = list(filter(lambda kv: kv[0] in my_list, net.named_parameters()))
    params_f = []
    for name, param in named_params_f:
        params_f.append(param)
        #print(name)
        
    optim_params_f += [{'params': params_f, 'lr': lr_f}] 
    
optimizer_f = torch.optim.SGD(optim_params_f, momentum = 0.9, weight_decay = wdecay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_f, 85, eta_min=1e-5)

#4 - Define training criterion
criterion = torch.nn.CrossEntropyLoss(reduction='mean')

#5 - Train the net on CIFAR-10 (should reach ~30% train accuracy at the end of the first epoch)
net.train()


for epoch in range(epochs):
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):

        data, target = data.to(device), target.to(device)

        #Optimize feedback weights
        net.weight_b_train(data, optimizer_b)

        optimizer_b.zero_grad()

        #Forward pass
        output = net(data, back=True)

        #Optimize forward weights
        loss = criterion(output.float(), target).squeeze()
        optimizer_f.zero_grad()
        loss.backward()
        optimizer_f.step()

        #Compute current accuracy
        train_loss += loss.item()
        _, predicted = output.max(1)
        targets = target
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        percent_trainset = (batch_idx + 1)/len(train_loader)*100
        train_acc = (correct/total)*100
        print('Train accuracy : {:.2f} % ({:.2f} % of training set)'.format(train_acc, percent_trainset))

    #Update scheduler
    scheduler.step()