<a href="https://colab.research.google.com/github/mayarali/carcinoma_classification/blob/fatih/OxML_gradient_descent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import copy
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import io, models, transforms
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

In [8]:
class CustomDataset(Dataset):
    
    def __init__(self, img_folder, df, transform=None):
        
        self.img_folder = img_folder
        self.transform = transform
        self.df = df
        
    def __len__(self):
        return self.df.shape[0]
        
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        fpath = os.path.join(self.img_folder, f"img_{self.df.id.iloc[idx]}.png")
        img = Image.open(fpath)
        label = self.df.malignant.iloc[idx] + 1
        pid = self.df.id.iloc[idx]
        if self.transform:
            img = self.transform(img)
        return img, label, pid

In [None]:
model_name = 'resnet18'
EPOCHS = 2
torch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
root = '/content/drive/MyDrive/OxML/MLx Cases/data/'
df = pd.read_csv('/content/drive/MyDrive/OxML/MLx Cases/data/labels.csv')
skf = StratifiedKFold(n_splits=5)
scores = {}
loss_hist = {}
for fold, (train_index, val_index) in enumerate(skf.split(df['id'], df['malignant'])):
    print(f'Fold {fold}: ')

    train_transforms = transforms.Compose([ transforms.ToTensor(),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.RandomVerticalFlip()])
    val_transforms = transforms.Compose([transforms.ToTensor()])

    dataset = {'train': CustomDataset(root, df.iloc[train_index], transform=train_transforms),
               'val' : CustomDataset(root, df.iloc[val_index], transform=val_transforms)}

    dataloader = {x: DataLoader(dataset[x], batch_size=1, shuffle=True) for x in ['train', 'val']}
    print('Train dataset size:',len(dataset['train']))
    print('Val dataset size:',len(dataset['val']))
    model_dict = {'resnet50': models.resnet50(weights=models.ResNet50_Weights.DEFAULT),
                  'resnet18': models.resnet18(weights=models.ResNet18_Weights.DEFAULT)}
    model = model_dict[model_name]
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, 3)
    model.to(device)
    for i, child in enumerate(model.children()):
        if i < 7:
            for param in child.parameters():
                param.requires_grad = False
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, weight_decay=0.001)

    criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.21, 0.38, 0.41]).to(device)) #[0.42, 0.77, 0.81]
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20)

    loss_hist[fold] = {'train':[], 'val':[]}
    scores[fold] = {'train':[], 'val':[]}
    best_model_wts = copy.deepcopy(model.state_dict())
    best_f1 = 0
    for epoch in range(EPOCHS):
        
        print('-'*5+'Epoch '+str(epoch)+'-'*5)
        for phase in ['train', 'val']:
            model.train() if phase == 'train' else model.eval()
            running_loss = 0
            labels = torch.Tensor().to(device)
            preds = torch.Tensor().to(device)
            for data in dataloader[phase]:
                input, label = data[0].to(device), data[1].to(device)
                if phase == 'train':
                    
                    outputs = model(input)
                    loss = criterion(outputs, label)
                    loss.backward(retain_graph=True)
                    
                else: 
                    with torch.no_grad():
                        outputs = model(input)
                    loss = criterion(outputs, label)
                    
                running_loss = loss.item()*input.shape[0]
                pred = torch.argmax(outputs, dim=1)
                labels = torch.cat((labels, label))
                preds = torch.cat((preds, pred))
            if phase == 'train':
                optimizer.step()
                optimizer.zero_grad()
            epoch_loss = running_loss/len(dataset[phase])
            loss_hist[fold][phase].append(epoch_loss)
            
            f1 = f1_score(labels.cpu(), preds.cpu(), average='micro')
            scores[fold][phase].append(f1)
            if phase == 'val' and f1 > best_f1:
                best_f1 = f1
                best_model_wts = copy.deepcopy(model.state_dict())
            print(f'{phase} loss : {epoch_loss}, F1 : {f1}')
    
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), f'softmax_{model_name}_f{fold}.pth')

fig, axs = plt.subplots(1,5, figsize=(16,4), sharey=True)
fig.suptitle('Loss')
for i in range(5):
    axs[i].plot(range(EPOCHS), loss_hist[i]['train'], color='blue', label='train')
    axs[i].plot(range(EPOCHS), loss_hist[i]['val'], color='red', label='test')
    if i != 0:
        axs[i].set_yticklabels([])
    else:
        axs[i].legend()
plt.show()

fig, axs = plt.subplots(1,5, figsize=(16,4), sharey=True)
fig.suptitle('F1')
for i in range(5):
    axs[i].plot(range(EPOCHS), scores[i]['train'], color='blue', label='train')
    axs[i].plot(range(EPOCHS), scores[i]['val'], color='red', label='test')
    if i != 0:
        axs[i].set_yticklabels([])
    else:
        axs[i].legend()
plt.show()