<a href="https://colab.research.google.com/github/nickleefly/experiment/blob/master/Explore_Gradients.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exploring gradients with Weights & Biases

First we have to install a couple additional packages into the Colab runtime for the training loop to work.

In [0]:
%%capture
!pip install wandb tqdm

If you don't have an account in W&B, you can [create a free account](https://www.wandb.com/) to visualize your model training.

Next, log in here in the notebook so you can log live results.

In [0]:
!wandb login

You can find your API keys in your browser here: https://app.wandb.ai/authorize
Paste an API key from your profile: 23b4991182751590ec0e6cc80e92144f080f3a2d
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


## Defining our model and train/test loops

Here we'll set up the code we need to run a couple of different model to classify the [MNIST digits](http://yann.lecun.com/exdb/mnist/) dataset. We'll be borrowing a lot of the boilerplate code from the PyTorch MNIST example found [here](https://github.com/pytorch/examples/blob/master/mnist/main.py).


Our model has two major components that will illustrate a couple different advantages of tracking gradients while training a deep learning model. The first component is a pretty basic 2D CNN --> fully-connected model that will do the heavy lifting of making the actual prediction. The second part feeds in 10 random values, passes them through a fully connected layer and concatenates them to the flattened output of the second 2D CNN layer. These random parameters carry no real value for the prediction task at hand. Check out how the gradients flowing to these parameters (which appear as `gradients/rand_fc.weight` and `gradients/rand_fc.weight` in your W&D dashboard) compare to those of the other model parameters.

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import wandb
from tqdm import *


class CNN_Net(nn.Module):
    def __init__(self, device):
        super(CNN_Net, self).__init__()
        
        self.device = device
        
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.rand_fc = nn.Linear(10, 10)
        self.fc1 = nn.Linear((4*4*50) + 10, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        rand_x = torch.randn(x.shape[0], 10).to(self.device)
        
        rand_x = F.relu(self.rand_fc(rand_x))
      
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        
        x = torch.cat((x, rand_x), dim=1)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
      
    
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    
    n_ex = len(train_loader)
    
    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=n_ex):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        
def test(model, device, test_loader, WANDB):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() 
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    tqdm.write('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    if WANDB:
        wandb.log({'test_loss': test_loss,
                   'accuracy': correct / len(test_loader.dataset)})
        
        
def main(config):
    
    if config['WANDB']:
        wandb.init(project='explore-gradients', reinit=True, config=config)
  
    use_cuda = torch.cuda.is_available()

    torch.manual_seed(config['SEED'])

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=config['BATCH_SIZE'], shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=config['TEST_BATCH_SIZE'], shuffle=True, **kwargs)

    model = CNN_Net(device).to(device)
    
    
    if config['WANDB']:
        wandb.watch(model, log='all')
    
    optimizer = optim.SGD(model.parameters(),
                          lr=config['LR'],
                          momentum=config['MOMENTUM'])

    for epoch in range(1, config['EPOCHS'] + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader, config['WANDB'])

## Training the model

Here you can edit the configuration dictionary to see how changing hyperparameters like the learning rate or momentum affect the gradients. If you want to turn off W&B experiment tracking, set `WANDB` to `False`.

In [0]:
config = {
    'BATCH_SIZE'         : 64,
    'TEST_BATCH_SIZE'    : 1000,
    'EPOCHS'             : 30,
    'LR'                 : 0.01,
    'MOMENTUM'           : 0,
    'SEED'               : 17,
    'WANDB'              : True,
}

main(config=config)

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:01, 8858809.71it/s]                            


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 129425.08it/s]           
  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 2236709.82it/s]                            
0it [00:00, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 48638.35it/s]            


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


100%|██████████| 938/938 [00:17<00:00, 53.28it/s]



Test set: Average loss: 0.1554, Accuracy: 9522/10000 (95%)



100%|██████████| 938/938 [00:17<00:00, 53.84it/s]



Test set: Average loss: 0.0924, Accuracy: 9715/10000 (97%)



100%|██████████| 938/938 [00:17<00:00, 52.97it/s]



Test set: Average loss: 0.0766, Accuracy: 9762/10000 (98%)



100%|██████████| 938/938 [00:17<00:00, 55.90it/s]



Test set: Average loss: 0.0556, Accuracy: 9819/10000 (98%)



100%|██████████| 938/938 [00:17<00:00, 53.63it/s]



Test set: Average loss: 0.0543, Accuracy: 9834/10000 (98%)



100%|██████████| 938/938 [00:17<00:00, 53.11it/s]



Test set: Average loss: 0.0461, Accuracy: 9851/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 53.20it/s]



Test set: Average loss: 0.0405, Accuracy: 9861/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 49.18it/s]



Test set: Average loss: 0.0411, Accuracy: 9864/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 53.27it/s]



Test set: Average loss: 0.0358, Accuracy: 9883/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 53.43it/s]



Test set: Average loss: 0.0467, Accuracy: 9841/10000 (98%)



100%|██████████| 938/938 [00:17<00:00, 53.36it/s]



Test set: Average loss: 0.0331, Accuracy: 9888/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.33it/s]



Test set: Average loss: 0.0318, Accuracy: 9891/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 56.66it/s]



Test set: Average loss: 0.0297, Accuracy: 9903/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 53.20it/s]



Test set: Average loss: 0.0262, Accuracy: 9911/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 53.45it/s]



Test set: Average loss: 0.0282, Accuracy: 9905/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.98it/s]



Test set: Average loss: 0.0275, Accuracy: 9904/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 53.54it/s]



Test set: Average loss: 0.0267, Accuracy: 9905/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 53.32it/s]



Test set: Average loss: 0.0281, Accuracy: 9903/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.17it/s]



Test set: Average loss: 0.0267, Accuracy: 9903/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 53.56it/s]



Test set: Average loss: 0.0315, Accuracy: 9880/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.87it/s]



Test set: Average loss: 0.0282, Accuracy: 9906/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.45it/s]



Test set: Average loss: 0.0295, Accuracy: 9896/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.57it/s]



Test set: Average loss: 0.0239, Accuracy: 9913/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.54it/s]



Test set: Average loss: 0.0285, Accuracy: 9907/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.90it/s]



Test set: Average loss: 0.0284, Accuracy: 9900/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.72it/s]



Test set: Average loss: 0.0250, Accuracy: 9906/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.76it/s]



Test set: Average loss: 0.0254, Accuracy: 9911/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 52.56it/s]



Test set: Average loss: 0.0256, Accuracy: 9912/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 48.74it/s]



Test set: Average loss: 0.0263, Accuracy: 9907/10000 (99%)



100%|██████████| 938/938 [00:17<00:00, 53.47it/s]



Test set: Average loss: 0.0251, Accuracy: 9909/10000 (99%)

