# Code of CNN Casper

- The program can implement the LeNet-5, VGG-16 as ConvNet, The fully connective layer can be replaced by a casper tower or a casper layer without cascading the hidden neurons. 

- Datasets: EMNIST or MNIST. 
 
- The codes are modified from lab4.1 and the example code of casper.py. Reference is inclued in the paper.

In [38]:
# import libraries
from __future__ import print_function
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [39]:
class args:
    # defaults
    batch_size = 64
    test_batch_size = 1000    
    lr = 0.01 # for SGD
    momentum = 0.5 # for SGD
    no_cuda = False
    seed = 1
    log_interval = 500    
    
    """paras for net structure"""
    LeNet5 = True  # True for input->cnn->fc/tower->output, False for input->fc/tower->output
    vgg16 = False  # True for input->cnn->fc/tower->output, False for input->fc/tower->output
    minimal_net = True  # True for input->output, False for input->hidden->output
    add_fc = True  # if True, fc hidden nuerons will be added, Casper layer without cascading the hidden neurons. 
    add_tower = True  # if True, tower will be added 
    """datasets"""
    EMNIST = True
    MNIST = False
    
    num_fc_0 = 800  # number of fixed fc hidden neurons of initial net, will not increase as the stage grows 
    num_fc = 5  # number of hidden neurons of casper fc added at one time, Casper layer without cascading the hidden neurons. 
    num_hidden = 5  # number of hidden neurons of Casper tower added at a time
    
    """initial learning rate (Rprop/Adadelta/Adam)"""
    # Lr_in_before = 0.01  # hidden in before for Rprop/Adadelta
    Lr_in_before = 0.001  # hidden in before for Adam, default 0.001
    Lr_out_before = 0.001  # hidden out before
    Lr_in_after = 0.0005  # hidden in after
    Lr_out_after = 0.0005  # hidden out after
    
    weight_decay_before = 1e-5
    weight_decay_after = 2e-5
    drop_out_rate = 0
    
    epochs = 10
    stage = 3 
    if minimal_net and not add_tower and not add_fc:
        stage = 1 # if minimal_net, keep it 1
    """theshold for adding nuerons by correlation"""
    num_not_decrease = 1800
    
    """input, output"""
    if LeNet5:
        num_input = 7*7*64  # size after cnn
    elif vgg16:
        num_input = 3*3*64  # shrink the size to save time
    else:
        num_input = 28*28*1  # size of EMNIST/MNIST images
    if EMNIST:
        number_of_classes = 47  # 47 for EMNIST 
    elif MNIST:
        number_of_classes = 10  # 10 for MNIST
        
    """activation fuction"""
    af = nn.RReLU()  # nn.PReLU(), nn.ReLU(), nn.RReLU()
    
    """batch normalization"""
    bn = True
    """optimizer"""
    optimiser = "adam"  # rprop, adadelta, adam


In [40]:
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [41]:
if args.EMNIST:
    # download EMNIST dataset, EMNIST Balanced Dataset is used for classification
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.EMNIST(
            root='./data',
            split='balanced',
            train=True, 
            download=True,
            transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.EMNIST(
            root='./data', 
            split='balanced',
            train=False, 
            transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)
elif args.MNIST:
    # download EMNIST dataset
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            root='./data-MNIST',
    #         split='balanced',
            train=True, 
            download=True,
            transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            root='./data-MNIST', 
    #         split='balanced',
            train=False, 
            transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)


In [42]:
# define a casper network 
class Net(nn.Module):
    def __init__(self, n_feature, n_output):
        super(Net, self).__init__()        
        if args.LeNet5 == True:
            self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=0)
            self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=0)
            # self.conv2_drop = nn.Dropout2d()
            self.zero_pad = nn.ZeroPad2d(2) # 2 for left, right, up, down
        elif args.vgg16 == True:  
            out = 8 # default 64
            self.conv3_64_1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=1)
            self.conv3_64_2 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=3, stride=1)
            self.conv3_128_1 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1)
            self.conv3_128_2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1)
            self.conv3_256_1 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1)
            self.conv3_256_2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1)
            self.conv3_256_3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1)
            self.conv3_512_1 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1)
            self.conv3_512_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
            self.conv3_512_3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
            self.conv3_512_4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
            self.conv3_512_5 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
            self.conv3_512_6 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
            # self.conv2_drop = nn.Dropout2d()
            self.zero_pad1 = nn.ZeroPad2d(1)
            self.zero_pad2 = nn.ZeroPad2d((1,0,1,0)) # left 1, right 0, up 1, down 0
        else:
            """
            Define an artificial neuro network. Transform the pics by takeing advantage of Conv2d so that it 
            can be introduce to the fully connected layer
            """
            self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, stride=1)  
        if args.minimal_net:
            self.layer = torch.nn.Linear(n_feature, n_output)   # origin layer input->output
        else:
            self.layer_in = torch.nn.Linear(n_feature, args.num_fc_0)
            self.layer_out = torch.nn.Linear(args.num_fc_0, n_output)
        self.fc = {}
        self.hidden = {}
        for i in range(1, args.stage):
            if args.add_fc:
                self.fc[str(i)+"_in"] = torch.nn.Linear(n_feature, args.num_fc)
                self.fc[str(i)+"_out"] = torch.nn.Linear(args.num_fc, n_output)
            if args.add_tower:
                self.hidden[str(i) + "_in"] = torch.nn.Linear(n_feature+(i-1)*args.num_hidden, args.num_hidden) # the ith hidden neural
                self.hidden[str(i) + "_out"] = torch.nn.Linear(args.num_hidden, n_output) 
        self.drop_out = torch.nn.Dropout(p = args.drop_out_rate)
        if args.bn:
            self.bn = nn.BatchNorm1d(args.num_hidden)  # batch normalization for inputs of tower 
            # self.bn_fc_0 = nn.BatchNorm1d(args.num_fc_0)
            self.bn_fc = nn.BatchNorm1d(args.num_fc)  # batch normalization for inputs of fc

    def forward(self, x, stage):

        """
        define the forward pass of the cnn(only for transformation) with a relu activation function for 
        the hidden layer.
        """
        if args.LeNet5:
            x = F.relu(F.max_pool2d(self.conv1(self.zero_pad(x)), 2))
            # x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(self.zero_pad(x))), 2))
            x = F.relu(F.max_pool2d(self.conv2(self.zero_pad(x)), 2))
            # print(x.data.cpu().numpy().shape) # check the dimension of x
        elif args.vgg16:
            # 28*28*1
            x = F.relu(F.max_pool2d(self.conv3_64_2(self.zero_pad1(self.conv3_64_1(self.zero_pad1(x)))),2))
            # 14*14*out*2
            x = F.relu(F.max_pool2d(self.conv3_128_2(self.zero_pad1(self.conv3_128_1(self.zero_pad1(x)))),2))
            # 7*7*out*4
            x = F.relu(F.max_pool2d(self.conv3_256_3(self.zero_pad1(self.conv3_256_2(self.zero_pad1(self.conv3_256_1(self.zero_pad2(x)))))),2))
            # 4*4*out*8
            # x = F.relu(F.max_pool2d(self.conv3_512_3(self.zero_pad1(self.conv3_512_2(self.zero_pad1(self.conv3_512_1(self.zero_pad1(x)))))),2))  # without pooling
            x = F.relu(self.conv3_512_3(self.zero_pad1(self.conv3_512_2(self.zero_pad1(self.conv3_512_1(self.zero_pad1(x)))))))
            # max_pool2d-> 1*1*out*8 / without max_pool2d -> 3*3*out*8
            x = F.relu(self.conv3_512_6(self.zero_pad1(self.conv3_512_5(self.zero_pad1(self.conv3_512_4(self.zero_pad1(x)))))))  # without pooling
            # 1*1*out*8 
        else:
            x = args.af(self.conv(x))
        x_in = x.view(-1, args.num_input) # for casper net
                
        xx = {}
        node_in = [x_in,]
        if stage == 1:
            if args.minimal_net:   
                x = self.layer(x_in)
            else:
                # x = rrelu(bn_fc_0(self.layer_in(x_in))) # batch normalization
                x = args.af(self.layer_in(x_in)) # F.relu is better than F.rrelu for the layer between input and output
                x = self.layer_out(x)
        elif stage == 2:
            if args.minimal_net:
                xx["0"] = self.layer(x_in)  
                x = xx["0"]
            else:
                # xx["0"] = rrelu(bn_fc_0(self.layer_in(x_in))) # batch normalization
                xx["0"] = args.af(self.layer_in(x_in))
                xx["0"] = self.layer_out(xx["0"])
                x = xx["0"]
            if args.add_fc:
                xx["fc_1_in"] = args.af(self.bn_fc(self.fc["1_in"](x_in)))  # input - fc_hidden
                xx["fc_1_out"] = self.fc["1_out"](xx["fc_1_in"]) # fc_hidden - output
                x += xx["fc_1_out"]
            if args.add_tower:
                if args.bn:
                    xx["1_in"] = args.af(self.bn(self.hidden["1_in"](x_in)))
                else:
                    xx["1_in"] = args.af(self.hidden["1_in"](x_in))
                xx["1_out"] = self.hidden["1_out"](xx["1_in"])
                x += xx["1_out"]
            # x = xx["0"] + xx["1_out"] + xx["fc_1_out"]
        else:
            if args.minimal_net:
                xx["0"] = self.layer(x_in)  
                x = xx["0"]
            else:
                # xx["0"] = rrelu(bn_fc_0(self.layer_in(x_in))) # batch normalization
                xx["0"] = args.af(self.layer_in(x_in))
                xx["0"] = self.layer_out(xx["0"]) 
                x = xx["0"]
            if args.add_fc:
                if args.bn:
                    xx["fc_1_in"] = args.af(self.bn_fc(self.fc["1_in"](x_in)))  # input - fc_hidden
                else:
                    xx["fc_1_in"] = args.af(self.fc["1_in"](x_in))
                xx["fc_1_out"] = self.fc["1_out"](xx["fc_1_in"]) # fc_hidden - output
                x += xx["fc_1_out"]
            if args.add_tower:
                if args.bn:
                    xx["1_in"] = args.af(self.bn(self.hidden["1_in"](x_in)))
                else:
                    xx["1_in"] = args.af(self.hidden["1_in"](x_in))
                xx["1_out"] = self.hidden["1_out"](xx["1_in"])
                x += xx["1_out"]
            # x = xx["0"] + xx["1_out"] + xx["fc_1_out"]
            for i in range(2, stage):   
                if args.add_fc:
                    # node_in.append(xx["fc_"+str(i-1)+"_in"]) 
                    if args.bn:
                        xx["fc_"+str(i)+"_in"] = args.af(self.bn_fc(self.fc[str(i)+"_in"](x_in)))  # input - fc_hidden
                    else:
                        xx["fc_"+str(i)+"_in"] = args.af(self.fc[str(i)+"_in"](x_in))
                    xx["fc_"+str(i)+"_out"] = self.fc[str(i)+"_out"](xx["fc_"+str(i)+"_in"]) # fc_hidden - output
                    x += xx["fc_"+str(i)+"_out"]
                if args.add_tower:
                    node_in.append(xx[str(i-1)+"_in"])
                    if args.bn:
                        xx[str(i)+"_in"] = args.af(self.bn(self.hidden[str(i)+"_in"](torch.cat(node_in, 1))))
                    else:
                        xx[str(i)+"_in"] = args.af(self.hidden[str(i)+"_in"](torch.cat(node_in, 1)))
                    xx[str(i)+"_out"] = self.hidden[str(i)+"_out"](xx[str(i)+"_in"])
                    x += xx[str(i)+"_out"]
                # x += xx[str(i)+"_out"] + xx["fc_"+str(i)+"_out"]
        
        return F.log_softmax(x, dim=1)
        # return x
        
# optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

In [43]:
net = Net(args.num_input, args.number_of_classes)

if args.cuda:
    net.cuda()

In [44]:
optimiser = {}
if args.LeNet5:
    optimiser_conv1 = torch.optim.Adam(net.conv1.parameters(), lr=args.Lr_in_before)
    optimiser_conv2 = torch.optim.Adam(net.conv2.parameters(), lr=args.Lr_in_before)
elif args.vgg16:
    optimiser_conv3_64_1 = torch.optim.Adam(net.conv3_64_1.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_64_2 = torch.optim.Adam(net.conv3_64_2.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_128_1 = torch.optim.Adam(net.conv3_128_1.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_128_2 = torch.optim.Adam(net.conv3_128_2.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_256_1 = torch.optim.Adam(net.conv3_256_1.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_256_2 = torch.optim.Adam(net.conv3_256_2.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_256_3 = torch.optim.Adam(net.conv3_256_3.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_512_1 = torch.optim.Adam(net.conv3_512_1.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_512_2 = torch.optim.Adam(net.conv3_512_2.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_512_3 = torch.optim.Adam(net.conv3_512_3.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_512_4 = torch.optim.Adam(net.conv3_512_4.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_512_5 = torch.optim.Adam(net.conv3_512_5.parameters(), lr=args.Lr_in_before)
    optimiser_conv3_512_6 = torch.optim.Adam(net.conv3_512_6.parameters(), lr=args.Lr_in_before)

if args.minimal_net:
    if args.optimiser == "rprop":
        optimiser_layer_before = torch.optim.Rprop(net.layer.parameters(), lr=args.Lr_in_before*10)
        optimiser_layer_after = torch.optim.Rprop(net.layer.parameters(), lr=args.Lr_in_after*10)
    elif args.optimiser == "adadelta":
        optimiser_layer_before = torch.optim.Adadelta(net.layer.parameters(), lr=args.Lr_in_before*10)
        optimiser_layer_after = torch.optim.Adadelta(net.layer.parameters(), lr=args.Lr_in_after*10)
    elif args.optimiser == "adam":
        optimiser_layer_before = torch.optim.Adam(net.layer.parameters(), lr=args.Lr_in_before, weight_decay=args.weight_decay_before)
        optimiser_layer_after = torch.optim.Adam(net.layer.parameters(), lr=args.Lr_out_before, weight_decay=args.weight_decay_after)
else:
    if args.optimiser == "rprop":
        optimiser_layer_in_before = torch.optim.Rprop(net.layer_in.parameters(), lr=args.Lr_in_before*10)  # input - fc_hidden
        optimiser_layer_in_after = torch.optim.Rprop(net.layer_in.parameters(), lr=args.Lr_in_after*10)  # input - fc_hidden
        optimiser_layer_out_before = torch.optim.Rprop(net.layer_out.parameters(), lr=args.Lr_out_before*10)  # fc_hidden - output
        optimiser_layer_out_after = torch.optim.Rprop(net.layer_out.parameters(), lr=args.Lr_out_after*10)  # fc_hidden - output
    elif args.optimiser == "adadelta":
        optimiser_layer_in_before = torch.optim.Adadelta(net.layer_in.parameters(), lr=args.Lr_in_before*10)  # input - fc_hidden
        optimiser_layer_in_after = torch.optim.Adadelta(net.layer_in.parameters(), lr=args.Lr_in_after*10)  # input - fc_hidden
        optimiser_layer_out_before = torch.optim.Adadelta(net.layer_out.parameters(), lr=args.Lr_out_before*10)  # fc_hidden - output
        optimiser_layer_out_after = torch.optim.Adadelta(net.layer_out.parameters(), lr=args.Lr_out_after*10)  # fc_hidden - output
    # optimiser_layer_in_before = torch.optim.Adam(net.layer_in.parameters(), lr=args.Lr_in_before, weight_decay=args.weight_decay_before)  # input - fc_hidden
    # optimiser_layer_in_after = torch.optim.Adam(net.layer_in.parameters(), lr=args.Lr_in_after, weight_decay=args.weight_decay_after)  # input - fc_hidden
    # optimiser_layer_out_before = torch.optim.Adam(net.layer_out.parameters(), lr=args.Lr_out_before, weight_decay=args.weight_decay_before)  # fc_hidden - output
    # optimiser_layer_out_after = torch.optim.Adam(net.layer_out.parameters(), lr=args.Lr_out_after, weight_decay=args.weight_decay_after)  # fc_hidden - output
    elif args.optimiser == "adam":
        optimiser_layer_in_before = torch.optim.Adam(net.layer_in.parameters(), lr=args.Lr_in_before)  # input - fc_hidden
        optimiser_layer_in_after = torch.optim.Adam(net.layer_in.parameters(), lr=args.Lr_in_after)  # input - fc_hidden
        optimiser_layer_out_before = torch.optim.Adam(net.layer_out.parameters(), lr=args.Lr_out_before)  # fc_hidden - output
        optimiser_layer_out_after = torch.optim.Adam(net.layer_out.parameters(), lr=args.Lr_out_after)  # fc_hidden - output
    # optimiser = torch.optim.Adam(net.parameters(), lr=args.Lr_in_before)
    """SDG"""
    # optimiser_layer_in_before = optim.SGD(net.layer_in.parameters(), lr=args.lr, momentum=args.momentum)
    # optimiser_layer_in_after = optim.SGD(net.layer_in.parameters(), lr=args.lr, momentum=args.momentum)
    # optimiser_layer_out_before = optim.SGD(net.layer_out.parameters(), lr=args.lr, momentum=args.momentum)
    # optimiser_layer_out_after = optim.SGD(net.layer_out.parameters(), lr=args.lr, momentum=args.momentum)

for i in range(1, args.stage):
    """Rprop"""
    # optimiser[str(i)+"_in_before"] = torch.optim.Rprop(net.hidden[str(i)+"_in"].parameters(), lr=args.Lr_in_before)
    # optimiser[str(i) + "_in_after"] = torch.optim.Rprop(net.hidden[str(i)+"_in"].parameters(), lr=args.Lr_in_after)
    # optimiser[str(i) + "_out_before"] = torch.optim.Rprop(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_before)
    # optimiser[str(i) + "_out_after"] = torch.optim.Rprop(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_after)
    # adaptive according to weights
#     optimiser[str(i)+"_in_before"] = torch.optim.Rprop(net.hidden[str(i)+"_in"].parameters(), lr=args.L1*np.mean([abs(w) for w in net.hidden[str(i) + "_in"].weight.data.numpy()]))
#     optimiser[str(i) + "_in_after"] = torch.optim.Rprop(net.hidden[str(i) + "_in"].parameters(), lr=args.L3*np.mean([abs(w) for w in net.hidden[str(i) + "_in"].weight.data.numpy()]))
#     optimiser[str(i) + "_out_before"] = torch.optim.Rprop(net.hidden[str(i) + "_out"].parameters(), lr=args.L2*np.mean([abs(w) for w in net.hidden[str(i) + "_out"].weight.data.numpy()]))
#     optimiser[str(i) + "_out_after"] = torch.optim.Rprop(net.hidden[str(i) + "_out"].parameters(), lr=args.L4*np.mean([abs(w) for w in net.hidden[str(i) + "_out"].weight.data.numpy()])) 
    """Adadelta"""
#     optimiser[str(i)+"_in_before"] = torch.optim.Adadelta(net.hidden[str(i)+"_in"].parameters(), lr=args.L1)
#     optimiser[str(i) + "_in_after"] = torch.optim.Adadelta(net.hidden[str(i) + "_in"].parameters(), lr=args.L3)
#     optimiser[str(i) + "_out_before"] = torch.optim.Adadelta(net.hidden[str(i) + "_out"].parameters(), lr=args.L2)
#     optimiser[str(i) + "_out_after"] = torch.optim.Adadelta(net.hidden[str(i) + "_out"].parameters(), lr=args.L4)
    """ Adam """
    if args.add_fc: 
        if args.optimiser == "rprop":
            optimiser["fc_"+str(i)+"_in_before"] = torch.optim.Rprop(net.fc[str(i)+"_in"].parameters(), lr=args.Lr_in_before, weight_decay=args.weight_decay_before)  # input - fc_hidden
            # optimiser["fc_"+str(i)+"_in_after"] = torch.optim.Rprop(net.fc[str(i)+"_in"].parameters(), lr=args.Lr_in_after*np.mean([abs(w) for w in net.fc[str(i)+"_in"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after)  # input - fc_hidden # adaptive according to weights
            optimiser["fc_"+str(i)+"_in_after"] = torch.optim.Rprop(net.fc[str(i)+"_in"].parameters(), lr=args.Lr_in_after, weight_decay=args.weight_decay_after)  # input - fc_hidden     
            optimiser["fc_"+str(i)+"_out_before"] = torch.optim.Rprop(net.fc[str(i)+"_out"].parameters(), lr=args.Lr_out_before, weight_decay=args.weight_decay_before)  # fc_hidden - output
            # optimiser["fc_"+str(i)+"_out_after"] = torch.optim.Rprop(net.fc[str(i)+"_out"].parameters(), lr=args.Lr_out_after*np.mean([abs(w) for w in net.fc[str(i)+"_out"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after)  # fc_hidden - output # adaptive according to weights
            optimiser["fc_"+str(i)+"_out_after"] = torch.optim.Rprop(net.fc[str(i)+"_out"].parameters(), lr=args.Lr_out_after, weight_decay=args.weight_decay_after)  # fc_hidden - output
        elif args.optimiser == "adadelta":
            optimiser["fc_"+str(i)+"_in_before"] = torch.optim.Adadelta(net.fc[str(i)+"_in"].parameters(), lr=args.Lr_in_before, weight_decay=args.weight_decay_before)  # input - fc_hidden
            # optimiser["fc_"+str(i)+"_in_after"] = torch.optim.Adadelta(net.fc[str(i)+"_in"].parameters(), lr=args.Lr_in_after*np.mean([abs(w) for w in net.fc[str(i)+"_in"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after)  # input - fc_hidden # adaptive according to weights
            optimiser["fc_"+str(i)+"_in_after"] = torch.optim.Adadelta(net.fc[str(i)+"_in"].parameters(), lr=args.Lr_in_after, weight_decay=args.weight_decay_after)  # input - fc_hidden     
            optimiser["fc_"+str(i)+"_out_before"] = torch.optim.Adadelta(net.fc[str(i)+"_out"].parameters(), lr=args.Lr_out_before, weight_decay=args.weight_decay_before)  # fc_hidden - output
            # optimiser["fc_"+str(i)+"_out_after"] = torch.optim.Adadelta(net.fc[str(i)+"_out"].parameters(), lr=args.Lr_out_after*np.mean([abs(w) for w in net.fc[str(i)+"_out"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after)  # fc_hidden - output # adaptive according to weights
            optimiser["fc_"+str(i)+"_out_after"] = torch.optim.Adadelta(net.fc[str(i)+"_out"].parameters(), lr=args.Lr_out_after, weight_decay=args.weight_decay_after)  # fc_hidden - output
        elif args.optimiser == "adam":
            optimiser["fc_"+str(i)+"_in_before"] = torch.optim.Adam(net.fc[str(i)+"_in"].parameters(), lr=args.Lr_in_before, weight_decay=args.weight_decay_before)  # input - fc_hidden
            # optimiser["fc_"+str(i)+"_in_after"] = torch.optim.Adam(net.fc[str(i)+"_in"].parameters(), lr=args.Lr_in_after*np.mean([abs(w) for w in net.fc[str(i)+"_in"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after)  # input - fc_hidden # adaptive according to weights
            optimiser["fc_"+str(i)+"_in_after"] = torch.optim.Adam(net.fc[str(i)+"_in"].parameters(), lr=args.Lr_in_after, weight_decay=args.weight_decay_after)  # input - fc_hidden     
            optimiser["fc_"+str(i)+"_out_before"] = torch.optim.Adam(net.fc[str(i)+"_out"].parameters(), lr=args.Lr_out_before, weight_decay=args.weight_decay_before)  # fc_hidden - output
            # optimiser["fc_"+str(i)+"_out_after"] = torch.optim.Adam(net.fc[str(i)+"_out"].parameters(), lr=args.Lr_out_after*np.mean([abs(w) for w in net.fc[str(i)+"_out"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after)  # fc_hidden - output # adaptive according to weights
            optimiser["fc_"+str(i)+"_out_after"] = torch.optim.Adam(net.fc[str(i)+"_out"].parameters(), lr=args.Lr_out_after, weight_decay=args.weight_decay_after)  # fc_hidden - output
    if args.add_tower:
        if args.optimiser == "rprop":
            optimiser[str(i)+"_in_before"] = torch.optim.Rprop(net.hidden[str(i)+"_in"].parameters(), lr=args.Lr_in_before, weight_decay=args.weight_decay_before)
            # optimiser[str(i) + "_in_after"] = torch.optim.Rprop(net.hidden[str(i) + "_in"].parameters(), lr=args.Lr_in_after*np.mean([abs(w) for w in net.hidden[str(i)+"_in"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after) # adaptive according to weights
            optimiser[str(i) + "_in_after"] = torch.optim.Rprop(net.hidden[str(i) + "_in"].parameters(), lr=args.Lr_in_after, weight_decay=args.weight_decay_after)
            optimiser[str(i) + "_out_before"] = torch.optim.Rprop(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_before, weight_decay=args.weight_decay_before)
            # optimiser[str(i) + "_out_after"] = torch.optim.Rprop(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_after*np.mean([abs(w) for w in net.hidden[str(i)+"_in"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after) # adaptive according to weights
            optimiser[str(i) + "_out_after"] = torch.optim.Rprop(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_after, weight_decay=args.weight_decay_after)
        elif args.optimiser == "adadelta":
            optimiser[str(i)+"_in_before"] = torch.optim.Adadelta(net.hidden[str(i)+"_in"].parameters(), lr=args.Lr_in_before, weight_decay=args.weight_decay_before)
            # optimiser[str(i) + "_in_after"] = torch.optim.Adadelta(net.hidden[str(i) + "_in"].parameters(), lr=args.Lr_in_after*np.mean([abs(w) for w in net.hidden[str(i)+"_in"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after) # adaptive according to weights
            optimiser[str(i) + "_in_after"] = torch.optim.Adadelta(net.hidden[str(i) + "_in"].parameters(), lr=args.Lr_in_after, weight_decay=args.weight_decay_after)
            optimiser[str(i) + "_out_before"] = torch.optim.Adadelta(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_before, weight_decay=args.weight_decay_before)
            # optimiser[str(i) + "_out_after"] = torch.optim.Adadelta(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_after*np.mean([abs(w) for w in net.hidden[str(i)+"_in"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after) # adaptive according to weights
            optimiser[str(i) + "_out_after"] = torch.optim.Adadelta(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_after, weight_decay=args.weight_decay_after)
        elif args.optimiser == "adam":
            optimiser[str(i)+"_in_before"] = torch.optim.Adam(net.hidden[str(i)+"_in"].parameters(), lr=args.Lr_in_before, weight_decay=args.weight_decay_before)
            # optimiser[str(i) + "_in_after"] = torch.optim.Adam(net.hidden[str(i) + "_in"].parameters(), lr=args.Lr_in_after*np.mean([abs(w) for w in net.hidden[str(i)+"_in"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after) # adaptive according to weights
            optimiser[str(i) + "_in_after"] = torch.optim.Adam(net.hidden[str(i) + "_in"].parameters(), lr=args.Lr_in_after, weight_decay=args.weight_decay_after)
            optimiser[str(i) + "_out_before"] = torch.optim.Adam(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_before, weight_decay=args.weight_decay_before)
            # optimiser[str(i) + "_out_after"] = torch.optim.Adam(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_after*np.mean([abs(w) for w in net.hidden[str(i)+"_in"].weight.data.cpu().numpy()]), weight_decay=args.weight_decay_after) # adaptive according to weights
            optimiser[str(i) + "_out_after"] = torch.optim.Adam(net.hidden[str(i) + "_out"].parameters(), lr=args.Lr_out_after, weight_decay=args.weight_decay_after)


In [45]:
# loss_func = torch.nn.CrossEntropyLoss() # no need to use softmax
loss_func = torch.nn.NLLLoss()
all_losses = []

In [46]:
def train(epochs):
    net.train()
    stage = 1
    num_loss_without_decrease = 0
    count = 0
    
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)

            # if num_loss_without_decrease < args.num_not_decrease*(args.stage if args.stage==1 else args.stage*0.3): # check optimal threshold
            if num_loss_without_decrease < args.num_not_decrease:
                Y_pred = net(data, stage)
            elif stage < args.stage:
                num_loss_without_decrease = 0
                count = 0
                stage += 1
                print("\t\tstage" + str(stage))                
                Y_pred = net(data, stage)
            elif stage == args.stage:
                Y_pred = net(data, stage)
                
            # Compute loss
            # Here we pass Variables containing the predicted and true values of Y,
            # and the loss function returns a Variable containing the loss.
            loss = loss_func(Y_pred, target)
            # loss_num = loss.data[0]
            all_losses.append(float(loss.data[0]))
            count += 1
            if count >= 2:
                if all_losses[len(all_losses) - 1] >= all_losses[len(all_losses) - 2]:
                    num_loss_without_decrease += 1

            # Clear the gradients before running the backward pass.
            net.zero_grad()

            # Perform backward pass: compute gradients of the loss with respect to
            # all the learnable parameters of the model.
            loss.backward()

            # Calling the step function on an Optimiser makes an update to its parameters
            if stage == 1:
                if args.LeNet5:
                    optimiser_conv1.step()
                    optimiser_conv2.step()
                elif args.vgg16:
                    optimiser_conv3_64_1.step()
                    optimiser_conv3_64_2.step()
                    optimiser_conv3_128_1.step()
                    optimiser_conv3_128_2.step()
                    optimiser_conv3_256_1.step()
                    optimiser_conv3_256_2.step()
                    optimiser_conv3_256_3.step()
                    optimiser_conv3_512_1.step()
                    optimiser_conv3_512_2.step()
                    optimiser_conv3_512_3.step()
                    optimiser_conv3_512_4.step()
                    optimiser_conv3_512_5.step()
                    optimiser_conv3_512_6.step()
                """minimal net"""
                if args.minimal_net:
                    optimiser_layer_before.step()
                else:
                    optimiser_layer_in_before.step()
                    optimiser_layer_out_before.step()
                    # optimiser.step()
            elif stage == 2:
                if args.LeNet5:
                    optimiser_conv1.step()
                    optimiser_conv2.step()
                elif args.vgg16:
                    optimiser_conv3_64_1.step()
                    optimiser_conv3_64_2.step()
                    optimiser_conv3_128_1.step()
                    optimiser_conv3_128_2.step()
                    optimiser_conv3_256_1.step()
                    optimiser_conv3_256_2.step()
                    optimiser_conv3_256_3.step()
                    optimiser_conv3_512_1.step()
                    optimiser_conv3_512_2.step()
                    optimiser_conv3_512_3.step()
                    optimiser_conv3_512_4.step()
                    optimiser_conv3_512_5.step()
                    optimiser_conv3_512_6.step()
                """minimal net"""
                if args.minimal_net:
                    optimiser_layer_after.step()
                else:
                    optimiser_layer_in_after.step()
                    optimiser_layer_out_after.step()
                """"fc""" 
                if args.add_fc:
                    optimiser["fc_1_out_before"].step()  # fc_hidden -> output 
                    optimiser["fc_1_in_before"].step()  # input -> fc_hidden
                """tower"""
                if args.add_tower:
                    optimiser["1_out_before"].step()
                    optimiser["1_in_before"].step()
            else:
                if args.LeNet5:
                    optimiser_conv1.step()
                    optimiser_conv2.step()
                elif args.vgg16:
                    optimiser_conv3_64_1.step()
                    optimiser_conv3_64_2.step()
                    optimiser_conv3_128_1.step()
                    optimiser_conv3_128_2.step()
                    optimiser_conv3_256_1.step()
                    optimiser_conv3_256_2.step()
                    optimiser_conv3_256_3.step()
                    optimiser_conv3_512_1.step()
                    optimiser_conv3_512_2.step()
                    optimiser_conv3_512_3.step()
                    optimiser_conv3_512_4.step()
                    optimiser_conv3_512_5.step()
                    optimiser_conv3_512_6.step()
                """minimal net"""
                if args.minimal_net:
                    optimiser_layer_after.step()
                else:
                    optimiser_layer_in_after.step()
                    optimiser_layer_out_after.step()                
                for i in range(1, stage-1):
                    """"fc""" 
                    if args.add_fc:
                        optimiser["fc_"+str(i)+"_out_after"].step()
                        optimiser["fc_"+str(i)+"_in_after"].step()
                    """tower"""
                    if args.add_tower:
                        optimiser[str(i)+"_out_after"].step()
                        optimiser[str(i)+"_in_after"].step()
                """"fc"""
                if args.add_fc:
                    optimiser["fc_"+str(stage-1)+"_out_before"].step()
                    optimiser["fc_"+str(stage-1)+"_in_before"].step()
                """tower"""
                if args.add_tower:
                    optimiser[str(stage-1)+"_out_before"].step()
                    optimiser[str(stage-1)+"_in_before"].step()
    
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format(
                epoch+1, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))
        
        """ test """
        net.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data, volatile=True), Variable(target)
            output = net(data, stage)            
            test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
            # test_loss += loss_func(output, target).data[0]
            pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

        test_loss /= len(test_loader.dataset)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * float(correct) / len(test_loader.dataset)))
        # print("hidden_1_in weight", np.mean([abs(w) for w in net.hidden["1_in"].weight.data.cpu().numpy()[0]]))
  

In [47]:
# for epoch in range(1, args.epochs + 1):
train(args.epochs)
#     test()













Test set: Average loss: 0.4152, Accuracy: 16161/18800 (85.96%)











Test set: Average loss: 0.3746, Accuracy: 16368/18800 (87.06%)





		stage2








Test set: Average loss: 0.4850, Accuracy: 15827/18800 (84.19%)











Test set: Average loss: 0.4617, Accuracy: 16030/18800 (85.27%)





		stage3








Test set: Average loss: 0.4593, Accuracy: 15988/18800 (85.04%)











Test set: Average loss: 0.4743, Accuracy: 15905/18800 (84.60%)











Test set: Average loss: 0.4795, Accuracy: 15795/18800 (84.02%)











Test set: Average loss: 0.4760, Accuracy: 15910/18800 (84.63%)











Test set: Average loss: 0.4932, Accuracy: 15846/18800 (84.29%)











Test set: Average loss: 0.4979, Accuracy: 15962/18800 (84.90%)



In [48]:
# save the model
# torch.save(net, 'model_casper.pth')  

In [49]:
# load the model
# net = torch.load('model_casper.pth')