In [2]:
import os
import io
import copy
import cv2
import csv
import torch
import torchvision
import pandas as pd
import numpy as np
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.utils as utils
import torchvision.models as models
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.models import vgg16_bn
from torch.utils.tensorboard import SummaryWriter

In [3]:
root_dir = "/home/rishab/alexnet_attention/train"

# Hyperparameters
batch_size = 16
learning_rate = 0.001

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = datasets.ImageFolder(root=root_dir,transform=transform)

train_size = 0.8 
train_data, val_data = train_test_split(dataset, train_size=train_size, shuffle=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)

In [4]:
class AttentionBlock(nn.Module):
    def __init__(self, in_features_l, in_features_g, attn_features, up_factor, normalize_attn=True):
        super(AttentionBlock, self).__init__()
        self.up_factor = up_factor
        self.normalize_attn = normalize_attn
        self.W_l = nn.Conv2d(in_channels=in_features_l, out_channels=attn_features, kernel_size=1, padding=0, bias=False)
        self.W_g = nn.Conv2d(in_channels=in_features_g, out_channels=attn_features, kernel_size=1, padding=0, bias=False)
        self.phi = nn.Conv2d(in_channels=attn_features, out_channels=1, kernel_size=1, padding=0, bias=True)
    def forward(self, l, g):
        N, C, W, H = l.size()
        l_ = self.W_l(l)
        g_ = self.W_g(g)
        if self.up_factor > 1:
            g_ = F.interpolate(g_, scale_factor=self.up_factor, mode='bilinear', align_corners=False)
        c = self.phi(F.relu(l_ + g_)) # batch_sizex1xWxH
        
        # compute attn map
        if self.normalize_attn:
            a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W,H)
        else:
            a = torch.sigmoid(c)
        # re-weight the local feature
        f = torch.mul(a.expand_as(l), l) # batch_sizexCxWxH
        if self.normalize_attn:
            output = f.view(N,C,-1).sum(dim=2) # weighted sum
        else:
            output = F.adaptive_avg_pool2d(f, (1,1)).view(N,C) # global average pooling
        return a, output

In [5]:
class AttnVGG(nn.Module):
    def __init__(self, num_classes, normalize_attn=False, dropout=None):
        super(AttnVGG, self).__init__()
        net = models.vgg16_bn(weights=torchvision.models.VGG16_BN_Weights.IMAGENET1K_V1)
        self.conv_block1 = nn.Sequential(*list(net.features.children())[0:6])
        self.conv_block2 = nn.Sequential(*list(net.features.children())[7:13])
        self.conv_block3 = nn.Sequential(*list(net.features.children())[14:23])
        self.conv_block4 = nn.Sequential(*list(net.features.children())[24:33])
        self.conv_block5 = nn.Sequential(*list(net.features.children())[34:43])
        self.pool = nn.AvgPool2d(7, stride=1)
        self.dpt = None
        if dropout is not None:
            self.dpt = nn.Dropout(dropout)
        self.cls = nn.Linear(in_features=512+512+256, out_features=num_classes, bias=True)
        
       # initialize the attention blocks defined above
        self.attn1 = AttentionBlock(256, 512, 256, 4, normalize_attn=normalize_attn)
        self.attn2 = AttentionBlock(512, 512, 256, 2, normalize_attn=normalize_attn)
        
       
        self.reset_parameters(self.cls)
        self.reset_parameters(self.attn1)
        self.reset_parameters(self.attn2)
    def reset_parameters(self, module):
        for m in module.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.)
                nn.init.constant_(m.bias, 0.)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0., 0.01)
                nn.init.constant_(m.bias, 0.)
    def forward(self, x):
        block1 = self.conv_block1(x)       # /1
        pool1 = F.max_pool2d(block1, 2, 2) # /2
        block2 = self.conv_block2(pool1)   # /2
        pool2 = F.max_pool2d(block2, 2, 2) # /4
        block3 = self.conv_block3(pool2)   # /4
        pool3 = F.max_pool2d(block3, 2, 2) # /8
        block4 = self.conv_block4(pool3)   # /8
        pool4 = F.max_pool2d(block4, 2, 2) # /16
        block5 = self.conv_block5(pool4)   # /16
        pool5 = F.max_pool2d(block5, 2, 2) # /32
        N, __, __, __ = pool5.size()
        
        g = self.pool(pool5).view(N,512)
        a1, g1 = self.attn1(pool3, pool5)
        a2, g2 = self.attn2(pool4, pool5)
        g_hat = torch.cat((g,g1,g2), dim=1) # batch_size x C
        if self.dpt is not None:
            g_hat = self.dpt(g_hat)
        out = self.cls(g_hat)

        return [out, a1, a2]

In [6]:
model = AttnVGG(num_classes = 16,normalize_attn=True)
print(model)

AttnVGG(
  (conv_block1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (conv_block2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (conv_block3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0

In [7]:
from sklearn.metrics import confusion_matrix
def compute_metrics(all_labels,all_preds,num_classes,epoch):
    CM = confusion_matrix(all_labels, all_preds, labels=list(range(16)))
    acc = np.sum(np.diag(CM)) / np.sum(CM)
    
    class_sensitivity = []
    class_precision = []
    class_metrics = []
    
    for class_idx in range(num_classes):
        tp = CM[class_idx, class_idx]
        fn = np.sum(CM[class_idx, :]) - tp
        fp = np.sum(CM[:, class_idx]) - tp
        tn = np.sum(CM) - tp - fn - fp
        
        sensitivity = tp / (tp + fn)
        precision = tp / (tp + fp)
        class_sensitivity.append(sensitivity)
        class_precision.append(precision)
        class_metrics.append([sensitivity, precision])
        
    val_mean_sensitivity = np.mean(class_sensitivity)
    val_mean_precision = np.mean(class_precision)
    return acc,val_mean_sensitivity,val_mean_precision,CM

In [8]:
class EarlyStopping():
    def __init__(self,patience=5,min_delta=0,restore_best_weigths=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weigths = restore_best_weigths
        self.best_model = None
        self.best_loss = None
        self.counter = 0
        self.status = ""

    def __call__(self,model,val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
            self.best_model = copy.deepcopy(model)
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.best_model.load_state_dict(model.state_dict())
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.status = f"Stopped On {self.counter}"
                if self.restore_best_weigths:
                    model.load_state_dict(self.best_model.state_dict())
                return True
            self.status = f"{self.counter}/{self.patience}"
            return False

In [9]:
def visualize_attn(I, a, up_factor, nrow):
    # image
    img = I.permute((1,2,0)).cpu().numpy()
    # compute the heatmap
    if up_factor > 1:
        a = F.interpolate(a, scale_factor=up_factor, mode='bilinear', align_corners=False)
    attn = utils.make_grid(a, nrow=nrow, normalize=True, scale_each=True)
    attn = attn.permute((1,2,0)).mul(255).byte().cpu().numpy()
    attn = cv2.applyColorMap(attn, cv2.COLORMAP_JET)
    attn = cv2.cvtColor(attn, cv2.COLOR_BGR2RGB)
    attn = np.float32(attn) / 255
    # add the heatmap to the image
    vis = 0.6 * img + 0.4 * attn
    return torch.from_numpy(vis).permute(2,0,1)

In [10]:
best_val_loss = float("inf")
num_epochs = 30

In [11]:
es = EarlyStopping(0.00000001)

In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

lr_lambda = lambda epoch : np.power(0.1, epoch//10)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

In [16]:
writer = SummaryWriter("logs")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, labels in train_loader:
        inputs = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs,_,_= model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_loader)
    
    writer.add_scalar("Training Loss" , avg_train_loss,epoch)

    # Adjusting Learning Rate
    scheduler.step()

    model.eval()
    total = 0
    correct = 0
    val_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
          inputs = images.to(device)
          labels = labels.to(device)
          outputs,_,_ = model(inputs)
          loss = criterion(outputs,labels)
          val_loss += loss.item()*images.size(0)
          _,predict = torch.max(outputs, 1)
          total += labels.size(0)
          correct += (predict == labels).sum().item()
          all_preds.extend(predict.cpu().numpy())
          all_labels.extend(labels.cpu().numpy())
    val_loss /= len(val_loader.dataset)
    accuracy_val = 100*correct / total
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Traning Loss: {avg_train_loss:.4f}, Validation Loss: {val_loss: .4f}, Validation Accuracy,{accuracy_val:.2f}%')
    
    acc,val_mean_sensitivity,val_mean_precision,CM = compute_metrics(all_preds,all_labels,num_classes=16,epoch=epoch)
    writer.add_scalar('val/accuracy', acc*100, epoch)
    writer.add_scalar('val/mean_recall', val_mean_sensitivity,epoch)
    writer.add_scalar('val/precision_mel',val_mean_precision, epoch)
    writer.add_scalar("Validation Loss",val_loss,epoch)
    fig = plt.figure(figsize=(20,10))
    sns.heatmap(CM, annot=True, cmap="coolwarm")

    # Add the figure to the SummaryWriter
    writer.add_figure("heatmap", fig,global_step=epoch)

    # writer.close()
    
    if val_loss < best_val_loss:
      best_val_loss = val_loss
      checkpoint_path = '/home/rishab/alexnet_attention/saved_model_1'
      os.makedirs(checkpoint_path, exist_ok=True)
      checkpoint_path = os.path.join(checkpoint_path ,'best_model.pth')
      torch.save(model.state_dict(), checkpoint_path)
    
    data_iter = iter(val_loader)
    images, labels = next(data_iter)
    fixed_batch = images[0:16, :, :, :].to(device)
    log_images = True
    writer = SummaryWriter("logs")
    if log_images:
        
        I_train = utils.make_grid(inputs[0:16, :, :, :], nrow=4, normalize=True, scale_each=True)
        writer.add_image('train/image', I_train , global_step = epoch)
        
    
        I_val = utils.make_grid(fixed_batch, nrow=4, normalize=True, scale_each=True)
        writer.add_image('val/image', I_val,global_step = epoch)
    
    base_up_factor = 8
    
    if log_images:
        __, a1, a2 = model(inputs[0:16,:,:,:])
        if a1 is not None:
            attn1 = visualize_attn(I_train, a1, up_factor=base_up_factor, nrow=4)
            writer.add_image('train/attention_map_1', attn1, global_step = epoch)
        if a2 is not None:
            attn2 = visualize_attn(I_train, a2, up_factor=2*base_up_factor, nrow=4)
            writer.add_image('train/attention_map_2', attn2,global_step= epoch)
        # val data
        __, a1, a2 = model(fixed_batch)
        if a1 is not None:
            attn1 = visualize_attn(I_val, a1, up_factor=base_up_factor, nrow=4)
            writer.add_image('val/attention_map_1', attn1, global_step = epoch)
        if a2 is not None:
            attn2 = visualize_attn(I_val, a2, up_factor=2*base_up_factor, nrow=4)
            writer.add_image('val/attention_map_2', attn2, global_step = epoch) 
    if epoch == num_epochs - 1:
        checkpoint_path = '/home/rishab/alexnet_attention/last_epoch_model'
        os.makedirs(checkpoint_path, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_path ,'last_model.pth')
        torch.save(model.state_dict(), checkpoint_path)
    
    # if es(model,val_loss):
    #     print("Early Stopping")
    #     break

print('Training finished.')

Epoch [1/30], Traning Loss: 0.5915, Validation Loss:  0.7262, Validation Accuracy,77.95%
Epoch [2/30], Traning Loss: 0.5208, Validation Loss:  0.6321, Validation Accuracy,81.21%
Epoch [3/30], Traning Loss: 0.4684, Validation Loss:  0.6994, Validation Accuracy,79.52%
Epoch [4/30], Traning Loss: 0.4142, Validation Loss:  0.6892, Validation Accuracy,80.14%
Epoch [5/30], Traning Loss: 0.2555, Validation Loss:  0.5491, Validation Accuracy,84.16%
Epoch [6/30], Traning Loss: 0.2034, Validation Loss:  0.5480, Validation Accuracy,84.28%
Epoch [7/30], Traning Loss: 0.1721, Validation Loss:  0.5619, Validation Accuracy,84.32%
Epoch [8/30], Traning Loss: 0.1506, Validation Loss:  0.5788, Validation Accuracy,84.05%
Epoch [9/30], Traning Loss: 0.1308, Validation Loss:  0.5811, Validation Accuracy,84.39%
Epoch [10/30], Traning Loss: 0.1100, Validation Loss:  0.6010, Validation Accuracy,83.82%
Epoch [11/30], Traning Loss: 0.0949, Validation Loss:  0.5979, Validation Accuracy,84.28%
Epoch [12/30], Tran