<a href="https://colab.research.google.com/github/m-abbas-ansari/ASD-Classification/blob/main/ASD_Classification_LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Notebook Setup and Dataset Download

In [19]:
from google.colab import drive

drive.mount('./gdrive')

Mounted at ./gdrive


In [None]:
!pip -q install wget wandb torchmetrics
import wget
wget.download("https://zenodo.org/record/2647418/files/TrainingDataset.rar?download=1")
!mkdir Dataset
!unrar x TrainingDataset.rar Dataset

### Imports and Utilities

In [2]:
from PIL import Image
import os
import shutil
import numpy as np
import torch.utils.data as data
import cv2
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import models
import torch.optim as optim
from torchvision import transforms
import tensorflow as tf
import pandas as pd
import operator
from glob import glob
import wandb
import matplotlib.pyplot as plt
from torchmetrics.functional import accuracy, auc, precision_recall, f1_score

In [3]:
def read_dataset(anno_path):
    anno_dict = dict()
    max_len = dict()
    # Saliency4ASD has 300 images
    for i in range(1,301):
        img = cv2.imread(os.path.join(anno_path,'Images',str(i)+'.png'))
        y_lim, x_lim, _ = img.shape
        anno_dict[i] = dict()
        anno_dict[i]['img_size'] = [y_lim,x_lim]
        asd = pd.read_csv(os.path.join(anno_path,'ASD','ASD_scanpath_'+str(i)+'.txt'))
        ctrl = pd.read_csv(os.path.join(anno_path,'TD','TD_scanpath_'+str(i)+'.txt'))
        group_name = ['ctrl','asd']
        for flag, group in enumerate([ctrl, asd]):
            anno_dict[i][group_name[flag]] = dict()
            anno_dict[i][group_name[flag]]['fixation'] = []
            anno_dict[i][group_name[flag]]['duration'] = []
            cur_idx = list(group['Idx'])
            cur_x = list(group[' x'])
            cur_y = list(group[' y'])
            cur_dur = list(group[' duration'])
            tmp_fix = []
            tmp_dur = []
            for j in range(len(cur_idx)):
                # finish loading data for one subject
                if cur_idx[j] == 0  and j != 0:
                    anno_dict[i][group_name[flag]]['fixation'].append(tmp_fix)
                    anno_dict[i][group_name[flag]]['duration'].append(tmp_dur)
                    tmp_fix = []
                    tmp_dur = []
                tmp_fix.append([cur_y[j],cur_x[j]])
                tmp_dur.append(cur_dur[j])
            # save data of the last subject
            anno_dict[i][group_name[flag]]['fixation'].append(tmp_fix)
            anno_dict[i][group_name[flag]]['duration'].append(tmp_dur)

    return anno_dict

In [4]:
class Dataset(data.Dataset):
    def __init__(self,img_dir,data,max_len,img_height,img_width,transform):
        self.img_dir = img_dir
        self.initial_dataset(data)
        self.max_len = max_len
        self.img_height = img_height
        self.img_width = img_width
        self.transform = transform

    def initial_dataset(self,data):
        self.fixation = []
        self.duration = []
        self.label = []
        self.img_id = []
        self.img_size = []

        for img_id in data.keys():
            # if not img_id in valid_id:
            #     continue
            for group_label, group in enumerate(['ctrl','asd']):
                self.fixation.extend(data[img_id][group]['fixation'])
                self.duration.extend(data[img_id][group]['duration'])
                self.img_id.extend([os.path.join(self.img_dir,str(img_id)+'.png')]*len(data[img_id][group]['fixation']))
                self.label.extend([group_label]*len(data[img_id][group]['fixation']))
                self.img_size.extend([data[img_id]['img_size']]*len(data[img_id][group]['fixation']))

    def get_fix_dur(self,idx):
        fixs = self.fixation[idx]
        durs = self.duration[idx]
        y_lim, x_lim = self.img_size[idx]
        fixation = []
        duration = []
        invalid = 0
        # only consider the first k fixations
        for i in range(self.max_len):
            if i+1 <= len(fixs):
                y_fix, x_fix = fixs[i]
                dur = durs[i]
                x_fix = int(x_fix*(self.img_width/float(x_lim))/32)
                y_fix = int(y_fix*(self.img_height/float(y_lim))/33)
                if x_fix >=0 and y_fix>=0:
                    fixation.append(y_fix*25 + x_fix) # get the corresponding index of fixation on the downsampled feature map
                    duration.append(dur) # duration of corresponding fixation
                else:
                    invalid += 1
            else:
                fixation.append(0) # pad if necessary
                duration.append(0)
        for i in range(invalid):
            fixation.append(0)
            duration.append(0)
        fixation = torch.from_numpy(np.array(fixation).astype('int'))
        duration = torch.from_numpy(np.array(duration).astype('int'))
        return fixation, duration

    def __getitem__(self,index):
        img = Image.open(self.img_id[index])
        if self.transform is not None:
            img = self.transform(img)
        label = torch.FloatTensor([self.label[index]])
        fixation, duration = self.get_fix_dur(index)
        return img, label, fixation, duration

    def __len__(self,):
        return len(self.fixation)

In [10]:
import random

def simple_img_split(anno_dict, val_ratio):
  num_imgs = len(anno_dict.keys())
  print(num_imgs)
  val_idx = list(random.sample(range(1, num_imgs+1), k=int(val_ratio*num_imgs))) # randomly select validation images
  train_idx = list(set(val_idx) ^ set(range(1, num_imgs+1)))
  train_dict = {k: anno_dict[k] for k in train_idx}
  val_dict = {k: anno_dict[k] for k in val_idx}

  return train_dict, val_dict

os.makedirs('checkpoints', exist_ok=True)

### Archtitecture

In [6]:
def clip_gradient(optimizer, grad_clip):
    for group in optimizer.param_groups:
        for param in group['params']:
            param.grad.data.clamp_(-grad_clip, grad_clip)

def adjust_lr(optimizer, epoch):
    "adatively adjust lr based on epoch"
    if epoch <= 0 :
        lr = LR
    else :
        lr = LR * (0.5 ** (float(epoch) / 2))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [7]:
class G_LSTM(nn.Module):
	"""
	LSTM implementation proposed by A. Graves (2013),
	it has more parameters compared to original LSTM
	"""
	def __init__(self,input_size=2048,hidden_size=512):
		super(G_LSTM,self).__init__()
		# without batch_norm
		self.input_x = nn.Linear(input_size,hidden_size,bias=True)
		self.forget_x = nn.Linear(input_size,hidden_size,bias=True)
		self.output_x = nn.Linear(input_size,hidden_size,bias=True)
		self.memory_x = nn.Linear(input_size,hidden_size,bias=True)

		self.input_h = nn.Linear(hidden_size,hidden_size,bias=True)
		self.forget_h = nn.Linear(hidden_size,hidden_size,bias=True)
		self.output_h = nn.Linear(hidden_size,hidden_size,bias=True)
		self.memory_h = nn.Linear(hidden_size,hidden_size,bias=True)

		self.input_c = nn.Linear(hidden_size,hidden_size,bias=True)
		self.forget_c = nn.Linear(hidden_size,hidden_size,bias=True)
		self.output_c = nn.Linear(hidden_size,hidden_size,bias=True)

	def forward(self,x,state):
		h, c = state
		i = torch.sigmoid(self.input_x(x) + self.input_h(h) + self.input_c(c))
		f = torch.sigmoid(self.forget_x(x) + self.forget_h(h) + self.forget_c(c))
		g = torch.tanh(self.memory_x(x) + self.memory_h(h))

		next_c = torch.mul(f,c) + torch.mul(i,g)
		o = torch.sigmoid(self.output_x(x) + self.output_h(h) + self.output_c(next_c))
		h = torch.mul(o,next_c)
		state = (h,next_c)

		return state

In [8]:
# Original code from https://github.com/szzexpoi/attention_asd_screening/blob/master/model/model.py
# Code was modified to include time-dependent representation techniques 
# i.e time-masking and time-event joint embedding

class Sal_seq(nn.Module):
    def __init__(self, backend, seq_len, all_lstm=False, crop_seq=False, mask=False, joint=False, time_proj_dim=128, hidden_size=512):
        super(Sal_seq,self).__init__()
        self.seq_len = seq_len
        self.mask = mask
        self.joint = joint
        self.crop = crop_seq
        self.all = all_lstm
        # defining backend
        if backend == 'resnet':
            resnet = models.resnet50(pretrained=True)
            self.init_resnet(resnet)
            input_size = 2048
        elif backend == 'vgg':
            vgg = models.vgg19(pretrained=True)
            self.init_vgg(vgg)
            input_size = 512
        else:
            assert 0, 'Backend not implemented'
        
        self.rnn = G_LSTM(input_size,hidden_size)
        self.decoder = nn.Linear(hidden_size,1,bias=True) # comment for multi-modal distillation
        self.hidden_size = hidden_size
        if self.all:
          self.predecode = nn.Linear(hidden_size*seq_len, hidden_size, bias=True)

        if self.mask or self.joint:
          self.time_projection = nn.Linear(1, time_proj_dim) # project duration to a time emedding space
          if self.mask:
            self.project_emb = nn.Linear(time_proj_dim, input_size) # take time embedding to the same emedding as input feature
          if self.joint:
            self.time_emb = nn.Linear(time_proj_dim, input_size, bias=False) # time embedding matrix 

    def init_resnet(self,resnet):
        self.backend = nn.Sequential(*list(resnet.children())[:-2])

    def init_vgg(self,vgg):
        # self.backend = vgg.features
        self.backend = nn.Sequential(*list(vgg.features.children())[:-2]) # omitting the last Max Pooling

    def init_hidden(self,batch): #initializing hidden state as all zero
        h = torch.zeros(batch,self.hidden_size).cuda()
        c = torch.zeros(batch,self.hidden_size).cuda()
        return (Variable(h),Variable(c))

    def process_lengths(self,input):
        """
        Computing the lengths of sentences in current batchs
        """
        max_length = input.size(1)
        lengths = list(max_length - input.data.eq(0).sum(1).squeeze())
        return lengths

    def crop_seq(self,x,lengths):
        """
        Adaptively select the hidden state at the end of sentences
        """
        batch_size = x.size(0)
        seq_length = x.size(1)
        mask = x.data.new().resize_as_(x.data).fill_(0)
        for i in range(batch_size):
            mask[i][lengths[i]-1].fill_(1)
        mask = Variable(mask)
        x = x.mul(mask)
        x = x.sum(1).view(batch_size, x.size(2))
        return x

    def forward(self,x,fixation, duration):
        valid_len = self.process_lengths(fixation) # computing valid fixation lengths
        x = self.backend(x) # [12, 2048, 19, 25]
        batch, feat, h, w = x.size()
        x = x.view(batch,feat,-1) # [12, 2048, 475]
        
        # recurrent loop
        state = self.init_hidden(batch) # initialize hidden state
        fixation = fixation.view(fixation.size(0),1,fixation.size(1)) # [12, 1, 14]
        fixation = fixation.expand(fixation.size(0),feat,fixation.size(2)) # [12, 2048, 14]
        #print(f'x before gather: {x}')
        x = x.gather(2,fixation)
        output = []
        for i in range(self.seq_len):
            # extract features corresponding to current fixation
            cur_x = x[:,:,i].contiguous()
            if self.mask or self.joint:
              cur_t = duration[:, i].contiguous().unsqueeze(1)
              cur_t_proj = self.time_projection(cur_t)
              
              if self.joint: # time-event joint embedding
                cur_t_enc = torch.softmax(cur_t_proj, dim=1)
                cur_t_emb = self.time_emb(cur_t_enc)
                cur_x = (cur_x + cur_t_emb)/2.0

              if self.mask: # time mask
                cur_t_proj = torch.relu(cur_t_proj)
                cur_t_proj = self.project_emb(cur_t_proj)
                time_mask = torch.sigmoid(cur_t_proj)
                cur_x = torch.mul(cur_x, time_mask)

            #LSTM forward
            state = self.rnn(cur_x,state)
            out = state[0].view(batch,1,self.hidden_size)
            output.append(out)
        
        # selecting hidden states from the valid fixations without padding
        output = torch.cat(output, 1) # [12, 14, 512]
        if self.all:
          seq_len = output.size(1)
          output = self.predecode(output.view(batch, self.hidden_size*seq_len))
        else:
          if self.crop:
            output = self.crop_seq(output,valid_len)
          else:
            output = output[:,-1,:].view(batch, self.hidden_size) # select the last output state of LSTM
        output = torch.sigmoid(self.decoder(output))
        return output

### Hyperparameters

In [11]:
LR = 1e-4
img_dir = 'Dataset/TrainingData/Images'
anno_dir = 'Dataset/TrainingData'
backend = 'resnet'
checkpoint_path= 'checkpoints'
num_epochs = 10
val_ratio = 0.1
batch_size = 12
max_len = 14
clip = 10
img_height = 600
img_width = 800


#Sal_seq
hidden_size = 512
mask = False        # flag to whether or not to create a time mask to endcode duration into fixations
joint = False       # flag to whether or not to create a time-event joint embedding
crop_seq = True
all_lstm = False
time_proj_dim = 256

### Data Loading and Model Intialization

In [12]:
anno = read_dataset(anno_dir)

transform = transforms.Compose([transforms.Resize((img_height,img_width)),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

train_data, val_data = simple_img_split(anno, val_ratio)

train_set = Dataset(img_dir, train_data, max_len, img_height, img_width, transform)
val_set = Dataset(img_dir, val_data, max_len, img_height, img_width, transform)
trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)

300


In [None]:
if mask or joint:
  model = Sal_seq(backend=backend,seq_len=max_len,all_lstm=all_lstm,crop_seq=crop_seq, mask=mask,joint=joint,time_proj_dim=time_proj_dim,hidden_size=hidden_size)  
else:
  model = Sal_seq(backend=backend,seq_len=max_len,all_lstm=all_lstm,crop_seq=crop_seq, hidden_size=hidden_size)

if torch.cuda.is_available():
  model = model.cuda()

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=1e-5) 

In [None]:
total_params = sum([p.numel() for p in model.parameters()])
print(f"Total parameters in model: {total_params:,}")

### Train and Validation Loop

In [16]:
from tqdm import tqdm


def train(iteration):
    avg_loss = 0
    preds = []
    targets = []
    for j, (img,target,fix, dur) in enumerate(tqdm(trainloader)):
        if len(img) < batch_size:
            continue
        img, target, fix, dur = Variable(img), Variable(target.type(torch.FloatTensor)), Variable(fix,requires_grad=False), Variable(dur.type(torch.FloatTensor), requires_grad=False)
        if torch.cuda.is_available():
          img, target, fix, dur = img.cuda(), target.cuda(), fix.cuda(), dur.cuda()
        optimizer.zero_grad()

        pred = model(img,fix, dur)
        loss = F.binary_cross_entropy(pred,target)
        loss.backward()
        if clip != -1:
            clip_gradient(optimizer,clip)
        optimizer.step()
        avg_loss = (avg_loss*np.maximum(0,j) + loss.data.cpu().numpy())/(j+1)

        if j%25 == 0:
            wandb.log({'bce loss': avg_loss}, step=iteration)
        iteration += 1

        preds.append(pred.cpu())
        targets.append(target.to(torch.int16).cpu())
    with torch.no_grad():
      preds = torch.cat(preds, 0)
      targets = torch.cat(targets, 0)
      acc = accuracy(preds, targets)
      auc_v = auc(preds, targets, reorder=True)
      pre, rec = precision_recall(preds, targets)
      score = f1_score(preds, targets)
      print(f'\nT {epoch}: acc = {acc.item():.2f} auc = {auc_v.item():.2f} pre = {pre.item():.2f} rec = {rec.item():.2f} f1_score = {score.item():.2f}')

    return iteration

In [17]:
def validation(epoch):
  
  preds = []
  targets = []
  with torch.no_grad():
    for _, (img,target,fix, dur) in enumerate(valloader):
      img, target, fix, dur = Variable(img), Variable(target.type(torch.FloatTensor)), Variable(fix,requires_grad=False), Variable(dur.type(torch.FloatTensor), requires_grad=False)
      img, fix, dur = img.cuda(), fix.cuda(), dur.cuda()
      #print(f'img: {img.shape} target: {target} fix: {fix.shape}')
      # break
      pred = model(img,fix, dur)
      preds.append(pred.cpu())
      targets.append(target.to(torch.int16))

  preds = torch.cat(preds, 0)
  targets = torch.cat(targets, 0)
  
  acc = accuracy(preds, targets)
  auc_v = auc(preds, targets, reorder=True)
  pre, rec = precision_recall(preds, targets)
  score = f1_score(preds, targets)
  print(f'V {epoch}: acc = {acc.item():.2f} auc = {auc_v.item():.2f} pre = {pre.item():.2f} rec = {rec.item():.2f} f1_score = {score.item():.2f}')
  wandb.log({'accuracy': acc.item(), 
             'auc': auc_v.item(),
             'precision': pre.item(),
             'recall': rec.item(),
             'f1 score': score.item()})
  return score.item()


### Logging setup and Training

In [None]:
name_of_run="rand"
wandb.init(project="asd-lstm", name=name_of_run)

In [21]:
os.makedirs("gdrive/MyDrive/ASD/LSTM-Models", exist_ok=True)

In [None]:
iteration = 0
best_f1, best_epoch = 0,0
f1_s = validation(0)
for epoch in range(num_epochs):
    adjust_lr(optimizer,epoch)
    
    iteration = train(iteration)
    f1_s = validation(epoch)
  
    if f1_s > best_f1:
        torch.save(model.state_dict(),os.path.join(checkpoint_path,'best_model_epoch'+str(epoch)+'.pth'))
        best_f1 = f1_s
        best_epoch = epoch

print(f'Best F1 score at epoch {best_epoch}: {f1_s}')
dir_name = f'gdrive/MyDrive/ASD/LSTM-Models/{name_of_run}/'
src_file = f'{checkpoint_path}/best_model_epoch{best_epoch}.pth'
dest = f'gdrive/MyDrive/ASD/LSTM-Models/{name_of_run}/best_model_epoch{best_epoch}.pth'
os.makedirs(dir_name, exist_ok=True)
shutil.move(src_file, dest)