# Sentence Transformer & ResNet50

## Dataloader

In [1]:
import os
import glob
from PIL import Image

from torch.utils.data import Dataset
import torchvision.transforms as transforms


class SurgicalVQADataset(Dataset):
    def __init__(self, seq, folder_head, folder_tail, labels, transform=None):
        
        self.transform = transform
        
        # files, question and answers
        filenames = []
        for curr_seq in seq: filenames = filenames + glob.glob(folder_head + str(curr_seq) + folder_tail)
        self.vqas = []
        for file in filenames:
            file_data = open(file, "r")
            lines = [line.strip("\n") for line in file_data if line != "\n"]
            file_data.close()
            for line in lines: self.vqas.append([file, line])
        print('Total files: %d | Total question: %.d' %(len(filenames), len(self.vqas)))
        
        # Labels
        self.labels = labels
        
    def __len__(self):
        return len(self.vqas)

    def __getitem__(self, idx):
        
        # img
        loc = self.vqas[idx][0].split('/')
        img_loc = os.path.join(loc[0],loc[1],loc[2], 'left_frames',loc[-1].split('_')[0]+'.png')
        img = Image.open(img_loc)
        if self.transform: img = self.transform(img)
            
        # question and answer
        question = self.vqas[idx][1].split('|')[0]
        label = self.labels.index(str(self.vqas[idx][1].split('|')[1]))

        return img, question, label

## Model Version 1

In [2]:
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
from torchvision import models

class Surgical_VQA(nn.Module):
    def __init__(self, num_classes=12):
        super(Surgical_VQA, self).__init__()

        # text processing
        self.text_feature_extractor = SentenceTransformer('bert-base-nli-mean-tokens')
        # image processing
        self.img_feature_extractor = models.resnet50(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #classifier
        self.classifier = nn.Linear(2816, num_classes)

    def forward(self, img, text):
        img_feature = self.img_feature_extractor(img)
        
        text_feature = self.text_feature_extractor.encode(text)
        text_feature = torch.tensor(text_feature).cuda()
        
        img_text_features = torch.cat((img_feature, text_feature), dim=1)
        
        out = self.classifier(img_text_features)
        return out

## Model Version 2

In [2]:
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
from torchvision import models

class Surgical_VQA(nn.Module):
    def __init__(self, num_classes=12):
        super(Surgical_VQA, self).__init__()

        # text processing
        self.text_feature_extractor = SentenceTransformer('bert-large-nli-mean-tokens')
        # image processing
        self.img_feature_extractor = models.resnet50(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #classifier
        self.classifier = nn.Linear(3072, num_classes)

    def forward(self, img, text):
        img_feature = self.img_feature_extractor(img)
        
        text_feature = self.text_feature_extractor.encode(text)
        text_feature = torch.tensor(text_feature).cuda()
        
        img_text_features = torch.cat((img_feature, text_feature), dim=1)
        
        out = self.classifier(img_text_features)
        return out

## Metrics

In [3]:
import numpy as np

from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score

def calc_acc(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    return acc

def calc_classwise_acc(y_true, y_pred):
    matrix = confusion_matrix(y_true, y_pred)
    classwise_acc = matrix.diagonal()/matrix.sum(axis=1)
    return classwise_acc

def calc_map(y_true, y_scores):
    mAP = average_precision_score(y_true, y_scores,average=None)
    return mAP

## Test model

In [4]:
import torch.nn.functional as F

def test_model(epoch, model, valid_dataloader):
    
    model.eval()

    total_loss = 0.0    
    label_true = None
    label_pred = None
    label_score = None
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for i, (imgs, q, labels) in enumerate(valid_dataloader, 0):
            questions = []
            for question in q: questions.append(question)
            imgs, labels = imgs.cuda(), labels.cuda()
            
            outputs = model(imgs, questions)

            loss = criterion(outputs,labels)
            total_loss += loss.item()
        
            scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)    
            label_true = labels.data.cpu() if label_true == None else torch.cat((label_true, labels.data.cpu()), 0)
            label_pred = predicted.data.cpu() if label_pred == None else torch.cat((label_pred, predicted.data.cpu()), 0)
            label_score = scores.data.cpu() if label_score == None else torch.cat((label_score, scores.data.cpu()), 0)

            
    acc, c_acc, mAP = calc_acc(label_true, label_pred), calc_classwise_acc(label_true, label_pred), 0.0#calc_map(label_true, label_score)

    print('Test: epoch: %d loss: %.6f | Acc: %.6f | mAP: %.6f' %(epoch, total_loss, acc, mAP))
    print(c_acc)
    
    return (acc, c_acc, mAP)

## Train model

In [5]:
from torch import optim
def train_model(epoch, model, train_dataloader, lr):  # train model
    
    model.train()
    
    total_loss = 0.0    
    label_true = None
    label_pred = None
    label_score = None
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr = lr, weight_decay = 0)
    
    for i, (imgs, q, labels) in enumerate(train_dataloader,0):
        questions = []
        for question in q: questions.append(question)
        imgs, labels = imgs.cuda(), labels.cuda()
        
        # zero the parameter gradients
        optimizer.zero_grad()

        outputs = model(imgs, questions)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        total_loss += loss.item()
        
        scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)    
        label_true = labels.data.cpu() if label_true == None else torch.cat((label_true, labels.data.cpu()), 0)
        label_pred = predicted.data.cpu() if label_pred == None else torch.cat((label_pred, predicted.data.cpu()), 0)
        label_score = scores.data.cpu() if label_score == None else torch.cat((label_score, scores.data.cpu()), 0)

    
    # loss and acc
    acc, c_acc, mAP = calc_acc(label_true, label_pred), calc_classwise_acc(label_true, label_pred), 0.0#calc_map(label_true, label_score)

    print('Train: epoch: %d loss: %.6f | Acc: %.6f | mAP: %.6f' %(epoch, total_loss, acc, mAP))
    return

## Main

In [6]:
import os
import torch

from torchvision import transforms
from torch.utils.data import DataLoader

os.environ["CUDA_VISIBLE_DEVICES"]="0"

def seed_everything(seed=27):
    '''
    Set random seed for reproducible experiments
    Inputs: seed number 
    '''
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
if __name__ == "__main__":
     
    # Set random seed
    seed_everything()  
    
    # Device Count
    num_gpu = torch.cuda.device_count()
    
    # hyperparameters
    bs = 32
    epochs = 150
    lr = 0.00001
    
    checkpoint_dir = 'checkpoints/v2/simple/'
    
    # train and test dataloader
    train_seq = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    val_seq = [1, 5, 16]
    folder_head = 'dataset/instruments18/seq_'
    folder_tail = '/vqa/simple/*.txt'

    labels = ['kidney',
          'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation',
          'Tool_Manipulation', 'Cutting', 'Cauterization', 'Suction', 
          'Looping', 'Suturing', 'Clipping', 'Staple', 'Ultrasound_Sensing',
          'left-top', 'right-top', 'left-bottom', 'right-bottom']

    transform = transforms.Compose([
                transforms.Resize((300,256)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                ])

    # train_dataset
    train_dataset = SurgicalVQADataset(train_seq, folder_head, folder_tail, labels, transform=transform)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size= bs, shuffle=True)

    # Val_dataset
    val_dataset = SurgicalVQADataset(val_seq, folder_head, folder_tail, labels, transform=transform)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size= bs, shuffle=False)
    
    # model
    model = Surgical_VQA(num_classes=len(labels)).cuda()
    
    best_epoch = [0]
    best_results = [0.0]
    
    for epoch in range(1, epochs):
        train_model(epoch, model, train_dataloader, lr)
        test_acc, test_c_acc, mAP = test_model(epoch, model, train_dataloader)
    
        if test_acc >= best_results[0]:
            best_results[0] = test_acc
            best_epoch[0] = epoch
        
        print('Best epoch: %d | Best acc: %.6f' %(best_epoch[0], best_results[0]))
        checkpoint = {'lr': lr, 'b_s': bs, 'state_dict': model.state_dict() }
        save_name = "checkpoint_" + str(epoch) + '_epoch.pth'
        
        torch.save(checkpoint, os.path.join(checkpoint_dir, save_name))

Total files: 1560 | Total question: 9014
Total files: 447 | Total question: 2769
Test: epoch: 1 loss: 605.971992 | Acc: 0.324939 | mAP: 0.000000
Test: epoch: 1 loss: 494.138120 | Acc: 0.498669 | mAP: 0.000000
[0.91089744 0.76721511 0.         0.         0.2446959  0.
 0.32505176 0.         0.         0.         0.         0.
 0.         0.         0.54611812 0.5107084  0.06403013 0.01684211]
Best epoch: 1 | Best acc: 0.498669
Test: epoch: 2 loss: 455.929187 | Acc: 0.520635 | mAP: 0.000000
Test: epoch: 2 loss: 394.542856 | Acc: 0.604726 | mAP: 0.000000
[1.         0.70505789 0.         0.17847769 0.48797737 0.
 0.42028986 0.         0.         0.5203252  0.         0.
 0.         0.         0.64299934 0.6630972  0.3747646  0.17052632]
Best epoch: 2 | Best acc: 0.604726
Test: epoch: 3 loss: 378.927593 | Acc: 0.605613 | mAP: 0.000000
Test: epoch: 3 loss: 331.437542 | Acc: 0.684935 | mAP: 0.000000
[1.         0.81170018 0.         0.56955381 0.59264498 0.0625
 0.60248447 0.         0.     

Test: epoch: 24 loss: 100.537047 | Acc: 0.916574 | mAP: 0.000000
Test: epoch: 24 loss: 86.584318 | Acc: 0.934103 | mAP: 0.000000
[1.         0.92199878 0.96610169 0.98687664 0.87411598 0.65625
 1.         0.92307692 1.         0.95934959 0.33333333 1.
 1.         1.         0.91838089 0.90527183 0.89265537 0.95578947]
Best epoch: 23 | Best acc: 0.937431
Test: epoch: 25 loss: 97.851239 | Acc: 0.917462 | mAP: 0.000000
Test: epoch: 25 loss: 84.060434 | Acc: 0.936876 | mAP: 0.000000
[1.         0.91163924 0.94915254 0.98687664 0.90240453 0.6796875
 1.         0.92307692 1.         0.95121951 0.16666667 0.87804878
 1.         1.         0.91174519 0.91680395 0.94538606 0.94526316]
Best epoch: 23 | Best acc: 0.937431
Test: epoch: 26 loss: 95.753885 | Acc: 0.921012 | mAP: 0.000000
Test: epoch: 26 loss: 82.231095 | Acc: 0.938096 | mAP: 0.000000
[1.         0.95185862 0.94915254 0.98687664 0.87128713 0.578125
 1.         0.92307692 1.         0.89430894 0.33333333 0.87804878
 0.92307692 1.     

Test: epoch: 47 loss: 68.079378 | Acc: 0.933881 | mAP: 0.000000
Test: epoch: 47 loss: 60.636830 | Acc: 0.943199 | mAP: 0.000000
[1.         0.92382693 0.96610169 0.98687664 0.89250354 0.703125
 1.         0.92307692 1.         0.97560976 1.         1.
 1.         1.         0.92368945 0.92092257 0.93973635 0.96      ]
Best epoch: 42 | Best acc: 0.944087
Test: epoch: 48 loss: 67.240613 | Acc: 0.935767 | mAP: 0.000000
Test: epoch: 48 loss: 60.108697 | Acc: 0.941314 | mAP: 0.000000
[1.         0.93418647 0.96610169 0.98687664 0.85714286 0.734375
 1.         0.92307692 1.         0.97560976 1.         1.
 1.         1.         0.92169874 0.91433278 0.93785311 0.95789474]
Best epoch: 42 | Best acc: 0.944087
Test: epoch: 49 loss: 66.871841 | Acc: 0.935212 | mAP: 0.000000
Test: epoch: 49 loss: 59.608473 | Acc: 0.945529 | mAP: 0.000000
[1.         0.92017063 0.96610169 0.98687664 0.91230552 0.7421875
 1.         0.92307692 1.         0.95934959 1.         1.
 1.         1.         0.93165229 0

Test: epoch: 70 loss: 53.108719 | Acc: 0.944198 | mAP: 0.000000
[1.         0.94576478 0.96610169 0.98687664 0.85855728 0.75
 1.         1.         1.         0.87804878 1.         1.
 1.         1.         0.9270073  0.91680395 0.93973635 0.96421053]
Best epoch: 49 | Best acc: 0.945529
Test: epoch: 71 loss: 58.138755 | Acc: 0.938762 | mAP: 0.000000
Test: epoch: 71 loss: 53.252082 | Acc: 0.944198 | mAP: 0.000000
[1.         0.9250457  0.96610169 0.98687664 0.88260255 0.828125
 1.         1.         1.         0.95934959 1.         1.
 1.         1.         0.92966158 0.91350906 0.93973635 0.95789474]
Best epoch: 49 | Best acc: 0.945529
Test: epoch: 72 loss: 57.467102 | Acc: 0.940870 | mAP: 0.000000
Test: epoch: 72 loss: 52.742980 | Acc: 0.944642 | mAP: 0.000000
[1.         0.94576478 0.96610169 0.98687664 0.83451202 0.75
 1.         1.         1.         0.95934959 1.         1.
 1.         1.         0.93497014 0.91598023 0.93973635 0.96421053]
Best epoch: 49 | Best acc: 0.945529
Test

Test: epoch: 94 loss: 53.661840 | Acc: 0.942090 | mAP: 0.000000
Test: epoch: 94 loss: 49.383733 | Acc: 0.946639 | mAP: 0.000000
[1.         0.92382693 0.96610169 0.98687664 0.93352192 0.75
 1.         1.         1.         0.88617886 1.         1.
 1.         1.         0.92435302 0.92421746 0.93973635 0.96210526]
Best epoch: 85 | Best acc: 0.947748
Test: epoch: 95 loss: 53.292409 | Acc: 0.940648 | mAP: 0.000000
Test: epoch: 95 loss: 49.643275 | Acc: 0.946195 | mAP: 0.000000
[1.         0.93174893 0.96610169 0.98687664 0.89957567 0.7578125
 1.         1.         1.         0.83739837 1.         1.
 1.         1.         0.93032515 0.92751236 0.93973635 0.96      ]
Best epoch: 85 | Best acc: 0.947748
Test: epoch: 96 loss: 53.099043 | Acc: 0.941424 | mAP: 0.000000
Test: epoch: 96 loss: 49.224976 | Acc: 0.947193 | mAP: 0.000000
[1.         0.93113955 0.96610169 0.98687664 0.89674682 0.75
 1.         1.         1.         0.95121951 1.         1.
 1.         1.         0.92833444 0.9266886

Test: epoch: 117 loss: 47.698119 | Acc: 0.947970 | mAP: 0.000000
[1.         0.92687386 0.96610169 0.98687664 0.91230552 0.75
 1.         1.         1.         0.94308943 1.         1.
 1.         1.         0.93497014 0.92174629 0.93973635 0.96631579]
Best epoch: 111 | Best acc: 0.948635
Test: epoch: 118 loss: 50.714145 | Acc: 0.943532 | mAP: 0.000000
Test: epoch: 118 loss: 47.550668 | Acc: 0.948635 | mAP: 0.000000
[1.         0.92382693 0.96610169 0.98687664 0.92503536 0.75
 1.         1.         1.         0.90243902 1.         1.
 1.         1.         0.93231586 0.92668863 0.94161959 0.97473684]
Best epoch: 118 | Best acc: 0.948635
Test: epoch: 119 loss: 50.558547 | Acc: 0.943310 | mAP: 0.000000
Test: epoch: 119 loss: 47.442971 | Acc: 0.949079 | mAP: 0.000000
[1.         0.93113955 0.96610169 0.98687664 0.9165488  0.75
 1.         1.         1.         0.89430894 1.         1.
 1.         1.         0.93430657 0.92751236 0.93973635 0.96631579]
Best epoch: 119 | Best acc: 0.949079


Test: epoch: 141 loss: 49.254164 | Acc: 0.944642 | mAP: 0.000000
Test: epoch: 141 loss: 46.432620 | Acc: 0.948968 | mAP: 0.000000
[1.         0.92870201 0.96610169 0.98687664 0.91796322 0.75
 1.         1.         1.         0.91056911 1.         1.
 1.         1.         0.93629728 0.92504119 0.94161959 0.96421053]
Best epoch: 129 | Best acc: 0.949745
Test: epoch: 142 loss: 49.079990 | Acc: 0.944753 | mAP: 0.000000
Test: epoch: 142 loss: 46.415242 | Acc: 0.946639 | mAP: 0.000000
[1.         0.9250457  0.96610169 0.98687664 0.88260255 0.75
 1.         1.         1.         0.95121951 1.         1.
 1.         1.         0.93629728 0.9291598  0.94161959 0.96421053]
Best epoch: 129 | Best acc: 0.949745
Test: epoch: 143 loss: 49.005354 | Acc: 0.946195 | mAP: 0.000000
Test: epoch: 143 loss: 46.277344 | Acc: 0.947193 | mAP: 0.000000
[1.         0.93601463 0.96610169 0.98687664 0.87694484 0.75
 1.         1.         1.         0.89430894 1.         1.
 1.         1.         0.93430657 0.9291