In [14]:
import io
import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torchvision.models as models
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision.transforms import ToTensor
# from network import AttnVGG
from torch.utils.tensorboard import SummaryWriter
import torch.optim.lr_scheduler as lr_scheduler

In [2]:
root_dir = "/home/rishab/alexnet_attention/train"

# Hyperparameters
batch_size = 16
learning_rate = 0.001

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = datasets.ImageFolder(root=root_dir,transform=transform)

train_size = 0.8 
train_data, val_data = train_test_split(dataset, train_size=train_size, shuffle=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)

In [3]:
from sklearn.metrics import confusion_matrix
def compute_metrics(all_labels,all_preds,num_classes,epoch):
    CM = confusion_matrix(all_labels, all_preds, labels=list(range(16)))
    acc = np.sum(np.diag(CM)) / np.sum(CM)
    
    class_sensitivity = []
    class_precision = []
    class_metrics = []
    
    for class_idx in range(num_classes):
        tp = CM[class_idx, class_idx]
        fn = np.sum(CM[class_idx, :]) - tp
        fp = np.sum(CM[:, class_idx]) - tp
        tn = np.sum(CM) - tp - fn - fp
        
        sensitivity = tp / (tp + fn)
        precision = tp / (tp + fp)
        class_sensitivity.append(sensitivity)
        class_precision.append(precision)
        class_metrics.append([sensitivity, precision])
        
    val_mean_sensitivity = np.mean(class_sensitivity)
    val_mean_precision = np.mean(class_precision)
    return acc,val_mean_sensitivity,val_mean_precision,CM

In [4]:
best_val_loss = float("inf")
num_epochs = 30

In [16]:
model = models.vgg16_bn(weights=torchvision.models.VGG16_BN_Weights.IMAGENET1K_V1)

num_classes = 16
model.classifier[6] = torch.nn.Linear(4096, num_classes)

In [17]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

lr_lambda = lambda epoch : np.power(0.1, epoch//10)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

In [18]:
writer = SummaryWriter("logs_vgg16")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, labels in train_loader:
        inputs = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_loader)
    
    writer.add_scalar("Training Loss" , avg_train_loss,epoch)

    # Adjusting Learning Rate
    scheduler.step()

    model.eval()
    total = 0
    correct = 0
    val_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
          inputs = images.to(device)
          labels = labels.to(device)
          outputs = model(inputs)
          loss = criterion(outputs,labels)
          val_loss += loss.item()*images.size(0)
          _,predict = torch.max(outputs, 1)
          total += labels.size(0)
          correct += (predict == labels).sum().item()
          all_preds.extend(predict.cpu().numpy())
          all_labels.extend(labels.cpu().numpy())
    val_loss /= len(val_loader.dataset)
    accuracy_val = 100*correct / total
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Traning Loss: {avg_train_loss:.4f}, Validation Loss: {val_loss: .4f}, Validation Accuracy,{accuracy_val:.2f}%')
    
    acc,val_mean_sensitivity,val_mean_precision,CM = compute_metrics(all_preds,all_labels,num_classes=16,epoch=epoch)
    writer.add_scalar('val/accuracy', acc*100, epoch)
    writer.add_scalar('val/mean_recall', val_mean_sensitivity,epoch)
    writer.add_scalar('val/precision_mel',val_mean_precision, epoch)
    writer.add_scalar("Validation Loss",val_loss,epoch)
    fig = plt.figure(figsize=(20,10))
    sns.heatmap(CM, annot=True, cmap="coolwarm")

    # Add the figure to the SummaryWriter
    writer.add_figure("heatmap", fig,global_step=epoch)

    writer.close()
    
    if val_loss < best_val_loss:
      best_val_loss = val_loss
      checkpoint_path = '/home/rishab/alexnet_attention/saved_model_1_vgg16'
      os.makedirs(checkpoint_path, exist_ok=True)
      checkpoint_path = os.path.join(checkpoint_path ,'best_model.pth')
      torch.save(model.state_dict(), checkpoint_path)
    
  
    if epoch == num_epochs - 1:
        checkpoint_path = '/home/rishab/alexnet_attention/last_epoch_model_vgg16'
        os.makedirs(checkpoint_path, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_path ,'last_model.pth')
        torch.save(model.state_dict(), checkpoint_path)
   
print('Training finished.')

Epoch [1/30], Traning Loss: 2.0221, Validation Loss:  1.8698, Validation Accuracy,48.24%


  sensitivity = tp / (tp + fn)


Epoch [2/30], Traning Loss: 1.3614, Validation Loss:  1.1543, Validation Accuracy,65.61%
Epoch [3/30], Traning Loss: 1.0898, Validation Loss:  1.0380, Validation Accuracy,68.87%
Epoch [4/30], Traning Loss: 0.8727, Validation Loss:  0.9647, Validation Accuracy,69.52%
Epoch [5/30], Traning Loss: 0.7468, Validation Loss:  0.8847, Validation Accuracy,74.27%
Epoch [6/30], Traning Loss: 0.6091, Validation Loss:  0.8848, Validation Accuracy,74.19%
Epoch [7/30], Traning Loss: 0.5147, Validation Loss:  0.8346, Validation Accuracy,75.23%
Epoch [8/30], Traning Loss: 0.4546, Validation Loss:  0.8533, Validation Accuracy,76.53%
Epoch [9/30], Traning Loss: 0.3765, Validation Loss:  0.9789, Validation Accuracy,75.15%
Epoch [10/30], Traning Loss: 0.3756, Validation Loss:  0.9474, Validation Accuracy,75.84%
Epoch [11/30], Traning Loss: 0.1507, Validation Loss:  0.8355, Validation Accuracy,79.37%
Epoch [12/30], Traning Loss: 0.0983, Validation Loss:  0.8554, Validation Accuracy,78.91%
Epoch [13/30], Tra