# Create a simple CNN
Notebook to create and train a simple CNN on MNIST data.

## 1. Create a model

In [1]:
import numpy as np
import torch

from xai.data_handlers.mnist import load_mnist
from xai.models.simple_cnn import CNNClassifier

In [2]:
model = CNNClassifier()

In [3]:
model

CNNClassifier(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)

## 2. Load data

In [4]:
# Load corpus and test inputs
batch_size = 64
# batch_size_test = 1000

corpus_loader = load_mnist(subset_size=1024, train=True, batch_size=batch_size) # MNIST train loader
test_loader = load_mnist(subset_size=1024, train=False, batch_size=batch_size) # MNIST test loader
corpus_inputs, corpus_labels = next(iter(corpus_loader)) # A tensor of corpus inputs
test_inputs, test_labels = next(iter(test_loader)) # A set of inputs to explain

In [5]:
corpus_loader

<torch.utils.data.dataloader.DataLoader at 0x7f658b6cea90>

## 3. Train the model

### 3.1. Learner class

In [6]:
from xai.models.training import Learner

In [7]:
learn = Learner(model, corpus_loader, test_loader, 10)

In [8]:
learn.fit()

Epoch: 1 | Training loss: 2.321478843688965 | Validation loss: 2.307582139968872
Epoch: 2 | Training loss: 2.303701400756836 | Validation loss: 2.293464183807373
Epoch: 3 | Training loss: 2.258002758026123 | Validation loss: 2.2805933952331543
Epoch: 4 | Training loss: 2.271474599838257 | Validation loss: 2.262023448944092
Epoch: 5 | Training loss: 2.2004315853118896 | Validation loss: 2.2400355339050293
Epoch: 6 | Training loss: 2.2040114402770996 | Validation loss: 2.1976354122161865
Epoch: 7 | Training loss: 2.090068817138672 | Validation loss: 2.132516384124756
Epoch: 8 | Training loss: 1.9447423219680786 | Validation loss: 2.046426773071289
Epoch: 9 | Training loss: 1.9964561462402344 | Validation loss: 1.8926242589950562
Epoch: 10 | Training loss: 1.7329294681549072 | Validation loss: 1.6882061958312988


### 3.1 Development

In [None]:
import torch.nn as nn
import torch.optim as optim

# Use default values from 
# https://github.com/vanderschaarlab/Simplex/blob/0af504927122d59dfc1378b73d0292244213e982/src/simplexai/experiments/mnist.py#L83
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5, weight_decay=0.01)

In [None]:
corpus_inputs.shape

In [None]:
def calculate_one_epoch(model, corpus_loader, loss_function):
    running_loss = 0.
    
    for batch_idx, data in enumerate(corpus_loader):
        # Unpack inputs and labels from data loader
        inputs, labels = data
        
        # Zero your learning weight gradients for every batch
        optimizer.zero_grad()

        # Make predictions for this batch and compute the loss
        outputs = model(inputs)  
        loss = loss_function(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

    return loss



In [120]:
def calculate_validation_loss(model, test_loader, loss_function, debug=False):
    validation_losses = []
    with torch.no_grad():
        for batch_idx, data in enumerate(test_loader):
            inputs, labels = data
            batch_outputs = model(inputs)
            batch_loss = loss_function(batch_outputs, labels)
            validation_losses.append(batch_loss)
    
            # TODO GJ: For debugging
            if debug:
                print(batch_loss)

    return np.mean(validation_losses)
    

In [121]:
calculate_validation_loss(model, test_loader, loss_function, debug=True)

tensor(1.0683)
tensor(0.9434)
tensor(0.7376)
tensor(0.8664)
tensor(0.7844)
tensor(0.9880)
tensor(0.7252)
tensor(0.7040)
tensor(0.6760)
tensor(0.8815)
tensor(0.6401)
tensor(0.9979)
tensor(0.9591)
tensor(0.9345)
tensor(0.9426)
tensor(0.8515)


0.8562723

In [128]:
model.fc1.weight

Parameter containing:
tensor([[-0.0308, -0.0207,  0.0073,  ...,  0.0191,  0.0399, -0.0210],
        [ 0.0270, -0.0175,  0.0477,  ..., -0.0223,  0.0352, -0.0158],
        [ 0.0239,  0.0503,  0.0451,  ..., -0.0223,  0.0228,  0.0351],
        ...,
        [-0.0174, -0.0443,  0.0373,  ..., -0.0226,  0.0417, -0.0141],
        [ 0.0460,  0.0286, -0.0199,  ...,  0.0331, -0.0020,  0.0243],
        [ 0.0344,  0.0418, -0.0292,  ...,  0.0194, -0.0385,  0.0231]],
       requires_grad=True)

In [129]:
calculate_one_epoch(model, corpus_loader, loss_function)

0.6170090436935425


tensor(0.6170, grad_fn=<NllLossBackward0>)

In [130]:
model.fc1.weight

Parameter containing:
tensor([[-0.0304, -0.0198,  0.0075,  ...,  0.0191,  0.0397, -0.0209],
        [ 0.0290, -0.0157,  0.0474,  ..., -0.0223,  0.0351, -0.0154],
        [ 0.0262,  0.0527,  0.0485,  ..., -0.0221,  0.0227,  0.0352],
        ...,
        [-0.0194, -0.0479,  0.0323,  ..., -0.0223,  0.0417, -0.0138],
        [ 0.0466,  0.0283, -0.0214,  ...,  0.0330, -0.0018,  0.0245],
        [ 0.0337,  0.0398, -0.0310,  ...,  0.0196, -0.0384,  0.0227]],
       requires_grad=True)

In [142]:
def train_model(model, corpus_loader, test_loader, loss_function, num_epochs):
    print(f"Training model {model}")
    for epoch in range(num_epochs):  # loop over the dataset multiple times

        training_loss = calculate_one_epoch(model, corpus_loader, loss_function)
        validation_loss = calculate_validation_loss(model, test_loader, loss_function)
        print(f"Epoch: {epoch+1} | Training loss: {training_loss} | Validation loss: {validation_loss}")

    print(f"Model training complete.")

In [143]:
model.fc1.weight

Parameter containing:
tensor([[-0.0250, -0.0123,  0.0075,  ...,  0.0179,  0.0358, -0.0174],
        [ 0.0265, -0.0147,  0.0532,  ..., -0.0211,  0.0329, -0.0092],
        [ 0.0187,  0.0396,  0.0472,  ..., -0.0198,  0.0194,  0.0339],
        ...,
        [-0.0416, -0.0560,  0.0262,  ..., -0.0199,  0.0381, -0.0092],
        [ 0.0637,  0.0455,  0.0038,  ...,  0.0292, -0.0004,  0.0254],
        [ 0.0352,  0.0495, -0.0494,  ...,  0.0171, -0.0336,  0.0191]],
       requires_grad=True)

In [144]:
train_model(model, corpus_loader, test_loader, loss_function, num_epochs=10)

Training model CNNClassifier(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)
Epoch: 1 | Training loss: 0.3185777962207794 | Validation loss: 0.4822332262992859
Epoch: 2 | Training loss: 0.34704187512397766 | Validation loss: 0.48406344652175903
Epoch: 3 | Training loss: 0.4384196400642395 | Validation loss: 0.46865567564964294
Epoch: 4 | Training loss: 0.47031792998313904 | Validation loss: 0.4788334369659424
Epoch: 5 | Training loss: 0.48285216093063354 | Validation loss: 0.4831583499908447
Epoch: 6 | Training loss: 0.3725970983505249 | Validation loss: 0.47777703404426575
Epoch: 7 | Training loss: 0.2947460412979126 | Validation loss: 0.47649532556533813
Epoch: 8 | Training loss: 0.45903638005256653 | Validation loss: 0.4886394739151001
Epoch:

In [145]:
model.fc1.weight

Parameter containing:
tensor([[-0.0204, -0.0137,  0.0075,  ...,  0.0175,  0.0337, -0.0172],
        [ 0.0241, -0.0114,  0.0545,  ..., -0.0208,  0.0321, -0.0084],
        [ 0.0194,  0.0362,  0.0464,  ..., -0.0190,  0.0171,  0.0330],
        ...,
        [-0.0482, -0.0530,  0.0266,  ..., -0.0193,  0.0372, -0.0079],
        [ 0.0643,  0.0429,  0.0051,  ...,  0.0280, -0.0003,  0.0256],
        [ 0.0277,  0.0438, -0.0544,  ...,  0.0168, -0.0321,  0.0175]],
       requires_grad=True)

## 4. Save a model

In [9]:
import pathlib

In [11]:
MODEL_DIR = pathlib.Path("/home/gurp/workspace/xai/xai/models/saved_models")
MODEL_FNAME = 'simple_cnn_test2.pth'
MODEL_FPATH = MODEL_DIR / MODEL_FNAME
MODEL_FPATH

PosixPath('/home/gurp/workspace/xai/xai/models/saved_models/simple_cnn_test2.pth')

In [12]:
# This is essentially a wrapper around torch.save(model.state_dict(), MODEL_FPATH)
learn.save_model(MODEL_FPATH)

Saving/loading models

## 5. Load a model

In [13]:
model2 = CNNClassifier()
model2.load_state_dict(torch.load(MODEL_FPATH))

<All keys matched successfully>

In [14]:
model2.fc1.weight

Parameter containing:
tensor([[-0.0296, -0.0352,  0.0463,  ...,  0.0346, -0.0150, -0.0504],
        [ 0.0276, -0.0086,  0.0576,  ...,  0.0687,  0.0058, -0.0284],
        [-0.0418,  0.0038, -0.0282,  ...,  0.0667,  0.0281, -0.0422],
        ...,
        [ 0.0340, -0.0411,  0.0304,  ..., -0.0296, -0.0237,  0.0355],
        [ 0.0451,  0.0145, -0.0184,  ...,  0.0350,  0.0128, -0.0334],
        [ 0.0292,  0.0033, -0.0106,  ..., -0.0463, -0.0378, -0.0383]],
       requires_grad=True)

In [15]:
model.fc1.weight

Parameter containing:
tensor([[-0.0296, -0.0352,  0.0463,  ...,  0.0346, -0.0150, -0.0504],
        [ 0.0276, -0.0086,  0.0576,  ...,  0.0687,  0.0058, -0.0284],
        [-0.0418,  0.0038, -0.0282,  ...,  0.0667,  0.0281, -0.0422],
        ...,
        [ 0.0340, -0.0411,  0.0304,  ..., -0.0296, -0.0237,  0.0355],
        [ 0.0451,  0.0145, -0.0184,  ...,  0.0350,  0.0128, -0.0334],
        [ 0.0292,  0.0033, -0.0106,  ..., -0.0463, -0.0378, -0.0383]],
       requires_grad=True)

## 6. Model predictions
From https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [None]:
# correct = 0
# total = 0
# # since we're not training, we don't need to calculate the gradients for our outputs
# with torch.no_grad():
#     for data in testloader:
#         images, labels = data
#         # calculate outputs by running images through the network
#         outputs = net(images)
#         # the class with the highest energy is what we choose as prediction
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

# print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')