# Training ResNET
In this notebook one pretrained ResNet is trained and one ResNet from scratch is trained.
The models are saved in models/.

In [1]:
seed=3 # (1,2,3,4,5)
model_folder = "models/"
#Amount google driv
from google.colab import drive
import os

gdrive_path='/content/gdrive/MyDrive/case_study_opti'

# This will mount your google drive under 'MyDrive'
drive.mount('/content/gdrive', force_remount=True)
# In order to access the files in this notebook we have to navigate to the correct folder
os.chdir(gdrive_path)
# Check manually if all files are present
print(sorted(os.listdir()))

Mounted at /content/gdrive
['data', 'latent-communication']


In [2]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.2.4-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.8.0 (from lightning)
  Downloading lightning_utilities-0.11.2-py3-none-any.whl (26 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.4.0.post0-py3-none-any.whl (868 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m868.8/868.8 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.2.4-py3-none-any.whl (802 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m802.2/802.2 kB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<4.0,>=1.13.0->lightning)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12

In [3]:
%load_ext autoreload
%autoreload 2

**ResNet Pytorch implementation for MNIST classification**
First we import the required packages.

In [4]:
import torch
#Set seed


%matplotlib inline
import torch.nn as nn
from matplotlib import pyplot as plt
import numpy as np
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision import transforms
import torch.optim as optim
import time
import tqdm as tqdm
from torch.autograd import Variable
import importlib
from lightning import LightningModule

model_resnet = importlib.import_module('latent-communication.resnet.model_def')
utils = importlib.import_module('latent-communication.resnet.utils')

## **Load Dataset**
We can load data from pytorch dataset and preprocess it using *transform* function.

Note that the ResNet implemented in torchvision take RGB images as inputs, which has three channels. So, here we repeat the single-channel grey scale digits image three times to fit the torchvision model.

In [5]:
transform = transforms.Compose([transforms.ToTensor(),
                                # expand channel from 1 to 3 to fit
                                # ResNet pretrained model
                                transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                ])
batch_size = 256

data_file_path="./data"
# download dataset
mnist_train = datasets.MNIST(root=data_file_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root=data_file_path, train=False, download=True, transform=transform)
print(len(mnist_train), len(mnist_test))

# Load dataset
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
    shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size,
    shuffle=True, num_workers=0)

60000 10000


## **Building the model**

The torchvision model is pretrained on ImageNet with 1000 classes of output, therefore, the network structure is not suitable for the classification in MNIST dataset.

In [6]:
# print pretrain model structure
net = models.resnet18()
print(net)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## **Modify Pretrain Model Structure**
The main structure of the ResNet can be split into two parts: the feature generator (G) and the classifier (F). The pretrained weights on the feature generator can be reused and a new classifier can be trained to fit the calssfication task in MNIST.

In the following codes, *ResNetFeatrueExtractor18* reproduces the feature extraction parts of the ResNet18, with an option to load the pretained model. And *ResClassifier* use a fully connected layer to get 10 class predictions.



In [7]:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.01)

# calculate test accuracy
def test_accuracy(data_iter, model):
    """Evaluate testset accuracy of a model."""
    acc_sum,n = 0,0
    for (imgs, labels) in data_iter:
        # send data to the GPU if cuda is availabel
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            labels = labels.cuda()
        model.eval()
        with torch.no_grad():
            labels = labels.long()
            acc_sum += torch.sum((torch.argmax(model(imgs), dim=1) == labels)).float()
            n += labels.shape[0]
    return acc_sum.item()/n

## **Pre-trained model**

### Training

In [11]:
model = model = model_resnet.ResNet(pretrained=True)

if torch.cuda.is_available():
    model = model.cuda()

# setting up optimizer for both feature generator G and classifier F.
opt = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0005,momentum=0.9)

# loss function
criterion = nn.CrossEntropyLoss()

for epoch in range(0, 10):
    n, start = 0, time.time()
    train_l_sum = torch.tensor([0.0], dtype=torch.float32)
    train_acc_sum = torch.tensor([0.0], dtype=torch.float32)
    for i, (imgs, labels) in tqdm.tqdm(enumerate(iter(train_loader))):
        model.train()
        imgs = Variable(imgs)
        labels = Variable(labels)
        # train on GPU if possible
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            labels = labels.cuda()
            train_l_sum = train_l_sum.cuda()
            train_acc_sum = train_acc_sum.cuda()

        opt.zero_grad()
        # predicted labels
        label_hat = model(imgs)

        # loss function
        loss= criterion(label_hat, labels)
        loss.backward()
        opt.step()

        # calcualte training error
        model.eval()
        labels = labels.long()
        train_l_sum += loss.float()
        train_acc_sum += (torch.sum((torch.argmax(label_hat, dim=1) == labels))).float()
        n += labels.shape[0]
    test_acc = test_accuracy(iter(test_loader), model)
    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'\
        % (epoch + 1, train_l_sum/n, train_acc_sum/n, test_acc, time.time() - start))


235it [00:19, 11.92it/s]


epoch 1, loss 0.0005, train acc 0.962, test acc 0.990, time 22.7 sec


235it [00:20, 11.49it/s]


epoch 2, loss 0.0001, train acc 0.993, test acc 0.992, time 22.0 sec


235it [00:19, 11.94it/s]


epoch 3, loss 0.0001, train acc 0.996, test acc 0.992, time 21.4 sec


235it [00:19, 12.31it/s]


epoch 4, loss 0.0000, train acc 0.996, test acc 0.993, time 20.5 sec


235it [00:19, 12.10it/s]


epoch 5, loss 0.0000, train acc 0.997, test acc 0.993, time 21.0 sec


235it [00:19, 11.94it/s]


epoch 6, loss 0.0000, train acc 0.998, test acc 0.994, time 21.1 sec


235it [00:19, 12.09it/s]


epoch 7, loss 0.0000, train acc 0.999, test acc 0.994, time 20.9 sec


235it [00:19, 12.13it/s]


epoch 8, loss 0.0000, train acc 0.999, test acc 0.993, time 20.8 sec


235it [00:19, 12.12it/s]


epoch 9, loss 0.0000, train acc 0.999, test acc 0.994, time 20.8 sec


235it [00:19, 12.12it/s]


epoch 10, loss 0.0000, train acc 0.999, test acc 0.993, time 20.8 sec


In [12]:
## Save the model
torch.save(model.state_dict(), f'latent-communication/resnet/models/pretrained_model_seed{seed}.pth')

## **Training without Pre-trained model**

In [15]:
# setting pretrained to False. The rest is the same

seed=3
torch.manual_seed(seed)

model = model_resnet.ResNet(pretrained=False)

if torch.cuda.is_available():
    model = model.cuda()

opt = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0005,momentum=0.9)
criterion = nn.CrossEntropyLoss()

for epoch in range(0, 10):
    n, start = 0, time.time()
    train_l_sum = torch.tensor([0.0], dtype=torch.float32)
    train_acc_sum = torch.tensor([0.0], dtype=torch.float32)
    for i, (imgs, labels) in tqdm.tqdm(enumerate(iter(train_loader))):
        model.train()
        imgs = Variable(imgs)
        labels = Variable(labels)
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            labels = labels.cuda()
            train_l_sum = train_l_sum.cuda()
            train_acc_sum = train_acc_sum.cuda()

        opt.zero_grad()

        label_hat = model(imgs)

        # loss function
        loss= criterion(label_hat, labels)
        loss.backward()
        opt.step()


        # calcualte training error
        model.eval()
        labels = labels.long()
        train_l_sum += loss.float()
        train_acc_sum += (torch.sum((torch.argmax(label_hat, dim=1) == labels))).float()
        n += labels.shape[0]
    test_acc = test_accuracy(iter(test_loader), model)
    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'\
        % (epoch + 1, train_l_sum/n, train_acc_sum/n, test_acc, time.time() - start))


235it [00:19, 11.79it/s]


epoch 1, loss 0.0006, train acc 0.952, test acc 0.986, time 22.0 sec


235it [00:20, 11.23it/s]


epoch 2, loss 0.0001, train acc 0.989, test acc 0.990, time 22.3 sec


235it [00:19, 11.90it/s]


epoch 3, loss 0.0001, train acc 0.994, test acc 0.986, time 21.2 sec


235it [00:19, 12.31it/s]


epoch 4, loss 0.0000, train acc 0.997, test acc 0.989, time 20.5 sec


235it [00:19, 12.19it/s]


epoch 5, loss 0.0000, train acc 0.998, test acc 0.990, time 20.7 sec


235it [00:19, 11.97it/s]


epoch 6, loss 0.0000, train acc 0.999, test acc 0.991, time 21.0 sec


235it [00:19, 11.97it/s]


epoch 7, loss 0.0000, train acc 1.000, test acc 0.989, time 21.1 sec


235it [00:19, 12.08it/s]


epoch 8, loss 0.0000, train acc 1.000, test acc 0.992, time 20.9 sec


235it [00:19, 12.11it/s]


epoch 9, loss 0.0000, train acc 1.000, test acc 0.992, time 20.8 sec


235it [00:19, 12.09it/s]


epoch 10, loss 0.0000, train acc 1.000, test acc 0.992, time 20.9 sec


In [17]:
## Save the model
torch.save(model.state_dict(), f'latent-communication/resnet/models/model_seed{seed}.pth')