<a href="https://colab.research.google.com/github/m-abbas-ansari/ASD-Classification/blob/main/ASD_Classification_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Notebook Setup and Dataset Download

In [None]:
from google.colab import drive

drive.mount('./gdrive')

In [None]:
!pip -q install wget wandb torchmetrics
import wget
wget.download("https://zenodo.org/record/2647418/files/TrainingDataset.rar?download=1")
!mkdir Dataset
!unrar x TrainingDataset.rar Dataset

### Imports and Utilities

In [3]:
from PIL import Image
import os
import shutil
import numpy as np
import torch.utils.data as data
import cv2
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import models
import torch.optim as optim
from torchvision import transforms
import tensorflow as tf
import pandas as pd
import operator
from glob import glob
import wandb
import matplotlib.pyplot as plt
from torchmetrics.functional import accuracy, auc, precision_recall, f1_score

In [4]:
def read_dataset(anno_path):
    anno_dict = dict()
    max_len = dict()
    # Saliency4ASD has 300 images
    for i in range(1,301):
        img = cv2.imread(os.path.join(anno_path,'Images',str(i)+'.png'))
        y_lim, x_lim, _ = img.shape
        anno_dict[i] = dict()
        anno_dict[i]['img_size'] = [y_lim,x_lim]
        asd = pd.read_csv(os.path.join(anno_path,'ASD','ASD_scanpath_'+str(i)+'.txt'))
        ctrl = pd.read_csv(os.path.join(anno_path,'TD','TD_scanpath_'+str(i)+'.txt'))
        group_name = ['ctrl','asd']
        for flag, group in enumerate([ctrl, asd]):
            anno_dict[i][group_name[flag]] = dict()
            anno_dict[i][group_name[flag]]['fixation'] = []
            anno_dict[i][group_name[flag]]['duration'] = []
            cur_idx = list(group['Idx'])
            cur_x = list(group[' x'])
            cur_y = list(group[' y'])
            cur_dur = list(group[' duration'])
            tmp_fix = []
            tmp_dur = []
            for j in range(len(cur_idx)):
                # finish loading data for one subject
                if cur_idx[j] == 0  and j != 0:
                    anno_dict[i][group_name[flag]]['fixation'].append(tmp_fix)
                    anno_dict[i][group_name[flag]]['duration'].append(tmp_dur)
                    tmp_fix = []
                    tmp_dur = []
                tmp_fix.append([cur_y[j],cur_x[j]])
                tmp_dur.append(cur_dur[j])
            # save data of the last subject
            anno_dict[i][group_name[flag]]['fixation'].append(tmp_fix)
            anno_dict[i][group_name[flag]]['duration'].append(tmp_dur)

    return anno_dict

In [5]:
class Dataset(data.Dataset):
    def __init__(self,img_dir,data,max_len,img_height,img_width,transform):
        self.img_dir = img_dir
        self.initial_dataset(data)
        self.max_len = max_len
        self.img_height = img_height
        self.img_width = img_width
        self.transform = transform

    def initial_dataset(self,data):
        self.fixation = []
        self.duration = []
        self.label = []
        self.img_id = []
        self.img_size = []

        for img_id in data.keys():
            # if not img_id in valid_id:
            #     continue
            for group_label, group in enumerate(['ctrl','asd']):
                self.fixation.extend(data[img_id][group]['fixation'])
                self.duration.extend(data[img_id][group]['duration'])
                self.img_id.extend([os.path.join(self.img_dir,str(img_id)+'.png')]*len(data[img_id][group]['fixation']))
                self.label.extend([group_label]*len(data[img_id][group]['fixation']))
                self.img_size.extend([data[img_id]['img_size']]*len(data[img_id][group]['fixation']))

    def get_fix_dur(self,idx):
        fixs = self.fixation[idx]
        durs = self.duration[idx]
        y_lim, x_lim = self.img_size[idx]
        fixation = []
        duration = []
        invalid = 0
        # only consider the first k fixations
        for i in range(self.max_len):
            if i+1 <= len(fixs):
                y_fix, x_fix = fixs[i]
                dur = durs[i]
                x_fix = int(x_fix*(self.img_width/float(x_lim))/32)
                y_fix = int(y_fix*(self.img_height/float(y_lim))/33)
                if x_fix >=0 and y_fix>=0:
                    fixation.append(y_fix*25 + x_fix) # get the corresponding index of fixation on the downsampled feature map
                    duration.append(dur) # duration of corresponding fixation
                else:
                    invalid += 1
            else:
                fixation.append(0) # pad if necessary
                duration.append(0)
        for i in range(invalid):
            fixation.append(0)
            duration.append(0)
        fixation = torch.from_numpy(np.array(fixation).astype('int'))
        duration = torch.from_numpy(np.array(duration).astype('int'))
        return fixation, duration

    def __getitem__(self,index):
        img = Image.open(self.img_id[index])
        if self.transform is not None:
            img = self.transform(img)
        label = torch.FloatTensor([self.label[index]])
        fixation, duration = self.get_fix_dur(index)
        return img, label, fixation, duration

    def __len__(self,):
        return len(self.fixation)

In [6]:
import random

def simple_img_split(anno_dict, val_ratio):
  num_imgs = len(anno_dict.keys())
  print(num_imgs)
  val_idx = list(random.sample(range(1, num_imgs+1), k=int(val_ratio*num_imgs))) # randomly select validation images
  train_idx = list(set(val_idx) ^ set(range(1, num_imgs+1)))
  train_dict = {k: anno_dict[k] for k in train_idx}
  val_dict = {k: anno_dict[k] for k in val_idx}

  return train_dict, val_dict

os.makedirs('checkpoints', exist_ok=True)

### Archtitecture

In [7]:
def clip_gradient(optimizer, grad_clip):
    for group in optimizer.param_groups:
        for param in group['params']:
            param.grad.data.clamp_(-grad_clip, grad_clip)

def adjust_lr(optimizer, epoch):
    "adatively adjust lr based on epoch"
    if epoch <= 0 :
        lr = LR
    else :
        lr = LR * (0.5 ** (float(epoch) / 2))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [8]:
class PositionalEncoding(nn.Module):
    """
    compute sinusoid encoding.
    """
    def __init__(self, d_model, max_len, device):
        """
        constructor of sinusoid encoding class

        :param d_model: dimension of model
        :param max_len: max sequence length
        :param device: hardware device setting
        """
        super(PositionalEncoding, self).__init__()

        # same size with input matrix (for adding with input matrix)
        self.encoding = torch.zeros(max_len, d_model).cuda()
        self.encoding.requires_grad = False  # we don't need to compute gradient

        pos = torch.arange(0, max_len).cuda()
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D unsqueeze to represent word's position

        _2i = torch.arange(0, d_model, step=2).float().cuda()
        # 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
        # "step=2" means 'i' multiplied with two (same with 2 * i)

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model))).cuda()
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model))).cuda()
        #print(f'shape of encoding = {self.encoding.size()}')
        # compute positional encoding to consider positional information of words

    def forward(self, x):
        batch, seq_len, d_model = x.size()
        return self.encoding.expand(batch, seq_len, d_model) 

In [9]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        # Get number of training examples
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        query = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)  # (N, value_len, heads, head_dim)
        keys = self.keys(keys)  # (N, key_len, heads, head_dim)
        queries = self.queries(query)  # (N, query_len, heads, heads_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

        # Mask padded indices so their weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        # attention shape: (N, heads, query_len, key_len)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, query_len, embed_size)

        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        # Add skip connection, run through normalization and finally dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


class Encoder(nn.Module):
    def __init__(
        self,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
    ):
        print(f'Got parameters: embed_size = {embed_size} num_layers = {num_layers} heads = {heads} device = {device} forward_expansion = {forward_expansion}\
        dropout = {dropout} max_lenght = {max_length}')
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.position_embedding = PositionalEncoding(embed_size, max_length, device)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        N, seq_length, emb_dim = x.shape
        out = self.dropout(
            (x+ self.position_embedding(x))
        )

        # In the Encoder the query, key, value are all the same, it's in the
        # decoder this will change. This might look a bit odd in this case.
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

In [10]:
class Sal_transformer(nn.Module):
  def __init__(self, backend, seq_len, num_blocks, heads, device, expansion, drop_prob, emb_dim):
      super(Sal_transformer,self).__init__()
      print(f'Got parameters: backend = {backend} num_blocks = {num_blocks} heads = {heads} device = {device} expansion = {expansion}\
        drop_prob = {drop_prob} emb_dim = {emb_dim}')
      self.seq_len = seq_len
      self.emb_dim = emb_dim
      # defining backend
      if backend == 'resnet':
          resnet = models.resnet50(pretrained=True)
          self.init_resnet(resnet)
          self.input_size = 2048
      elif backend == 'vgg':
          vgg = models.vgg19(pretrained=True)
          self.init_vgg(vgg)
          self.input_size = 512
      else:
          assert 0, f"Backend '{backend}' not implemented"
      if self.input_size != emb_dim:
        self.project = nn.Linear(self.input_size, emb_dim)
      self.transformer = Encoder(emb_dim, num_blocks, heads, device, expansion, drop_prob, seq_len)
      self.decoder = nn.Sequential(
          nn.Linear(emb_dim, emb_dim//4),
          nn.Linear(emb_dim//4, emb_dim//8),
          nn.Flatten(),
          nn.Linear((emb_dim//8)*seq_len, 1))

  def init_resnet(self,resnet):
      self.backend = nn.Sequential(*list(resnet.children())[:-2])

  def init_vgg(self,vgg):
      # self.backend = vgg.features
      self.backend = nn.Sequential(*list(vgg.features.children())[:-2]) # omitting the last Max Pooling

  def forward(self,x,fixation, duration):
      x = self.backend(x) # [12, 2048, 19, 25]
      batch, feat, h, w = x.size()
      x = x.view(batch,feat,-1) # [12, 2048, 475]
      
      # recurrent loop
      fixation = fixation.view(fixation.size(0),1,fixation.size(1)) # [12, 1, 14]
      fixation = fixation.expand(fixation.size(0),feat,fixation.size(2)) # [12, 2048, 14]
      #print(f'x before gather: {x}')
      x = x.gather(2,fixation).transpose(1,2)
      if self.input_size != self.emb_dim:
        x = self.project(x)
      output = self.transformer(x)

      output = torch.sigmoid(self.decoder(output))
      return output

### Hyperparameters

In [11]:
LR = 1e-4
img_dir = 'Dataset/TrainingData/Images'
anno_dir = 'Dataset/TrainingData'
backend = 'resnet'
checkpoint_path= 'checkpoints'
num_epochs = 10
val_ratio = 0.1
batch_size = 12
max_len = 14
clip = 10
img_height = 600
img_width = 800

#Sal_transformer
num_blocks=1 
heads=4
device='gpu'
expansion=4
drop_prob=0.4
emb_dim=2048

### Data Loading and Model Intialization

In [None]:
anno = read_dataset(anno_dir)

transform = transforms.Compose([transforms.Resize((img_height,img_width)),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

train_data, val_data = simple_img_split(anno, val_ratio)

train_set = Dataset(img_dir, train_data, max_len, img_height, img_width, transform)
val_set = Dataset(img_dir, val_data, max_len, img_height, img_width, transform)
trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)

300


In [None]:
model = Sal_transformer(backend, 
                        max_len, 
                        num_blocks, 
                        heads, 
                        device, 
                        expansion, 
                        drop_prob,
                        emb_dim)

if torch.cuda.is_available():
  model = model.cuda()

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=1e-5) 

In [None]:
total_params = sum([p.numel() for p in model.parameters()])
print(f"Total parameters in model: {total_params:,}")

### Train and Validation Loop

In [None]:
from tqdm import tqdm


def train(iteration):
    avg_loss = 0
    preds = []
    targets = []
    for j, (img,target,fix, dur) in enumerate(tqdm(trainloader)):
        if len(img) < batch_size:
            continue
        img, target, fix, dur = Variable(img), Variable(target.type(torch.FloatTensor)), Variable(fix,requires_grad=False), Variable(dur.type(torch.FloatTensor), requires_grad=False)
        if torch.cuda.is_available():
          img, target, fix, dur = img.cuda(), target.cuda(), fix.cuda(), dur.cuda()
        optimizer.zero_grad()

        pred = model(img,fix, dur)
        loss = F.binary_cross_entropy(pred,target)
        loss.backward()
        if clip != -1:
            clip_gradient(optimizer,clip)
        optimizer.step()
        avg_loss = (avg_loss*np.maximum(0,j) + loss.data.cpu().numpy())/(j+1)

        if j%25 == 0:
            wandb.log({'bce loss': avg_loss}, step=iteration)
        iteration += 1

        preds.append(pred.cpu())
        targets.append(target.to(torch.int16).cpu())
    with torch.no_grad():
      preds = torch.cat(preds, 0)
      targets = torch.cat(targets, 0)
      acc = accuracy(preds, targets)
      auc_v = auc(preds, targets, reorder=True)
      pre, rec = precision_recall(preds, targets)
      score = f1_score(preds, targets)
      print(f'\nT {epoch}: acc = {acc.item():.2f} auc = {auc_v.item():.2f} pre = {pre.item():.2f} rec = {rec.item():.2f} f1_score = {score.item():.2f}')

    return iteration

In [None]:
def validation(epoch):
  
  preds = []
  targets = []
  with torch.no_grad():
    for _, (img,target,fix, dur) in enumerate(valloader):
      img, target, fix, dur = Variable(img), Variable(target.type(torch.FloatTensor)), Variable(fix,requires_grad=False), Variable(dur.type(torch.FloatTensor), requires_grad=False)
      img, fix, dur = img.cuda(), fix.cuda(), dur.cuda()
      #print(f'img: {img.shape} target: {target} fix: {fix.shape}')
      # break
      pred = model(img,fix, dur)
      preds.append(pred.cpu())
      targets.append(target.to(torch.int16))

  preds = torch.cat(preds, 0)
  targets = torch.cat(targets, 0)
  
  acc = accuracy(preds, targets)
  auc_v = auc(preds, targets, reorder=True)
  pre, rec = precision_recall(preds, targets)
  score = f1_score(preds, targets)
  print(f'V {epoch}: acc = {acc.item():.2f} auc = {auc_v.item():.2f} pre = {pre.item():.2f} rec = {rec.item():.2f} f1_score = {score.item():.2f}')
  wandb.log({'accuracy': acc.item(), 
             'auc': auc_v.item(),
             'precision': pre.item(),
             'recall': rec.item(),
             'f1 score': score.item()})
  return score.item()


### Logging setup and Training

In [None]:
name_of_run="rand"
wandb.init(project="asd-trans", name=name_of_run)

In [None]:
os.makedirs("gdrive/MyDrive/ASD/Transformer-Models", exist_ok=True)

In [None]:
iteration = 0
best_f1, best_epoch = 0,0
f1_s = validation(0)
for epoch in range(num_epochs):
    adjust_lr(optimizer,epoch)
    
    iteration = train(iteration)
    f1_s = validation(epoch)
  
    if f1_s > best_f1:
        torch.save(model.state_dict(),os.path.join(checkpoint_path,'best_model_epoch'+str(epoch)+'.pth'))
        best_f1 = f1_s
        best_epoch = epoch

print(f'Best F1 score at epoch {best_epoch}: {f1_s}')
dir_name = f'gdrive/MyDrive/ASD/Transformer-Models/{name_of_run}/'
src_file = f'{checkpoint_path}/best_model_epoch{best_epoch}.pth'
dest = f'gdrive/MyDrive/ASD/Transformer-Models/{name_of_run}/best_model_epoch{best_epoch}.pth'
os.makedirs(dir_name, exist_ok=True)
shutil.move(src_file, dest)