<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/SurgicalGPT_V3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/lalithjets/SurgicalGPT.git

Cloning into 'SurgicalGPT'...
remote: Enumerating objects: 98, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (61/61), done.[K
remote: Total 98 (delta 48), reused 80 (delta 35), pack-reused 0[K
Receiving objects: 100% (98/98), 752.31 KiB | 2.00 MiB/s, done.
Resolving deltas: 100% (48/48), done.


google drive: https://drive.google.com/file/d/1K5YnSPMPvn2x1gtRAw2ZfxIqoIo2DX3I

In [2]:
# Downloading the VQA EndoVis18 Dataset https://drive.google.com/file/d/1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN/view?usp=sharing
!gdown --id 1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN

# Unzipping the VQA EndoVis18 Dataset\
!unzip -q EndoVis-18-VQA.zip

Downloading...
From (original): https://drive.google.com/uc?id=1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN
From (redirected): https://drive.google.com/uc?id=1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN&confirm=t&uuid=e039104c-c47e-45e9-ba02-693a772b3a77
To: /content/EndoVis-18-VQA.zip
100% 2.70G/2.70G [00:26<00:00, 101MB/s]


In [3]:
# Dataset path correction as repo
!mv -f EndoVis-18-VQA /content/SurgicalGPT/dataset

Training:

In [None]:
%cd /content/SurgicalGPT
import os
import sys
import argparse
from torch import nn
from torch import optim
import torch.utils.data
import torch.nn.functional as F
from torch.utils.data  import DataLoader
from transformers import BertTokenizer, GPT2Tokenizer
from train import seed_everything
from dataloaders.dataloaderGPT2Classification import EndoVis18VQAGPTClassification
# from models.EFGPT2Classification import EFVLEGPT2RS18Classification, EFVLEGPT2SwinClassification
from utils import save_clf_checkpoint, adjust_learning_rate, calc_acc, calc_precision_recall_fscore, calc_classwise_acc

from transformers import  VisualBertConfig, GPT2Config
from transformers import VisualBertModel, GPT2Model, ViTModel, SwinModel

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

def get_arg():
    parser = argparse.ArgumentParser(description='VisualQuestionAnswerClassification')

    # VB Model parameters
    parser.add_argument('--emb_dim',        type=int,   default=300,                                help='dimension of word embeddings.')
    parser.add_argument('--n_heads',        type=int,   default=8,                                  help='Multi-head attention.')
    parser.add_argument('--dropout',        type=float, default=0.1,                                help='dropout')
    parser.add_argument('--encoder_layers', type=int,   default=6,                                  help='the number of layers of encoder in Transformer.')

    # Training parameters
    parser.add_argument('--epochs',         type=int,   default=2,                                 help='number of epochs to train for (if early stopping is not triggered).') #80, 26
    parser.add_argument('--batch_size',     type=int,   default=40,                                 help='batch_size')
    parser.add_argument('--workers',        type=int,   default=1,                                  help='for data-loading; right now, only 1 works with h5pys.')
    parser.add_argument('--print_freq',     type=int,   default=100,                                help='print training/validation stats every __ batches.')

    # existing checkpoint
    parser.add_argument('--checkpoint',     default=None,                                           help='path to checkpoint, None if none.')

    parser.add_argument('--lr',             type=float, default=0.00001,                           help='0.000005, 0.00001, 0.000005')
    parser.add_argument('--checkpoint_dir', default= 'checkpoints/efvlegpt2Swin/m18_v1_z_qf_',            help='med_vqa_c$version$/m18/c80/m18_vid$temporal_size$/c80_vid$temporal_size$') #clf_v1_2_1x1/med_vqa_c3
    parser.add_argument('--dataset_type',   default= 'm18',                                          help='med_vqa/m18/c80/m18_vid/c80_vid')
    parser.add_argument('--dataset_cat',    default= 'cat1',                                        help='cat1/cat2/cat3')
    parser.add_argument('--tokenizer_ver',  default= 'gpt2v1',                                      help='btv2/btv3/gpt2v1')
    parser.add_argument('--question_len',   default= 25,                                            help='25')
    parser.add_argument('--model_ver',      default= 'efvlegpt2Swin',                                          help='vb/vbrm/efvlegpt2rs18/efvlegpt2Swin/"')  #vrvb/gpt2rs18/gpt2ViT/gpt2Swin/biogpt2rs18/vilgpt2vqa/efgpt2rs18gr/efvlegpt2Swingr
    parser.add_argument('--model_subver',   default= 'v1',                                          help='V0,v1/v2/v3/v4')
    parser.add_argument('--vis_pos_emb',   default= 'zeroes',                                           help='None, zeroes, pos')
    parser.add_argument('--patch_size',     default= 5,                                             help='1/2/3/4/5')

    parser.add_argument('--num_class',      default= 2,                                             help='25')
    # parser.add_argument('--temporal_size',  default= 1,                                             help='1/2/3/4/5')
    parser.add_argument('--validate',       default=False,                                          help='When only validation required False/True')

    if 'ipykernel' in sys.modules:
        args = parser.parse_args([])
    else:
        args = parser.parse_args()
    return args

class EFVLEGPT2SwinClassification(nn.Module):
    def __init__(self, num_class = 12, model_subver = 'v0', vis_pos_emb = None):
        super(EFVLEGPT2SwinClassification, self).__init__()
        '''
        v0: visual embedding : Default patch1 + embedding form VB + GPT2 decoder
        v1: visual embedding : Default patch1 + GPT2 decoder
        '''

        self.sub_ver = model_subver
        self.vis_pos_emb = vis_pos_emb

        ## image processing
        self.img_feature_extractor = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")


        ## Question_embedding
        question_embedder = GPT2Model.from_pretrained('gpt2')
        self.question_embedder = question_embedder.wte

        ## GPT2 visual_cotext_aware_decoder
        self.VCAdecoder = GPT2Model.from_pretrained('gpt2')

        ## intermediate_layers
        self.intermediate_layer = nn.Linear(768, 512)  #(512+768)
        self.LayerNorm = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.1)

        ## classifier
        self.classifier = nn.Linear(512, num_class)

    def forward(self, input, img):

        ## image encoder features
        img['pixel_values'] = img['pixel_values'].to(device)
        img_feature = self.img_feature_extractor(**img)

        ## visual Embedding : id type 1, pos: zero / incremental
        visual_embeds = img_feature[0]
        visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
        visual_attention_mask = visual_attention_mask.to(device)
        visual_id_type = torch.ones(*visual_embeds.size()[:-1], dtype=torch.long, device=device)
        visual_position_id = torch.zeros(*visual_embeds.size()[:-1], dtype=torch.long, device=device)

        ## question embedding: id type 0, pose incremental
        input['input_ids'] = input['input_ids'].to(device)
        input['attention_mask'] = input['attention_mask'].to(device)

        question_embeds = self.question_embedder(input['input_ids'])
        question_attention_mask = input['attention_mask']
        question_id_type = torch.zeros(*question_embeds.size()[:-1], dtype=torch.long, device=device)
        question_position_id = torch.arange(0,question_embeds.size()[1])
        question_position_id = torch.unsqueeze(question_position_id,0)
        question_position_id = question_position_id.repeat(question_embeds.size()[0], 1)
        question_position_id = question_position_id.to(device)

        ## question first
        inputs_embeds = torch.cat((question_embeds, visual_embeds), dim=1)
        attention_mask = torch.cat((question_attention_mask, visual_attention_mask), dim=1)
        token_type_ids = torch.cat((question_id_type, visual_id_type), dim=1)
        position_ids = torch.cat((question_position_id, visual_position_id), dim=1)

        ## VCA_GPT2 decoder
        decoder_output = self.VCAdecoder(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids = position_ids, token_type_ids = token_type_ids)

        decoder_output = decoder_output.last_hidden_state.swapaxes(1,2)
        decoder_output = F.adaptive_avg_pool1d(decoder_output,1)
        decoder_output = decoder_output.swapaxes(1,2).squeeze(1)

        ## intermediate layers
        out =self.intermediate_layer(decoder_output)
        out = self.LayerNorm(out)
        out = self.dropout(out)

        ## classifier
        out = self.classifier(out)
        # print(out.size())
        return out

def train(args, train_dataloader, model, criterion, optimizer, epoch, tokenizer, device):

    model.train()
    total_loss = 0.0
    label_true = None
    label_pred = None
    label_score = None

    for i, (_, v_f, q, labels) in enumerate(train_dataloader,0):
        questions = []
        for question in q: questions.append(question)
        inputs = tokenizer(questions, padding="max_length",max_length= args.question_len, return_tensors="pt")

        # Visual features
        if args.model_ver == "efvlegpt2Swin" or args.model_ver == 'efvlegpt2ViT':
            visual_features = v_f
            visual_features['pixel_values'] = torch.squeeze(visual_features['pixel_values'],1)
        else:
            visual_features = v_f.to(device)

        # labels
        labels = labels.to(device)
        outputs = model(inputs, visual_features)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print statistics
        total_loss += loss.item()

        scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)
        label_true = labels.data.cpu() if label_true == None else torch.cat((label_true, labels.data.cpu()), 0)
        label_pred = predicted.data.cpu() if label_pred == None else torch.cat((label_pred, predicted.data.cpu()), 0)
        label_score = scores.data.cpu() if label_score == None else torch.cat((label_score, scores.data.cpu()), 0)

    # loss and acc
    acc, c_acc = calc_acc(label_true, label_pred), calc_classwise_acc(label_true, label_pred)
    precision, recall, fscore = calc_precision_recall_fscore(label_true, label_pred)
    print('Train: epoch: %d loss: %.6f | Acc: %.6f | Precision: %.6f | Recall: %.6f | FScore: %.6f' %(epoch, total_loss, acc, precision, recall, fscore))
    return acc


def validate(args, val_loader, model, criterion, epoch, tokenizer, device, save_output = False):

    model.eval()
    total_loss = 0.0
    label_true = None
    label_pred = None
    label_score = None
    file_names = list()

    with torch.no_grad():
        for i, (file_name, v_f, q, labels) in enumerate(val_loader,0):

            # prepare questions
            questions = []
            for question in q: questions.append(question)
            inputs = tokenizer(questions, padding="max_length",max_length=args.question_len, return_tensors="pt")

            # Visual features
            if args.model_ver == "efvlegpt2Swin" or args.model_ver == 'efvlegpt2ViT':
                visual_features = v_f
                visual_features['pixel_values'] = torch.squeeze(visual_features['pixel_values'],1)
            else:
                visual_features = v_f.to(device)

            labels = labels.to(device)
            outputs = model(inputs, visual_features)

            # loss
            loss = criterion(outputs,labels)
            total_loss += loss.item()
            scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)
            label_true = labels.data.cpu() if label_true == None else torch.cat((label_true, labels.data.cpu()), 0)
            label_pred = predicted.data.cpu() if label_pred == None else torch.cat((label_pred, predicted.data.cpu()), 0)
            label_score = scores.data.cpu() if label_score == None else torch.cat((label_score, scores.data.cpu()), 0)
            for f in file_name: file_names.append(f)

    acc = calc_acc(label_true, label_pred)
    c_acc = 0.0
    precision, recall, fscore = calc_precision_recall_fscore(label_true, label_pred)
    print('Test: epoch: %d loss: %.6f | Acc: %.6f | Precision: %.6f | Recall: %.6f | FScore: %.6f' %(epoch, total_loss, acc, precision, recall, fscore))

    return (acc, c_acc, precision, recall, fscore)


if __name__ == '__main__':

    args = get_arg()
    os.makedirs('checkpoints/efvlegpt2Swin', exist_ok=True)
    seed_everything()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    start_epoch = 1
    best_epoch = [0]
    best_results = [0.0]
    epochs_since_improvement = 0
    final_args = {"emb_dim": args.emb_dim, "n_heads": args.n_heads, "dropout": args.dropout, "encoder_layers": args.encoder_layers}


    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token

    # data location
    train_seq = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    val_seq = [1, 5, 16]

    folder_head = 'dataset/EndoVis-18-VQA/seq_'
    folder_tail = '/vqa/Classification/*.txt'

    train_dataset = EndoVis18VQAGPTClassification(train_seq, folder_head, folder_tail, model_ver=args.model_ver)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size= args.batch_size, shuffle=True, num_workers=8)
    val_dataset = EndoVis18VQAGPTClassification(val_seq, folder_head, folder_tail, model_ver=args.model_ver)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size= args.batch_size, shuffle=False, num_workers=8)

    # num_classes
    args.num_class = 18

    model = EFVLEGPT2SwinClassification(num_class = args.num_class, model_subver = args.model_subver, vis_pos_emb = args.vis_pos_emb)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)

    for epoch in range(start_epoch, args.epochs):

        if epochs_since_improvement > 0 and epochs_since_improvement % 5 == 0:
            adjust_learning_rate(optimizer, 0.8)

        # train
        train_acc = train(args, train_dataloader=train_dataloader, model = model, criterion=criterion, optimizer=optimizer, epoch=epoch, tokenizer = tokenizer, device = device)

        # validation
        test_acc, test_c_acc, test_precision, test_recall, test_fscore = validate(args, val_loader=val_dataloader, model = model, criterion=criterion, epoch=epoch, tokenizer = tokenizer, device = device)

        if test_acc >= best_results[0]:
            print('Best Epoch:', epoch)
            epochs_since_improvement = 0
            best_results[0] = test_acc
            best_epoch[0] = epoch
            save_clf_checkpoint(args.checkpoint_dir, epoch, epochs_since_improvement, model, optimizer, best_results[0], final_args)



/content/SurgicalGPT


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/255 [00:00<?, ?B/s]

Total files: 1560 | Total question: 9014
Total files: 447 | Total question: 2769




config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/113M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]