## The story of neural net layers that traveled back in time to the point when they weren't trained.

----

Experiment setups:

0. Randomly initialize a FC neural network and save a copy of it.
1. Train until a good performance is achieved.
2. Load:

    a. random network and replace weights of one of it's layers with the trained ones.
    
    b. random network and replace weights of one of trained network layers with the intial weights.
3. Get ready to have your mind blown! Once ready, test the network see the accuracy results.


Credits: https://lilianweng.github.io/lil-log/2019/03/14/are-deep-neural-networks-dramatically-overfitted.html#experiments and The lottery ticket hypothesis paper.

Training code is adapted from: https://docs.wandb.com/docs/frameworks/pytorch-example.html

In [1]:
from __future__ import print_function
import copy
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import wandb

In [2]:
class FC_Net(nn.Module):
    def __init__(self):
        super(FC_Net, self).__init__()
#         fc_sizes = [1024,512,256,128,64]
        fc_sizes = [2048,1024,512,256,128]
        self.fc1 = nn.Linear(1*28*28, fc_sizes[0])
        self.fc2 = nn.Linear(fc_sizes[0], fc_sizes[1])
        self.fc3 = nn.Linear(fc_sizes[1], fc_sizes[2])
        self.fc4 = nn.Linear(fc_sizes[2], fc_sizes[3])
        self.fc5 = nn.Linear(fc_sizes[3], fc_sizes[4])

    def forward(self, x):
        x = x.view(-1, 1*28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return F.log_softmax(x, dim=1)

In [3]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    example_images = []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            # Save the first inbput tensor in each test batch as an example image
            example_images.append(wandb.Image(data[0], caption="Pred: {} Truth: {}".format(pred[0].item(), target[0])))

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    # Log the images and metrics
    wandb.log({
            "Examples": example_images,
            "Test Accuracy": 100. * correct / len(test_loader.dataset),
            "Test Loss": test_loss})

In [4]:
wandb.init(project='generalisation')
# We load all of the arguments into config to save as hyperparameters
# wandb.config.update(args)

device = torch.device("cuda:0")

kwargs = {'num_workers': 4, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=256, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=256, shuffle=True, **kwargs)

W&B Run: https://app.wandb.ai/maks/generalisation/runs/9tda6m2a
Call `%%wandb` in the cell containing your training loop to display live results.


In [5]:
model = FC_Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [6]:
torch.save(model,'init_model.pt')

  "type " + obj.__name__ + ". It won't be checked "


In [7]:
# This magic line lets us save ther pytorch model and track all of the gradients and optionally parameters
wandb.watch(model)

[<wandb.wandb_torch.TorchGraph at 0x7f6e7b836be0>]

In [8]:
epochs = 25
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)



wandb: Wandb version 0.8.0 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade



Test set: Average loss: 2.1434, Accuracy: 3649/10000 (36%)


Test set: Average loss: 0.8945, Accuracy: 7702/10000 (77%)


Test set: Average loss: 0.5519, Accuracy: 8246/10000 (82%)


Test set: Average loss: 0.4028, Accuracy: 8850/10000 (88%)


Test set: Average loss: 0.3566, Accuracy: 8902/10000 (89%)


Test set: Average loss: 0.3234, Accuracy: 9042/10000 (90%)


Test set: Average loss: 0.2932, Accuracy: 9126/10000 (91%)


Test set: Average loss: 0.2763, Accuracy: 9192/10000 (92%)


Test set: Average loss: 0.2555, Accuracy: 9227/10000 (92%)


Test set: Average loss: 0.2464, Accuracy: 9264/10000 (93%)


Test set: Average loss: 0.2220, Accuracy: 9352/10000 (94%)


Test set: Average loss: 0.2118, Accuracy: 9379/10000 (94%)


Test set: Average loss: 0.1957, Accuracy: 9441/10000 (94%)


Test set: Average loss: 0.1903, Accuracy: 9471/10000 (95%)


Test set: Average loss: 0.1730, Accuracy: 9496/10000 (95%)


Test set: Average loss: 0.1693, Accuracy: 9494/10000 (95%)


Test set: Average loss:

In [9]:
#Checkpoint
torch.save(model,'trained_model.pt')

  "type " + obj.__name__ + ". It won't be checked "


---

## Loading initial and trained model

In [11]:
init_model = torch.load('init_model.pt')
trained_model = torch.load('trained_model.pt')

In [20]:
init_model.fc1.weight[0][:20]

tensor([-0.0189, -0.0237,  0.0143, -0.0071,  0.0112,  0.0074, -0.0261,  0.0316,
         0.0152, -0.0341,  0.0211, -0.0329,  0.0156,  0.0017,  0.0337,  0.0040,
        -0.0022, -0.0355,  0.0055, -0.0091], device='cuda:0',
       grad_fn=<SliceBackward>)

In [21]:
trained_model.fc1.weight[0][:20]

tensor([-0.0197, -0.0244,  0.0136, -0.0079,  0.0105,  0.0067, -0.0269,  0.0309,
         0.0145, -0.0348,  0.0204, -0.0336,  0.0149,  0.0010,  0.0330,  0.0033,
        -0.0029, -0.0362,  0.0048, -0.0098], device='cuda:0',
       grad_fn=<SliceBackward>)

In [24]:
# Different but very close, lets look into deeper layers

In [25]:
init_model.fc4.weight[0][:20]

tensor([-0.0411,  0.0225, -0.0391, -0.0073,  0.0105,  0.0398,  0.0214,  0.0348,
        -0.0152, -0.0315, -0.0428, -0.0353,  0.0102, -0.0082,  0.0332,  0.0086,
         0.0048,  0.0239,  0.0248, -0.0066], device='cuda:0',
       grad_fn=<SliceBackward>)

In [26]:
trained_model.fc4.weight[0][:20]

tensor([-0.0368,  0.0205, -0.0280, -0.0112,  0.0098,  0.0497,  0.0182,  0.0475,
        -0.0152, -0.0353, -0.0462, -0.0329,  0.0099, -0.0082,  0.0324,  0.0087,
         0.0050,  0.0239,  0.0415, -0.0066], device='cuda:0',
       grad_fn=<SliceBackward>)

In [None]:
# Still mostly within ±1e-2

In [27]:
# Verifying the accuracies on the test set for both of the models before changing the weights

In [28]:
test(init_model, device, test_loader)

wandb: Wandb version 0.8.0 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade



Test set: Average loss: 4.8498, Accuracy: 0/10000 (0%)



In [29]:
test(trained_model, device, test_loader)

wandb: Wandb version 0.8.0 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade



Test set: Average loss: 0.1147, Accuracy: 9656/10000 (97%)



In [64]:
# Replace weights of trained model with intial weights
for fc_id in range(1,6):
    new_model = copy.deepcopy(trained_model)
    new_model._modules[f'fc{fc_id}'].weight = init_model._modules[f'fc{fc_id}'].weight
    new_model._modules[f'fc{fc_id}'].bias = init_model._modules[f'fc{fc_id}'].bias
    print(f'replacing fc{fc_id}')
    test(new_model, device, test_loader)

replacing fc1


wandb: Wandb version 0.8.0 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade



Test set: Average loss: 0.3071, Accuracy: 9194/10000 (92%)

replacing fc2

Test set: Average loss: 0.8474, Accuracy: 9412/10000 (94%)

replacing fc3

Test set: Average loss: 0.5871, Accuracy: 9601/10000 (96%)

replacing fc4

Test set: Average loss: 0.6045, Accuracy: 9419/10000 (94%)

replacing fc5

Test set: Average loss: 1.1906, Accuracy: 8819/10000 (88%)



In [66]:
# Replace weights of intial model with the trained weights
for fc_id in range(1,6):
    new_model = copy.deepcopy(init_model)
    new_model._modules[f'fc{fc_id}'].weight = trained_model._modules[f'fc{fc_id}'].weight
    new_model._modules[f'fc{fc_id}'].bias = trained_model._modules[f'fc{fc_id}'].bias
    print(f'replacing fc{fc_id}')
    test(new_model, device, test_loader)

replacing fc1


wandb: Wandb version 0.8.0 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade



Test set: Average loss: 4.8028, Accuracy: 29/10000 (0%)

replacing fc2

Test set: Average loss: 4.6531, Accuracy: 3743/10000 (37%)

replacing fc3

Test set: Average loss: 4.7018, Accuracy: 3113/10000 (31%)

replacing fc4

Test set: Average loss: 4.7312, Accuracy: 2709/10000 (27%)

replacing fc5

Test set: Average loss: 4.6319, Accuracy: 1536/10000 (15%)



## A single trained and cropped out layer placed in the randomly initialized network produces ~30-40% accuracy! 
# 🤯