In [1]:
import pickle
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

try:
  import timm
except:
  print('timm does not exist')
  !pip install timm
  import timm

try:
  import torchmetrics
except:
  !pip install torchmetrics
  import torchmetrics

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# Load Data

In [2]:
with open("preprocessed_data/train_df.pkl", 'rb') as f:
    train_df = pickle.load(f)
    X_train = train_df.drop(columns=["label"])
    y_train = train_df[["label"]]

with open("preprocessed_data/test_df.pkl", 'rb') as f:
    test_df = pickle.load(f)
    X_test = test_df.drop(columns=["label"])
    y_test = test_df[["label"]]

# Transfer Learning using ResNet
Reference: https://www.kaggle.com/code/paulojunqueira/mnist-with-pytorch-and-transfer-learning-timm

In [3]:
class customDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X.values.reshape((-1,64,64,1))
        self.y = y.values.reshape(-1,1)
        self.transform = transform
    
    def __getitem__(self, index):
        X = self.X[index]
        y = self.y[index]

        if self.transform:
            X = self.transform(X)
        
        return X, y
    
    def __len__(self):
        return len(self.X)

In [4]:
# Split data
X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train, random_state=1, test_size=0.2)

# Define dataset objects
transform = transforms.Compose([transforms.ToTensor()])

ds_train = customDataset(X_tr, y_tr, transform=transform)
ds_val = customDataset(X_val, y_val, transform=transform)
ds_test = customDataset(X_test, y_test, transform=transform)

# Create Dataset Loaders
batch_size = 32
train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(ds_val, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=True)

In [5]:
timm_model = timm.create_model(model_name="resnet50", pretrained=True, num_classes=5, in_chans=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(timm_model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
accuracy_fn = torchmetrics.Accuracy(task="multiclass", num_classes=5).to(device)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [6]:
def train_step(model, dataloader, optimizer, loss_fn, accuracy_fn, device):
    
    train_loss, train_acc = 0, 0
    all_preds, all_targets = [], []

    model.to(device)
    model.train()

    for i, (X, y) in enumerate(dataloader):

        # Push to device (cuda)
        X, y = X.to(device), y.to(device)
        # Zero gradients
        optimizer.zero_grad()
        # Feedforward
        logits = model(X.float())
        # Compute loss
        loss = loss_fn(logits, y.long().squeeze())
        # Backpropagation
        loss.backward()
        # Update weights
        optimizer.step()

        train_loss += loss.item()
        y_preds = logits.argmax(dim=1)
        train_acc += accuracy_fn(y_preds, y.long().squeeze()).item()

        all_preds += [y_preds]
        all_targets += [y.long().squeeze()]

    train_loss /= len(dataloader)
    train_acc /= len(dataloader)
    print(f"Train loss: {train_loss:.5f} | Train accuracy: {100* train_acc:.2f}%")
    return train_loss, train_acc, all_preds, all_targets

def test_step(model, dataloader, loss_fn, accuracy_fn, device):

    test_loss, test_acc = 0, 0
    all_preds, all_targets = [], []

    model.to(device)
    model.eval()

    with torch.inference_mode():
        for i, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            logits = model(X.float())
            loss = loss_fn(logits, y.long().squeeze())

            test_loss += loss.item()
            y_preds = logits.argmax(dim=1)
            test_acc += accuracy_fn(y_preds, y.long().squeeze()).item()

            all_preds += [y_preds]
            all_targets += [y.long().squeeze()]

        test_loss /= len(dataloader)
        test_acc /= len(dataloader)
        print(f"Test loss: {test_loss:.5f} | Test accuracy: {100* test_acc:.2f}%")

    return test_loss, test_acc, all_preds, all_targets

In [7]:
epochs = 20
best_val_loss = float("inf")
early_stopping_counter = 0
early_stopping_threshold = 5

for epoch in range(epochs):
    print("-"*50 + f" Current epoch: {epoch + 1} " + "-"*50)
    train_loss, train_acc, train_preds, train_targets = train_step(
        model=timm_model,
        dataloader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_fn,
        accuracy_fn=accuracy_fn,
        device=device
    )

    val_loss, val_acc, val_preds, val_targets = test_step(
        model=timm_model,
        dataloader=val_loader,
        loss_fn=loss_fn,
        accuracy_fn=accuracy_fn,
        device=device
    )

    scheduler.step()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        torch.save(timm_model.state_dict(), "models/resnet_best_model.pth")  # Save the best model
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= early_stopping_threshold:
            print("Early Stopping Triggered")
            break

-------------------------------------------------- Current epoch: 1 --------------------------------------------------
Train loss: 0.50197 | Train accuracy: 82.13%
Test loss: 0.18933 | Test accuracy: 93.53%
-------------------------------------------------- Current epoch: 2 --------------------------------------------------
Train loss: 0.11140 | Train accuracy: 96.49%
Test loss: 0.08591 | Test accuracy: 97.18%
-------------------------------------------------- Current epoch: 3 --------------------------------------------------
Train loss: 0.06890 | Train accuracy: 97.68%
Test loss: 0.10113 | Test accuracy: 96.72%
-------------------------------------------------- Current epoch: 4 --------------------------------------------------
Train loss: 0.04565 | Train accuracy: 98.49%
Test loss: 0.07197 | Test accuracy: 97.85%
-------------------------------------------------- Current epoch: 5 --------------------------------------------------
Train loss: 0.03630 | Train accuracy: 98.83%
Test los

# Test Results using Best Model

In [8]:
model_path = "models/resnet_best_model.pth"
timm_model = timm.create_model(model_name="resnet50", num_classes=5, in_chans=1)
timm_model.load_state_dict(torch.load(model_path))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(timm_model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
accuracy_fn = torchmetrics.Accuracy(task="multiclass", num_classes=5).to(device)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

test_loss, test_acc, test_preds, test_targets = test_step(
        model=timm_model,
        dataloader=test_loader,
        loss_fn=loss_fn,
        accuracy_fn=accuracy_fn,
        device=device
    )

Test loss: 0.30520 | Test accuracy: 92.19%


In [9]:
test_preds_list = [pred.item() for batch_preds in test_preds for pred in batch_preds]
test_targets_list = [target.item() for batch_targets in test_targets for target in batch_targets]

report = classification_report(test_targets_list, test_preds_list, digits=4)
print(report)

              precision    recall  f1-score   support

           0     1.0000    0.7500    0.8571        60
           1     0.8551    0.9833    0.9147        60
           2     0.8293    0.9714    0.8947        35
           3     0.9677    1.0000    0.9836        60
           4     0.9643    0.9000    0.9310        30

    accuracy                         0.9184       245
   macro avg     0.9233    0.9210    0.9162       245
weighted avg     0.9278    0.9184    0.9166       245

