# Pretraining a big CNN classifier

This script simply defines and trains a large CNN classifier on the whole dataset for the students to use in lab 6.1 (introspection).

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split, Subset
import copy
#from codecarbon import track_emissions

# Define the data repository
data_dir = 'data/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Process on: {}'.format(device))

Process on: cpu


In [2]:
# Initialization function
def init_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight.data)
        if m.bias is not None:
            m.bias.data.fill_(0.01)
    return

In [3]:
# Load the Fashion MNIST dataset
train_data = datasets.MNIST(data_dir, train=True, download=True, transform=transforms.ToTensor())
num_classes = len(train_data.classes)

# Split in training / validation
n_train_examples = int(len(train_data)*0.9)
n_valid_examples = len(train_data) - n_train_examples
train_data, valid_data = random_split(train_data, [n_train_examples, n_valid_examples])

# Create dataloaders
batch_size = 8
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size)

In [4]:
# CNN classifier (with batch normalization)
class CNNClassif_bnorm(nn.Module):
    def __init__(self, input_size_linear, num_channels1=16, num_channels2=32, num_classes=10):
        super(CNNClassif_bnorm, self).__init__()
        
        self.num_channels1 = num_channels1
        self.num_channels2 = num_channels2
        self.num_classes = num_classes
        
        self.cnn_layer1 = nn.Sequential(
            nn.Conv2d(1, num_channels1, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.BatchNorm2d(num_channels1),
            nn.MaxPool2d(kernel_size=2))
            
        self.cnn_layer2 = nn.Sequential(
            nn.Conv2d(num_channels1, num_channels2, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.BatchNorm2d(num_channels2),
            nn.MaxPool2d(kernel_size=2))
        
        self.lin_layer = nn.Linear(input_size_linear, num_classes)
    
    def forward(self, x):
        
        out = self.cnn_layer1(x)
        out = self.cnn_layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.lin_layer(out)
        
        return out

In [5]:
# Eval function
def eval_cnn_classifier(model, eval_dataloader, device='cpu'):

    # Set the model in evaluation mode
    model.eval()
    model.to(device)

    # In test phase, we don't need to compute gradients (for memory efficiency)
    with torch.no_grad():
        # initialize the total and correct number of labels to compute the accuracy
        correct = 0
        total = 0
        for images, labels in eval_dataloader:
            labels = labels.to(device)
            images = images.to(device)
            y_predicted = model(images).to(device)
            _, label_predicted = torch.max(y_predicted.data, 1)
            total += labels.size(0)
            correct += (label_predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    
    return accuracy

In [8]:
# Training function
#@track_emissions(offline=True, country_iso_code="FRA")
def training_val_cnn_classifier(model, train_dataloader, valid_dataloader, num_epochs, loss_fn, learning_rate, device='cpu', verbose=True):

    # Make a copy of the model (avoid changing the model outside this function)
    model_tr = copy.deepcopy(model)
    
    # Set the model in 'training' mode (ensures all parameters' gradients are computed)
    model_tr.train()
    model_tr.to(device)
    
    # define the optimizer
    optimizer = torch.optim.Adam(model_tr.parameters(), lr=learning_rate)
    
    # Initialize lists to store the training loss and validation accuracy over epochs
    train_losses = []
    val_accuracies = []
    
    val_acc_opt = 0
    
    # Training loop
    for epoch in range(num_epochs):
        tr_loss = 0
        
        for batch_index, (images, labels) in enumerate(train_dataloader):

            # forward pass
            images = images.to(device)
            labels = labels.to(device)
            y_predicted = model_tr(images).to(device)
            loss = loss_fn(y_predicted, labels)

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update the current epoch loss
            tr_loss += loss.item() * images.shape[0]

        # At the end of each epoch, get the average training loss and store it
        tr_loss = tr_loss/len(train_dataloader.dataset)
        train_losses.append(tr_loss)
        
        # Compute the accuracy on the validation set and store it
        val_acc = eval_cnn_classifier(model_tr, valid_dataloader)
        val_accuracies.append(val_acc)
        
        # Display the training loss and validation accuracy
        if verbose:
            print('Epoch [{}/{}], Training loss: {:.4f} ; Validation accuracy: {:.4f}'
                   .format(epoch+1, num_epochs, tr_loss, val_acc))
            
        # If the validation accuracy is higher than the optimal value, record the model and update the optimal value
        if val_acc > val_acc_opt:
            model_opt = copy.deepcopy(model_tr)
            val_acc_opt = val_acc
        
    return model_opt, train_losses, val_accuracies

In [9]:
# Network parameters
num_channels1 = 16
num_channels2 = 32
num_classes = 10
input_size_linear = 7*7*num_channels2 

model = CNNClassif_bnorm(input_size_linear, num_channels1, num_channels2, num_classes)
# Optimizer
num_epochs = 10
loss_fn = nn.CrossEntropyLoss()
learning_rate = 0.001

# Training
model_tr, loss_total, val_acc = training_val_cnn_classifier(model, train_dataloader, valid_dataloader, num_epochs, loss_fn, learning_rate, device=device, verbose=True)

# Record the trained model
model_tr.to('cpu')
torch.save(model_tr.state_dict(), 'model_cnn_classif_introspection.pt')

Epoch [1/10], Training loss: 0.1192 ; Validation accuracy: 98.3000
Epoch [2/10], Training loss: 0.0520 ; Validation accuracy: 98.3667
Epoch [3/10], Training loss: 0.0369 ; Validation accuracy: 98.8833
Epoch [4/10], Training loss: 0.0282 ; Validation accuracy: 98.8167
Epoch [5/10], Training loss: 0.0212 ; Validation accuracy: 98.8833
Epoch [6/10], Training loss: 0.0183 ; Validation accuracy: 98.3000
Epoch [7/10], Training loss: 0.0148 ; Validation accuracy: 99.1333
Epoch [8/10], Training loss: 0.0127 ; Validation accuracy: 98.9667
Epoch [9/10], Training loss: 0.0111 ; Validation accuracy: 99.1000
Epoch [10/10], Training loss: 0.0095 ; Validation accuracy: 98.9167
