## This notebook fits the CNN on permuted labels for exercise 2.5.

**Load the Required Packages**

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt

# append the filepath to where torch is installed
sys.path.append('/home/millerm/.local/lib/python3.10/site-packages') 
# sys.path.append('/home/username/.local/lib/python3.10/site-packages')

import torch
import torchvision

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from torchinfo import summary
import torchvision.transforms as transforms
from torchvision.transforms import v2

We load the functions from pytorchcv. As you might experience complications importing the required pieces directly, we define the necessary functions separately below.

In [None]:
!wget https://raw.githubusercontent.com/MicrosoftDocs/pytorchfundamentals/main/computer-vision-pytorch/pytorchcv.py

In [None]:
from pytorchcv import train, plot_results, display_dataset, train_long

**Load the Model**

For this notebook, please refer to the model 20_model_state.pth. Unfortunately, we have been unable to load the full model in the student cluster such that we only provide the dictionary solution.

In [None]:
from torchvision.models import VGG16_Weights
model = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('Doing computations on device = {}'.format(device))

model.to(device)

In [None]:
model.classifier = nn.Sequential(
    torch.nn.Linear(25088,4096),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5, inplace = False),
    torch.nn.Linear(4096,4096),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5, inplace = False),
    torch.nn.Linear(4096,4096),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5, inplace = False),
    torch.nn.Linear(4096,4096),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5, inplace = False),
    torch.nn.Linear(4096,2)
).to(device)

In [None]:
model.load_state_dict(torch.load('models/20_model_state.pth'))

**Transform the Dataset**

In [None]:
trans_wo_norm = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

dataset0_wo_norm = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/train/", transform=trans_wo_norm)
dataset1_wo_norm = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/test/", transform=trans_wo_norm)
dataset2_wo_norm = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/val/", transform=trans_wo_norm)

In [None]:
mean0 = torch.tensor([0.5832, 0.5832, 0.5832])
std0  = torch.tensor([0.1413, 0.1413, 0.1413])
mean1 = torch.tensor([0.5763, 0.5763, 0.5763])
std1  = torch.tensor([0.1453, 0.1453, 0.1453])
mean2 = torch.tensor([0.6020, 0.6020, 0.6020])
std2  = torch.tensor([0.1401, 0.1401, 0.1401])

In [None]:
std_normalise_0 = transforms.Normalize(
    mean=mean0,
    std=std0
)
std_normalise_1 = transforms.Normalize(
    mean=mean1,
    std=std1
)
std_normalise_2 = transforms.Normalize(
    mean=mean2,
    std=std2
)

trans0 = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        std_normalise_0
])
trans1 = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor()
])
trans2 = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor()
])

random_trans = v2.RandomOrder([
        v2.GaussianBlur(3)
])

In [None]:
dataset_0 = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/train/", transform=trans0)
dataset_1 = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/test", transform=trans1)
dataset_2 = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/val", transform=trans2)

dataset_0 = random_trans(dataset_0)

**Permute lables of train data**

In [None]:
lables = np.array([lable for _, lable in trainset.dataset.imgs])

In [None]:
np.random.seed(0)
lables = np.random.permutation(lables)

In [None]:
for i, (_,lable) in enumerate(trainset.dataset.imgs):
    trainset.dataset.imgs[i] = (trainset.dataset.imgs[i][0], lables[i])

**Define data loaders**

In [None]:
num_samples = 3500
torch.manual_seed(1234)
trainset, testset = torch.utils.data.random_split(dataset_0, [num_samples, len(dataset_0) - num_samples])
train_loader = torch.utils.data.DataLoader(trainset, batch_size=32)
test_loader  = torch.utils.data.DataLoader(testset, batch_size=32)

In [None]:
def train_long(net,train_loader,test_loader,epochs=5,lr=0.001,optimizer=None,loss_fn = nn.NLLLoss(),print_freq=10):
    optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr)
    for epoch in range(epochs):
        net.train()
        total_loss,acc,count = 0,0,0
        for i, (features,labels) in enumerate(train_loader):
            lbls = labels.long().to(default_device)
            optimizer.zero_grad()
            out = net(features.to(default_device))
            loss = loss_fn(out,lbls)
            loss.backward()
            optimizer.step()
            total_loss+=loss
            _,predicted = torch.max(out,1)
            acc+=(predicted==lbls).sum()
            count+=len(labels)
            if i%print_freq==0:
                print("Epoch {}, minibatch {}: train acc = {}, train loss = {}".format(epoch,i,acc.item()/count,total_loss.item()/count))
        vl,va = validate(net,test_loader,loss_fn)
        print("Epoch {} done, validation acc = {}, validation loss = {}".format(epoch,va,vl))


In [None]:
def validate(net, dataloader,loss_fn=nn.NLLLoss()):
    net.eval()
    count,acc,loss = 0,0,0
    with torch.no_grad():
        for features,labels in dataloader:
            
            lbls = labels.long().to(default_device)
            out = net(features.to(default_device))
            loss += loss_fn(out,lbls) 
            pred = torch.max(out,1)[1]
            acc += (pred==lbls).sum()
            count += len(labels)
    return loss.item()/count, acc.item()/count

In [None]:
for param in model.features.parameters():
    param.requires_grad = True

Again, retraining the classifier will take at least half an hour. We recommend to load the provided model above.

In [None]:
# default_device = device
# train_long(model,train_loader,test_loader,lr=0.00005,loss_fn=torch.nn.CrossEntropyLoss(),epochs=10,print_freq=90)

In [None]:
validate(model, train_loader)

We achieve a training accuracy of 87.8%.

In [None]:
torch.save(model.state_dict(), 'perm_20_model_state.pth')