In [None]:
!pip install torch torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import copy

## Loading Data

In [None]:
# Most of the code in the present notebook was taken and adapted from the following tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

def get_data(use_valid=False, num_folds=5, fold=0):
    '''
    Function used to load the CIFAR10 dataset, along with any preprocessing needed.
    use_valid = False - Whether to create a validation set from the original training dataset
    num_folds = 5 - Number of cross-validation folds
    fold = 0 - Fold index to use as validation data
    '''

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # Normalize the image by subtracting by the mean and dividing by the standard deviation for each channel
        transforms.Lambda(lambda t: t.view(-1)) # Flatten the image, including channels (necessary for fully-connected NN)
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)

    num_examples = len(trainset)

    # Split into a training and validation set, given the current fold
    if use_valid:
        num_valid = int(num_examples // num_folds)
        train_idx = list(range(0, num_valid * fold)) + list(range(num_valid * (fold + 1), num_examples))
        valid_idx = list(range(num_valid * fold, num_valid * (fold + 1)))
    else:
        train_idx = list(range(num_examples))

    trainsampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, # From the paper
                                              sampler=trainsampler, num_workers=2)

    if use_valid:
        validset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=transform)
        validsampler = torch.utils.data.sampler.SubsetRandomSampler(valid_idx)
        validloader = torch.utils.data.DataLoader(validset, batch_size=256,
                                                  sampler=validsampler, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                             shuffle=False, num_workers=2)
    classes = ('plane', 'car', 'bird', 'cat',
                 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    if use_valid:
        return trainloader, validloader, testloader, classes
    else:
        return trainloader, testloader, classes

## Defining the network (Section 6.1 in paper)

** Quote from the paper, describing the network architecture for the CIFAR10 fully-connected experiment:** 
> We first experiment with a basic fully-connected architecture that takes as input the flattened image of size 3072. Input data is normalized by subtracting mean and dividing by standard deviation independently for each channel. The first linear layer is of size 3072 × 500. We then consider p layers 500 × 500, p being an architecture parameter for the sake of the analysis. The last classification is of size 500 × 10. The weights are initialized with He’s scheme. We train for 60 epochs using SGD with no momentum, a batch size of 256 and weight decay of 10−3. Cross validation is used to pick an initial learning rate in {0.0005, 0.001, 0.005, 0.01, 0.05, 0.1}. PathSGD, GN and WN are learned as detailed in Section 6.1. All results are the average test accuracies over 5 training runs.

In [None]:
# Can we use CUDA?
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

In [None]:
# Define the network
def create_net(p, batch_norm=False):
    '''
    Generate a new network of the given depth.
    p is the number of intermediary layers in the network
    batch_norm = False, whether to use batchnorm on each linear layer (except the output layer)
    '''

    # Input layer
    if batch_norm:
        layers = [nn.Linear(3072, 500, bias=False), nn.BatchNorm1d(500), nn.ReLU()]
    else:
        layers = [nn.Linear(3072, 500, bias=False), nn.ReLU()]
    
    # Intermediary layers
    for i in range(p):
        layers.append(nn.Linear(500, 500, bias=False))
        if batch_norm:
            layers.append(nn.BatchNorm1d(500))
        layers.append(nn.ReLU())
        
    # Output layer
    layers.append(nn.Linear(500, 10, bias=False))

    net = nn.Sequential(*layers)
    net.to(device)

    # Initialize weights with He's initialization
    def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.kaiming_uniform_(m.weight)
            if m.bias is not None:
                m.bias.data.fill_(0)

    net.apply(init_weights)
    
    return net

## Equi-Normalization
Algorithm 1: Equi Normalization Code

In [None]:
# Apply Equi-normalization directly to the linear layers in a sequential network
# Be careful here, the weight matrices stored in Linear layers have size output_size x input_size, which is the opposite of what seems to be assumed in the algorithm
def equi_norm(seq_net, c=1.2, C=1, p=2):
    '''
    Apply equi-normalization to the linear layers in a sequential neural network
    c: scaling factor
    C: number of enorm cycles to do (usually 1)
    p: norm to use (p-norm)
    '''
    
    with torch.no_grad():
        layers = [l for l in seq_net.children() if isinstance(l, nn.Linear)]       
        for T in range(C):
            for k in range(1, len(layers)):
                W_prev = layers[k - 1].weight.data
                W_k = layers[k].weight.data
                size = W_k.size()[1]

                R = torch.norm(W_k, p, dim=0)
                L = torch.norm(W_prev, p, dim=1)
                temp = torch.sqrt(1/c*torch.div(R,L)) # element wise division
                D = torch.diag(temp) # creates diagonal matrix from vector
                D_inverse = torch.diag(torch.div(torch.ones(size, dtype=W_k.dtype, device=W_k.device),temp))
                
                # Update the weights
                layers[k - 1].weight.data = torch.mm(D, W_prev) # Flip the operation compared to the Algorithm in the paper (see above)
                layers[k].weight.data = torch.mm(W_k, D_inverse)
                
                # Apply the bias normalization from section 3.5 of the paper
                if layers[k - 1].bias is not None:
                    layers[k - 1].bias.data = torch.mv(D, layers[k - 1].bias.data)
                
    

## Training
Algorithm 2: Training with Equi Normalization

In [None]:
# Train the network
EPOCHS = 60
def train(net, criterion, optimizer, lr_scheduler, trainloader, testloader=None, validloader=None, use_enorm=True):
    print('Starting training')
    
    top_valid_acc = -np.inf
    top_model_weights = None
    
    for epoch in range(EPOCHS):
        
        net.train()
        
        lr_scheduler.step()

        running_loss = 0.0
        since = time.time()
        for i, data in enumerate(trainloader, 0):

            # get the inputs
            inputs, labels = data

            # Use GPU if possible
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # ENorm cycle
            if use_enorm:
                equi_norm(net, c=1.2)

            # TODO: If using SGD with momentum, we would need to update the momentum buffers here using the D_k matrices from Algorithm 1
            # However, for the experiments that we ran, this was not necessary and was thus not implemented

            running_loss += loss.item()
            
        # Save network with best validation performance
        if validloader is not None:
            valid_acc = test(net, validloader)
            if valid_acc > top_valid_acc:
                top_valid_acc = valid_acc
                top_model_weights = copy.deepcopy(net.state_dict())

        # print statistics
    #     if (epoch + 1) % 10 == 0:
        time_elapsed = time.time() - since
        if validloader is not None:
            print(f"Finished epoch {epoch + 1} / {EPOCHS}. Loss: {running_loss}. Validation accuracy: {valid_acc}. Time taken: {round(time_elapsed // 60)}m {round(time_elapsed % 60)}s")
        elif testloader is not None:
            print(f"Finished epoch {epoch + 1} / {EPOCHS}. Loss: {running_loss}. Test accuracy: {test(net, testloader)}. Time taken: {round(time_elapsed // 60)}m {round(time_elapsed % 60)}s")
        else:
            print(f"Finished epoch {epoch + 1} / {EPOCHS}. Loss: {running_loss}. Time taken: {round(time_elapsed // 60)}m {round(time_elapsed % 60)}s")
    
    if validloader is not None:
        print(f'Finished Training. Top validation accuracy: {top_valid_acc}')
    else:
        print('Finished Training.')
    
    if top_model_weights is not None:
        net.load_state_dict(top_model_weights)
    
    return net

In [None]:
# Define a learning rate scheduler that linearly decays the learning rate from an initial value and a decay/epoch value
class LinearLRDecay(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, decay, last_epoch=-1):
        self.decay = decay
        super(LinearLRDecay, self).__init__(optimizer, last_epoch)
        
    def get_lr(self):
        return [base_lr - self.last_epoch * self.decay for base_lr in self.base_lrs]

## Testing

In [None]:
# Evaluate accuracy on a dataset
def test(net, testloader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return 100.0 * correct / total

## Running it all

In [None]:
trainloader, testloader, classes = get_data()

criterion = nn.CrossEntropyLoss()

# Run the baseline
optimal_baseline_lrs = { 1: 0.1, 3: 0.1, 5: 0.1, 7: 0.1, 9: 0.05, 11: 0.05, 13: 0.05, 15: 0.05, 17: 0.05, 19: 0.05 } # All learning rate values were given to us by the authors
for p in range(1, 20, 2):
    lr = optimal_baseline_lrs[p]
    print(f"Running baseline w/ p={p}")
    net = create_net(p)
    optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=1e-3)
    lr_scheduler = LinearLRDecay(optimizer, lr / EPOCHS)
    net = train(net, criterion, optimizer, lr_scheduler, trainloader, testloader=testloader, use_enorm=False) # Baseline for now
    test_acc = test(net, testloader)

    print(f'Accuracy of the network on the 10000 test images: {test_acc} %')

# Run with ENorm
optimal_lrs = { 1: 0.05, 3: 0.05, 5: 0.05, 7: 0.05, 9: 0.05, 11: 0.05, 13: 0.05, 15: 0.05, 17: 0.01, 19: 0.01 }
for p in range(1, 20, 2):
    print(f"Running ENorm with p = {p}")
    net = create_net(p)
    optimizer = optim.SGD(net.parameters(), lr=optimal_lrs[p], weight_decay=1e-3)
    lr_scheduler = LinearLRDecay(optimizer, optimal_lrs[p] / EPOCHS)
    net = train(net, criterion, optimizer, lr_scheduler, trainloader, testloader=testloader, use_enorm=True)
    test_acc = test(net, testloader)
    
    print(f'Accuracy of the network on the 10000 test images: {test_acc} %')
    
# Run a baseline with Batch Norm
optimal_bn_lr = 0.1
for p in range(1, 20, 2):
    print(f"Running Batch Norm baseline with p = {p}")
    net = create_net(p, batch_norm=True)
    optimizer = optim.SGD(net.parameters(), lr=optimal_bn_lr, weight_decay=1e-3)
    lr_scheduler = LinearLRDecay(optimizer, optimal_bn_lr / EPOCHS)
    net = train(net, criterion, optimizer, lr_scheduler, trainloader, testloader=testloader, use_enorm=False)
    test_acc = test(net, testloader)
    
    print(f'Accuracy of the network on the 10000 test images: {test_acc} %')

# Run with Enorm & Batch Norm
optimal_en_bn_lrs = { 1: 0.1, 3: 0.1, 5: 0.1, 7: 0.1, 9: 0.1, 11: 0.1, 13: 0.1, 15: 0.05, 17: 0.05, 19: 0.05 } # The last 3 learning rate values here we different than what was given to us by the authors and were found to give better results
for p in range(1, 20, 2):
    print(f"Running Batch Norm + ENorm with p = {p}")
    net = create_net(p, batch_norm=True)
    optimizer = optim.SGD(net.parameters(), lr=optimal_en_bn_lrs[p], weight_decay=1e-3)
    lr_scheduler = LinearLRDecay(optimizer, optimal_bn_lr / EPOCHS)
    net = train(net, criterion, optimizer, lr_scheduler, trainloader, testloader=testloader, use_enorm=True)
    test_acc = test(net, testloader)
    
    print(f'Accuracy of the network on the 10000 test images: {test_acc} %')