In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.optim as optim
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import random_split
from collections import Counter

torch.manual_seed(123)
torch.set_default_dtype(torch.double) 

In [None]:
def load_cifar(train_val_split=0.9, data_path='../data/', preprocessor=None):
    
    # Define preprocessor if not already given
    if preprocessor is None:
        preprocessor = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4915, 0.4823, 0.4468),
                                (0.2470, 0.2435, 0.2616))
        ])
    
    # load datasets
    data_train_val = datasets.CIFAR10(
        data_path,       
        train=True,      
        download=True,  
        transform=preprocessor)

    data_test = datasets.CIFAR10(
        data_path, 
        train=False,
        download=True,
        transform=preprocessor)

    # train/validation split
    n_train = int(len(data_train_val)*train_val_split)
    n_val =  len(data_train_val) - n_train

    data_train, data_val = random_split(
        data_train_val, 
        [n_train, n_val],
        generator=torch.Generator().manual_seed(123)
    )

    print("Size of the train dataset:        ", len(data_train))
    print("Size of the validation dataset:   ", len(data_val))
    print("Size of the test dataset:         ", len(data_test))
    
    return (data_train, data_val, data_test)

cifar10_train, cifar10_val, cifar10_test = load_cifar()

In [None]:
#label_map is a dictionary where the keys are original class labels, and the values are the new class labels.
#it's mapping the original class label 0 to the new label 0 and the original class label 2 to the new label 1. 
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']

# For each part of dataset, keep only airplanes and birds. it means that we create new datasets 
#(cifar2_train, cifar2_val, and cifar2_test) by filtering the original datasets (cifar10_train, cifar10_val, 
# and cifar10_test) to only include the selected classes (0 and 2) and mapping their labels using label_map.
#training data set:
cifar2_train = [(img, label_map[label]) for img, label in cifar10_train if label in [0, 2]]
#validation: 
cifar2_val = [(img, label_map[label]) for img, label in cifar10_val if label in [0, 2]]
#test set:
cifar2_test = [(img, label_map[label]) for img, label in cifar10_test if label in [0, 2]]

print('Size of the training dataset: ', len(cifar2_train))
print('Size of the validation dataset: ', len(cifar2_val))
print('Size of the test dataset: ', len(cifar2_test))

In [None]:
# Network

class MyMLP(nn.Module):
    
    #Architecture
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32*32*3, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 32)
        self.fc4 = nn.Linear(32,30)
    
    #Forward Pass:  
    def forward(self, x): 
        out = torch.flatten(x, 1) 
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        return out


In [None]:
# Training loop

def train(n_epochs, optimizer, model, loss_fn, train_loader):
    
    n_batch = len(train_loader)
    
    # We'll store there the training loss for each epoch
    losses_train = []
    
    # Set the network in training mode
    model.train()
    
    # Re-initialize gradients, just in case the model has been inappropriately 
    # manipulated before the training
    optimizer.zero_grad(set_to_none=True)
    
    for epoch in range(1, n_epochs + 1): 
        
        # Training loss for the current epoch
        loss_train = 0

        # Loop over our dataset (in batches the data loader creates for us)
        for imgs, labels in train_loader:
                     
            # Feed a batch into our model
            outputs = model(imgs)
            
            # Compute the loss we wish to minimize 
            # Note that by default, it is the mean loss that is computed
            # (so entire_batch_loss / batch_size)
            loss = loss_fn(outputs, labels) 
            
            # Perform the backward step. That is, compute the gradients of all parameters we want the network to learn
            loss.backward()

            # Update the model
            optimizer.step() 
            
            # Zero out gradients before the next round (or the end of training)
            optimizer.zero_grad() 

            # Update loss for this epoch
            # It is important to transform the loss to a number with .item()
            loss_train += loss.item()
            
        # Store current epoch loss. 
        losses_train.append(loss_train / n_batch)

        if epoch == 1 or epoch % 10 == 0:
            print('{}  |  Epoch {}  |  Training loss {:.3f}'.format(
                datetime.now().time(), epoch, loss_train / n_batch))
            
    return losses_train

In [None]:
# Training loop manually update

def train_manual_update(n_epochs, lr, model, loss_fn, train_loader):
    
    n_batch = len(train_loader)
    losses_train = []
    model.train()
    
    for epoch in range(1, n_epochs + 1): 
        
        loss_train = 0

        for imgs, labels in train_loader:
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
            loss.backward()

            with torch.no_grad():
                for par in model.parameters():
                    par = (par @ model.fc4[0](lr)) * par.grad()
                    

            loss_train += loss.item()
            
        losses_train.append(loss_train / n_batch)

        if epoch == 1 or epoch % 10 == 0:
            print('{}  |  Epoch {}  |  Training loss {:.3f}'.format(
                datetime.now().time(), epoch, loss_train / n_batch))
            
    return losses_train

In [None]:
# Training with train_manual_update

train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64, shuffle=False) 
loss_fn = nn.CrossEntropyLoss()
torch.manual_seed(123)
model = MyMLP()

manual_update_loss = train_manual_update(
    n_epochs = 21,
    lr = 0.05,
    model = model,
    loss_fn = loss_fn,
    train_loader = train_loader,
)
print(manual_update_loss)

In [None]:
# Training with train()-function

train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64, shuffle=False) 
loss_fn = nn.CrossEntropyLoss()
torch.manual_seed(123)
model = MyMLP()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0)

losssss = train(
    n_epochs = 21,
    optimizer = optimizer,
    model = model,
    loss_fn = loss_fn,
    train_loader = train_loader,
)
print(losssss)

In [None]:
# Modified train_manual_update (L2 regularization / ridge)

def train_manual_update2(n_epochs, lr, model, loss_fn, train_loader):
    
    n_batch = len(train_loader)
    losses_train = []
    model.train()
    model.zero_grad()
    
    for epoch in range(1, n_epochs + 1): 
        
        loss_train = 0

        for imgs, labels in train_loader:
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
            loss.backward()

            with torch.no_grad():
                for par in model.parameters():
                    par -= lr * par.grad
                    
            model.zero_grad()

            loss_train += loss.item()
            
        losses_train.append(loss_train / n_batch)

        if epoch == 1 or epoch % 10 == 0:
            print('{}  |  Epoch {}  |  Training loss {:.3f}'.format(
                datetime.now().time(), epoch, loss_train / n_batch))
            
    return losses_train

In [None]:

torch.manual_seed(123)
model = MyMLP()
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64, shuffle=False) 
loss_fn = nn.CrossEntropyLoss()


manual_update_lossss = train_manual_update2(
    n_epochs = 21,
    lr = 0.05,
    model = model,
    loss_fn = loss_fn,
    train_loader = train_loader,
    weight_decay = 0.5
)
print(manual_update_loss)