# GPT-2 & ResNet50

## Dataloader

In [1]:
import os
import glob
from PIL import Image

from torch.utils.data import Dataset
import torchvision.transforms as transforms


class SurgicalVQADataset(Dataset):
    def __init__(self, seq, folder_head, folder_tail, labels, transform=None):
        
        self.transform = transform
        
        # files, question and answers
        filenames = []
        for curr_seq in seq: filenames = filenames + glob.glob(folder_head + str(curr_seq) + folder_tail)
        self.vqas = []
        for file in filenames:
            file_data = open(file, "r")
            lines = [line.strip("\n") for line in file_data if line != "\n"]
            file_data.close()
            for line in lines: self.vqas.append([file, line])
        print('Total files: %d | Total question: %.d' %(len(filenames), len(self.vqas)))
        
        # Labels
        self.labels = labels
        
    def __len__(self):
        return len(self.vqas)

    def __getitem__(self, idx):
        
        # img
        loc = self.vqas[idx][0].split('/')
        img_loc = os.path.join(loc[0],loc[1],loc[2], 'left_frames',loc[-1].split('_')[0]+'.png')
        img = Image.open(img_loc)
        if self.transform: img = self.transform(img)
            
        # question and answer
        question = self.vqas[idx][1].split('|')[0]
        label = self.labels.index(str(self.vqas[idx][1].split('|')[1]))

        return img, question, label

## Model Version 1

In [2]:
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
from torchvision import models

class Surgical_VQA(nn.Module):
    def __init__(self, num_classes=12):
        super(Surgical_VQA, self).__init__()

        # text processing
        self.text_feature_extractor = SentenceTransformer('bert-base-nli-mean-tokens')
        # image processing
        self.img_feature_extractor = models.resnet50(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #classifier
        self.classifier = nn.Linear(2816, num_classes)

    def forward(self, img, text):
        img_feature = self.img_feature_extractor(img)
        
        text_feature = self.text_feature_extractor.encode(text)
        text_feature = torch.tensor(text_feature).cuda()
        
        img_text_features = torch.cat((img_feature, text_feature), dim=1)
        
        out = self.classifier(img_text_features)
        return out

## Model Version 2

In [2]:
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
from torchvision import models

class Surgical_VQA(nn.Module):
    def __init__(self, num_classes=12):
        super(Surgical_VQA, self).__init__()

        # text processing
        self.text_feature_extractor = SentenceTransformer('bert-large-nli-mean-tokens')
        # image processing
        self.img_feature_extractor = models.resnet50(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #classifier
        self.classifier = nn.Linear(3072, num_classes)

    def forward(self, img, text):
        img_feature = self.img_feature_extractor(img)
        
        text_feature = self.text_feature_extractor.encode(text)
        text_feature = torch.tensor(text_feature).cuda()
        
        img_text_features = torch.cat((img_feature, text_feature), dim=1)
        
        out = self.classifier(img_text_features)
        return out

## Model Version 3

In [2]:
import torch
import torch.nn as nn
from torchvision import models

import nltk
nltk.download('punkt')
from InferSent.models import InferSent

class Surgical_VQA(nn.Module):
    def __init__(self, num_classes=12):
        super(Surgical_VQA, self).__init__()

        # text processing
        params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,
                        'pool_type': 'max', 'dpout_model': 0.0, 'version': 2}
        self.text_feature_extractor = InferSent(params_model)
        self.text_feature_extractor.load_state_dict(torch.load('InferSent/encoder/infersent2.pkl'))
        self.text_feature_extractor.set_w2v_path('InferSent/fastText/crawl-300d-2M.vec')
        self.text_feature_extractor.build_vocab_k_words(K=100000)
        
        # image processing
        self.img_feature_extractor = models.resnet50(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #classifier
        self.classifier = nn.Linear(6144, num_classes)

    def forward(self, img, text):
        img_feature = self.img_feature_extractor(img)
        
        text_feature = self.text_feature_extractor.encode(text) #infersent.encode(query)[0]
        text_feature = torch.tensor(text_feature).cuda()
        
        img_text_features = torch.cat((img_feature, text_feature), dim=1)
        
        out = self.classifier(img_text_features)
        return out

[nltk_data] Downloading package punkt to /home/mobarak/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Model Version 4

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

from transformers import GPT2Tokenizer, GPT2Model


class Surgical_VQA(nn.Module):
    def __init__(self, num_classes=12):
        super(Surgical_VQA, self).__init__()

        # text processing
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.text_feature_extractor = GPT2Model.from_pretrained('gpt2')
 
        # image processing
        self.img_feature_extractor = models.resnet50(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #classifier
        self.classifier = nn.Linear(2816, num_classes)

    def forward(self, img, text):
        
        # image
        img_feature = self.img_feature_extractor(img)
        
        # text
        encoded_text = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
        encoded_text['input_ids'] = encoded_text['input_ids'].cuda()
        encoded_text['attention_mask'] = encoded_text['attention_mask'].cuda()
        text_feature = self.text_feature_extractor(**encoded_text)
        text_feature = text_feature.last_hidden_state.swapaxes(1,2)
        text_feature = F.adaptive_avg_pool1d(text_feature,1)
        text_feature = text_feature.swapaxes(1,2).squeeze(1)        
        img_text_features = torch.cat((img_feature, text_feature), dim=1)
        
        out = self.classifier(img_text_features)
        return out

## Metrics

In [3]:
import numpy as np

from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score

def calc_acc(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    return acc

def calc_classwise_acc(y_true, y_pred):
    matrix = confusion_matrix(y_true, y_pred)
    classwise_acc = matrix.diagonal()/matrix.sum(axis=1)
    return classwise_acc

def calc_map(y_true, y_scores):
    mAP = average_precision_score(y_true, y_scores,average=None)
    return mAP

## Test model

In [4]:
import torch.nn.functional as F

def test_model(epoch, model, valid_dataloader):
    
    model.eval()

    total_loss = 0.0    
    label_true = None
    label_pred = None
    label_score = None
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for i, (imgs, q, labels) in enumerate(valid_dataloader, 0):
            questions = []
            for question in q: questions.append(question)
            imgs, labels = imgs.cuda(), labels.cuda()
            
            outputs = model(imgs, questions)

            loss = criterion(outputs,labels)
            total_loss += loss.item()
        
            scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)    
            label_true = labels.data.cpu() if label_true == None else torch.cat((label_true, labels.data.cpu()), 0)
            label_pred = predicted.data.cpu() if label_pred == None else torch.cat((label_pred, predicted.data.cpu()), 0)
            label_score = scores.data.cpu() if label_score == None else torch.cat((label_score, scores.data.cpu()), 0)

            
    acc, c_acc, mAP = calc_acc(label_true, label_pred), calc_classwise_acc(label_true, label_pred), 0.0#calc_map(label_true, label_score)

    print('Test: epoch: %d loss: %.6f | Acc: %.6f | mAP: %.6f' %(epoch, total_loss, acc, mAP))
    print(c_acc)
    
    return (acc, c_acc, mAP)

## Train model

In [5]:
from torch import optim
def train_model(epoch, model, train_dataloader, lr):  # train model
    
    model.train()
    
    total_loss = 0.0    
    label_true = None
    label_pred = None
    label_score = None
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr = lr, weight_decay = 0)
    
    for i, (imgs, q, labels) in enumerate(train_dataloader,0):
        questions = []
        for question in q: questions.append(question)
        imgs, labels = imgs.cuda(), labels.cuda()
        
        # zero the parameter gradients
        optimizer.zero_grad()

        outputs = model(imgs, questions)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        total_loss += loss.item()
        
        scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)    
        label_true = labels.data.cpu() if label_true == None else torch.cat((label_true, labels.data.cpu()), 0)
        label_pred = predicted.data.cpu() if label_pred == None else torch.cat((label_pred, predicted.data.cpu()), 0)
        label_score = scores.data.cpu() if label_score == None else torch.cat((label_score, scores.data.cpu()), 0)

    
    # loss and acc
    acc, c_acc, mAP = calc_acc(label_true, label_pred), calc_classwise_acc(label_true, label_pred), 0.0#calc_map(label_true, label_score)

    print('Train: epoch: %d loss: %.6f | Acc: %.6f | mAP: %.6f' %(epoch, total_loss, acc, mAP))
    return

## Main

In [6]:
import os
import torch

from torchvision import transforms
from torch.utils.data import DataLoader

os.environ["CUDA_VISIBLE_DEVICES"]="2"

def seed_everything(seed=27):
    '''
    Set random seed for reproducible experiments
    Inputs: seed number 
    '''
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
if __name__ == "__main__":
     
    # Set random seed
    seed_everything()  
    
    # Device Count
    num_gpu = torch.cuda.device_count()
    
    # hyperparameters
    bs = 32
    epochs = 150
    lr = 0.00001
    
    checkpoint_dir = 'checkpoints/v4/simple/'
    
    # train and test dataloader
    train_seq = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    val_seq = [1, 5, 16]
    folder_head = 'dataset/instruments18/seq_'
    folder_tail = '/vqa/simple/*.txt'

    labels = ['kidney',
          'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation',
          'Tool_Manipulation', 'Cutting', 'Cauterization', 'Suction', 
          'Looping', 'Suturing', 'Clipping', 'Staple', 'Ultrasound_Sensing',
          'left-top', 'right-top', 'left-bottom', 'right-bottom']

    transform = transforms.Compose([
                transforms.Resize((300,256)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                ])

    # train_dataset
    train_dataset = SurgicalVQADataset(train_seq, folder_head, folder_tail, labels, transform=transform)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size= bs, shuffle=True)

    # Val_dataset
    val_dataset = SurgicalVQADataset(val_seq, folder_head, folder_tail, labels, transform=transform)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size= bs, shuffle=False)
    
    # model
    model = Surgical_VQA(num_classes=len(labels)).cuda()
    
    best_epoch = [0]
    best_results = [0.0]
    
    for epoch in range(1, epochs):
        train_model(epoch, model, train_dataloader, lr)
        test_acc, test_c_acc, mAP = test_model(epoch, model, train_dataloader)
    
        if test_acc >= best_results[0]:
            best_results[0] = test_acc
            best_epoch[0] = epoch
        
        print('Best epoch: %d | Best acc: %.6f' %(best_epoch[0], best_results[0]))
        checkpoint = {'lr': lr, 'b_s': bs, 'state_dict': model.state_dict() }
        save_name = "checkpoint_" + str(epoch) + '_epoch.pth'
        
        torch.save(checkpoint, os.path.join(checkpoint_dir, save_name))

Total files: 1560 | Total question: 9014
Total files: 447 | Total question: 2769
Train: epoch: 1 loss: 422.574315 | Acc: 0.564566 | mAP: 0.000000
Test: epoch: 1 loss: 219.202679 | Acc: 0.758487 | mAP: 0.000000
[1.         0.90981109 0.         0.79790026 0.69448373 0.
 0.51966874 0.         0.         0.2601626  0.         0.
 0.         0.         0.82813537 0.76112026 0.57815443 0.47789474]
Best epoch: 1 | Best acc: 0.758487
Train: epoch: 2 loss: 204.987015 | Acc: 0.760484 | mAP: 0.000000
Test: epoch: 2 loss: 155.270176 | Acc: 0.813623 | mAP: 0.000000
[1.         0.86532602 0.         0.89501312 0.67892504 0.34375
 0.99378882 0.         0.         0.64227642 0.         0.
 0.         0.59722222 0.84140677 0.77182867 0.56685499 0.80210526]
Best epoch: 2 | Best acc: 0.813623
Train: epoch: 3 loss: 158.492421 | Acc: 0.809629 | mAP: 0.000000
Test: epoch: 3 loss: 126.301651 | Acc: 0.855891 | mAP: 0.000000
[1.         0.83729433 0.         0.95275591 0.73833098 0.484375
 0.99378882 0.      

Train: epoch: 24 loss: 58.316690 | Acc: 0.932438 | mAP: 0.000000
Test: epoch: 24 loss: 47.112253 | Acc: 0.949634 | mAP: 0.000000
[1.         0.92809263 0.94915254 0.98687664 0.89816124 0.7890625
 1.         0.92307692 1.         0.92682927 1.         1.
 1.         1.         0.93563371 0.93327842 0.95291902 0.96631579]
Best epoch: 24 | Best acc: 0.949634
Train: epoch: 25 loss: 57.603463 | Acc: 0.936210 | mAP: 0.000000
Test: epoch: 25 loss: 46.880654 | Acc: 0.944974 | mAP: 0.000000
[1.         0.93540524 0.96610169 0.98687664 0.87553041 0.7421875
 1.         0.92307692 1.         0.91869919 1.         1.
 1.         1.         0.93629728 0.90939044 0.94161959 0.97052632]
Best epoch: 24 | Best acc: 0.949634
Train: epoch: 26 loss: 57.521896 | Acc: 0.936210 | mAP: 0.000000
Test: epoch: 26 loss: 46.532212 | Acc: 0.944198 | mAP: 0.000000
[1.         0.92687386 0.96610169 0.98687664 0.88967468 0.5859375
 1.         0.92307692 1.         0.95934959 1.         1.
 1.         1.         0.93497

Train: epoch: 47 loss: 50.481581 | Acc: 0.941314 | mAP: 0.000000
Test: epoch: 47 loss: 45.018283 | Acc: 0.948303 | mAP: 0.000000
[1.         0.92687386 0.96610169 0.98687664 0.89533239 0.75
 1.         1.         1.         0.95934959 1.         1.
 1.         1.         0.94890511 0.90856672 0.95480226 0.96631579]
Best epoch: 27 | Best acc: 0.950189
Train: epoch: 48 loss: 50.775249 | Acc: 0.940648 | mAP: 0.000000
Test: epoch: 48 loss: 43.416329 | Acc: 0.949190 | mAP: 0.000000
[1.         0.92870201 0.94915254 0.98687664 0.92362093 0.640625
 1.         1.         1.         0.91056911 1.         1.
 1.         1.         0.93231586 0.93657331 0.94161959 0.97473684]
Best epoch: 27 | Best acc: 0.950189
Train: epoch: 49 loss: 49.980672 | Acc: 0.942756 | mAP: 0.000000
Test: epoch: 49 loss: 44.152036 | Acc: 0.947304 | mAP: 0.000000
[1.         0.92809263 0.96610169 0.98687664 0.90947666 0.7421875
 1.         1.         1.         0.96747967 1.         1.
 1.         1.         0.94359655 0.

Test: epoch: 70 loss: 42.907676 | Acc: 0.948968 | mAP: 0.000000
[1.         0.93723339 1.         0.98687664 0.87411598 0.7578125
 1.         1.         1.         0.91056911 1.         1.
 1.         1.         0.94226941 0.9291598  0.93408663 0.97263158]
Best epoch: 56 | Best acc: 0.950854
Train: epoch: 71 loss: 47.567824 | Acc: 0.943310 | mAP: 0.000000
Test: epoch: 71 loss: 42.913863 | Acc: 0.946971 | mAP: 0.000000
[1.         0.9329677  0.96610169 0.98687664 0.8854314  0.75
 1.         1.         1.         0.91056911 1.         1.
 1.         1.         0.93497014 0.92257002 0.94161959 0.97052632]
Best epoch: 56 | Best acc: 0.950854
Train: epoch: 72 loss: 47.590756 | Acc: 0.944531 | mAP: 0.000000
Test: epoch: 72 loss: 42.723470 | Acc: 0.948857 | mAP: 0.000000
[1.         0.94759293 0.94915254 0.98687664 0.85289958 0.75
 1.         1.         1.         0.91056911 1.         1.
 1.         1.         0.93231586 0.93904448 0.94161959 0.97263158]
Best epoch: 56 | Best acc: 0.950854
T

Train: epoch: 94 loss: 46.191159 | Acc: 0.946860 | mAP: 0.000000
Test: epoch: 94 loss: 42.776588 | Acc: 0.946084 | mAP: 0.000000
[1.         0.93235832 0.96610169 0.98687664 0.88118812 0.75
 1.         1.         1.         0.90243902 1.         1.
 1.         1.         0.95023225 0.90032949 0.94161959 0.97263158]
Best epoch: 56 | Best acc: 0.950854
Train: epoch: 95 loss: 46.282883 | Acc: 0.943089 | mAP: 0.000000
Test: epoch: 95 loss: 42.693340 | Acc: 0.946750 | mAP: 0.000000
[1.         0.93479586 0.94915254 0.98687664 0.87270156 0.75
 1.         1.         1.         0.91056911 1.         1.
 1.         1.         0.94094227 0.91762768 0.94161959 0.97473684]
Best epoch: 56 | Best acc: 0.950854
Train: epoch: 96 loss: 46.217405 | Acc: 0.945529 | mAP: 0.000000
Test: epoch: 96 loss: 42.541985 | Acc: 0.949412 | mAP: 0.000000
[1.         0.92626447 0.94915254 0.98687664 0.92362093 0.75
 1.         1.         1.         0.91056911 1.         1.
 1.         1.         0.93165229 0.9324547  

Test: epoch: 117 loss: 42.270046 | Acc: 0.947637 | mAP: 0.000000
[1.         0.93113955 0.94915254 0.98687664 0.89674682 0.75
 1.         1.         1.         0.90243902 1.         1.
 1.         1.         0.91506304 0.94810544 0.94161959 0.97473684]
Best epoch: 103 | Best acc: 0.950854
Train: epoch: 118 loss: 45.519553 | Acc: 0.944642 | mAP: 0.000000
Test: epoch: 118 loss: 42.389150 | Acc: 0.949856 | mAP: 0.000000
[1.         0.92321755 0.94915254 0.98687664 0.93069307 0.75
 1.         1.         1.         0.91056911 1.         1.
 1.         1.         0.94293298 0.92092257 0.94161959 0.97263158]
Best epoch: 103 | Best acc: 0.950854
Train: epoch: 119 loss: 45.672694 | Acc: 0.943976 | mAP: 0.000000
Test: epoch: 119 loss: 42.023839 | Acc: 0.947859 | mAP: 0.000000
[1.         0.93479586 0.94915254 0.98687664 0.87835926 0.75
 1.         1.         1.         0.91869919 1.         1.
 1.         1.         0.93430657 0.93080725 0.94161959 0.97263158]
Best epoch: 103 | Best acc: 0.95085

Test: epoch: 140 loss: 42.182461 | Acc: 0.947415 | mAP: 0.000000
[1.         0.93235832 1.         0.98687664 0.89250354 0.75
 1.         1.         1.         0.80487805 1.         1.
 1.         1.         0.93297943 0.9324547  0.94161959 0.97473684]
Best epoch: 103 | Best acc: 0.950854
Train: epoch: 141 loss: 45.079163 | Acc: 0.946528 | mAP: 0.000000
Test: epoch: 141 loss: 42.138873 | Acc: 0.949745 | mAP: 0.000000
[1.         0.9213894  0.96610169 0.98687664 0.93635078 0.75
 1.         1.         1.         0.88617886 1.         1.
 1.         1.         0.933643   0.93163097 0.94161959 0.97473684]
Best epoch: 103 | Best acc: 0.950854
Train: epoch: 142 loss: 44.908445 | Acc: 0.945085 | mAP: 0.000000
Test: epoch: 142 loss: 41.796982 | Acc: 0.947526 | mAP: 0.000000
[1.         0.93418647 0.98305085 0.98687664 0.87270156 0.75
 1.         1.         1.         0.91056911 1.         1.
 1.         1.         0.93231586 0.93327842 0.94161959 0.97473684]
Best epoch: 103 | Best acc: 0.95085