<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Surgical_LlaVA_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

google drive: https://drive.google.com/file/d/1K5YnSPMPvn2x1gtRAw2ZfxIqoIo2DX3I

In [2]:
# Downloading the VQA EndoVis18 Dataset https://drive.google.com/file/d/1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN/view?usp=sharing
!gdown --id 1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN

# Unzipping the VQA EndoVis18 Dataset\
!unzip -q EndoVis-18-VQA.zip

Downloading...
From (original): https://drive.google.com/uc?id=1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN
From (redirected): https://drive.google.com/uc?id=1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN&confirm=t&uuid=77f3a854-e5f3-417a-b5c4-b676b115a241
To: /content/EndoVis-18-VQA.zip
100% 2.70G/2.70G [00:52<00:00, 51.2MB/s]


In [5]:
!pip install -q transformers==4.36.0
!pip install -q bitsandbytes==0.41.3 accelerate==0.25.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m28.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m32.2 MB/s[0m eta [36m0:00:00[0m
[?25h

Utils

In [3]:
import torch
import os
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_fscore_support

class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def adjust_learning_rate(optimizer, shrink_factor):
    """
    Shrinks learning rate by a specified factor.

    :param optimizer: optimizer whose learning rate must be shrunk.
    :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
    """

    print("\nDECAYING learning rate.")
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))


def save_clf_checkpoint(checkpoint_dir, epoch, epochs_since_improvement, model, optimizer, Acc, final_args):
    """
    Saves model checkpoint.
    """
    state = {'epoch': epoch,
             'epochs_since_improvement': epochs_since_improvement,
             'Acc': Acc,
             'model': model,
             'optimizer': optimizer,
             'final_args': final_args}
    filename = checkpoint_dir + 'Best.pth.tar'
    torch.save(state, filename)

def calc_acc(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    return acc

def calc_classwise_acc(y_true, y_pred):
    matrix = confusion_matrix(y_true, y_pred)
    classwise_acc = matrix.diagonal()/matrix.sum(axis=1)
    return classwise_acc

def calc_map(y_true, y_scores):
    mAP = average_precision_score(y_true, y_scores,average=None)
    return mAP

def calc_precision_recall_fscore(y_true, y_pred):
    precision, recall, fscore, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division = 1)
    return(precision, recall, fscore)


def seed_everything(seed=27):
    '''
    Set random seed for reproducible experiments
    Inputs: seed number
    '''
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Training:

In [None]:
# %cd /content/SurgicalGPT
import os
import sys
import numpy as np
import argparse
from torch import nn
from torch import optim
import torch.utils.data
import torch.nn.functional as F
from torch.utils.data  import DataLoader
from transformers import BertTokenizer, GPT2Tokenizer
# from train import seed_everything
# from dataloaders.dataloaderGPT2Classification import EndoVis18VQAGPTClassification
# from models.EFGPT2Classification import EFVLEGPT2RS18Classification, EFVLEGPT2SwinClassification
# from utils import save_clf_checkpoint, adjust_learning_rate, calc_acc, calc_precision_recall_fscore, calc_classwise_acc

# from transformers import  VisualBertConfig, GPT2Config
from transformers import ViTModel, SwinModel
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

def get_arg():
    parser = argparse.ArgumentParser(description='VisualQuestionAnswerClassification')

    # VB Model parameters
    parser.add_argument('--emb_dim',        type=int,   default=300,                                help='dimension of word embeddings.')
    parser.add_argument('--n_heads',        type=int,   default=8,                                  help='Multi-head attention.')
    parser.add_argument('--dropout',        type=float, default=0.1,                                help='dropout')
    parser.add_argument('--encoder_layers', type=int,   default=6,                                  help='the number of layers of encoder in Transformer.')

    # Training parameters
    parser.add_argument('--epochs',         type=int,   default=2,                                 help='number of epochs to train for (if early stopping is not triggered).') #80, 26
    parser.add_argument('--batch_size',     type=int,   default=40,                                 help='batch_size')
    parser.add_argument('--workers',        type=int,   default=1,                                  help='for data-loading; right now, only 1 works with h5pys.')
    parser.add_argument('--print_freq',     type=int,   default=100,                                help='print training/validation stats every __ batches.')

    # existing checkpoint
    parser.add_argument('--checkpoint',     default=None,                                           help='path to checkpoint, None if none.')

    parser.add_argument('--lr',             type=float, default=0.00001,                           help='0.000005, 0.00001, 0.000005')
    parser.add_argument('--checkpoint_dir', default= 'checkpoints/efvlegpt2Swin/m18_v1_z_qf_',            help='med_vqa_c$version$/m18/c80/m18_vid$temporal_size$/c80_vid$temporal_size$') #clf_v1_2_1x1/med_vqa_c3
    parser.add_argument('--dataset_type',   default= 'm18',                                          help='med_vqa/m18/c80/m18_vid/c80_vid')
    parser.add_argument('--dataset_cat',    default= 'cat1',                                        help='cat1/cat2/cat3')
    parser.add_argument('--tokenizer_ver',  default= 'gpt2v1',                                      help='btv2/btv3/gpt2v1')
    parser.add_argument('--question_len',   default= 35,                                            help='25')
    parser.add_argument('--model_ver',      default= 'efvlegpt2Swin',                                          help='vb/vbrm/efvlegpt2rs18/efvlegpt2Swin/"')  #vrvb/gpt2rs18/gpt2ViT/gpt2Swin/biogpt2rs18/vilgpt2vqa/efgpt2rs18gr/efvlegpt2Swingr
    parser.add_argument('--model_subver',   default= 'v1',                                          help='V0,v1/v2/v3/v4')
    parser.add_argument('--vis_pos_emb',   default= 'zeroes',                                           help='None, zeroes, pos')
    parser.add_argument('--patch_size',     default= 5,                                             help='1/2/3/4/5')

    parser.add_argument('--num_class',      default= 2,                                             help='25')
    # parser.add_argument('--temporal_size',  default= 1,                                             help='1/2/3/4/5')
    parser.add_argument('--validate',       default=False,                                          help='When only validation required False/True')

    if 'ipykernel' in sys.modules:
        args = parser.parse_args([])
    else:
        args = parser.parse_args()
    return args


from torch.utils.data import Dataset
from PIL import Image
import os
import glob
import torchvision.transforms as transforms
from torchvision import models
from torch import nn


class EndoVis18VQAGPTClassification(Dataset):
    '''
    	seq: train_seq  = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    	     val_seq    = [1, 5, 16]
    	folder_head     = 'dataset/EndoVis-18-VQA/seq_'
    	folder_tail     = '/vqa/Classification/*.txt'
    '''
    def __init__(self, seq, folder_head, folder_tail, transform=None):

        # files, question and answers
        filenames = []
        for curr_seq in seq: filenames = filenames + glob.glob(folder_head + str(curr_seq) + folder_tail)
        self.vqas = []
        for file in filenames:
            file_data = open(file, "r")
            lines = [line.strip("\n") for line in file_data if line != "\n"]
            file_data.close()
            for line in lines: self.vqas.append([file, line])
        print('Total files: %d | Total question: %.d' %(len(filenames), len(self.vqas)))

        # Labels
        self.labels = ['kidney', 'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation',
                        'Tool_Manipulation', 'Cutting', 'Cauterization', 'Suction',
                        'Looping', 'Suturing', 'Clipping', 'Staple', 'Ultrasound_Sensing',
                        'left-top', 'right-top', 'left-bottom', 'right-bottom']

    def __len__(self):
        return len(self.vqas)

    def __getitem__(self, idx):
        loc = self.vqas[idx][0].split('/')

        # img
        img_loc = os.path.join(loc[0],loc[1],'left_frames',loc[-1].split('_')[0]+'.png')
        img = Image.open(img_loc).convert('RGB')
        # question and answer
        question = self.vqas[idx][1].split('|')[0]
        label = self.labels.index(str(self.vqas[idx][1].split('|')[1]))
        prompt = "<image>\nUSER:" + question + "\nASSISTANT:"

        return np.array(img), prompt, label


class Surgical_LlaVA(nn.Module):
    def __init__(self, num_classes=18):
        super(Surgical_LlaVA, self).__init__()

        kwargs = {"device_map": "auto"}
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
        self.model_llava = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf",**kwargs)
        for param in self.model_llava.parameters():
            param.requires_grad = False

        ## intermediate_layers
        self.intermediate_layer = nn.Linear(595, 512)  #(512+768)
        self.LayerNorm = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.1)

        ## classifier
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, inputs):
        out = self.model_llava(**inputs).logits
        out = F.adaptive_avg_pool1d(out,1).squeeze(-1)
        out = self.intermediate_layer(out)
        # out = LayerNorm(out)
        out = self.dropout(out)
        out = self.classifier(out)
        return out

def train(args, train_dataloader, model, criterion, optimizer, epoch, processor, device):

    model.train()
    total_loss = 0.0
    label_true = None
    label_pred = None
    label_score = None

    for i, (imgs, prompts, labels) in enumerate(train_dataloader,0):
        # questions = []
        # for question in q: questions.append(question)
        inputs = processor(text=prompts, images=imgs, max_length=args.question_len, padding="max_length", return_tensors="pt")

        # labels
        labels = labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print statistics
        total_loss += loss.item()

        scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)
        label_true = labels.data.cpu() if label_true == None else torch.cat((label_true, labels.data.cpu()), 0)
        label_pred = predicted.data.cpu() if label_pred == None else torch.cat((label_pred, predicted.data.cpu()), 0)
        label_score = scores.data.cpu() if label_score == None else torch.cat((label_score, scores.data.cpu()), 0)

    # loss and acc
    acc, c_acc = calc_acc(label_true, label_pred), calc_classwise_acc(label_true, label_pred)
    precision, recall, fscore = calc_precision_recall_fscore(label_true, label_pred)
    print('Train: epoch: %d loss: %.6f | Acc: %.6f | Precision: %.6f | Recall: %.6f | FScore: %.6f' %(epoch, total_loss, acc, precision, recall, fscore))
    return acc


def validate(args, val_loader, model, criterion, epoch, processor, device, save_output = False):

    model.eval()
    total_loss = 0.0
    label_true = None
    label_pred = None
    label_score = None
    file_names = list()

    with torch.no_grad():
        for i, (imgs, prompts, labels) in enumerate(val_loader,0):

            # prepare questions
            # questions = []
            # for question in q: questions.append(question)
            inputs = processor(text=prompts, images=imgs, max_length=args.question_len, padding="max_length", return_tensors="pt")

            labels = labels.to(device)
            outputs = model(inputs)

            # loss
            loss = criterion(outputs,labels)
            total_loss += loss.item()
            scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)
            label_true = labels.data.cpu() if label_true == None else torch.cat((label_true, labels.data.cpu()), 0)
            label_pred = predicted.data.cpu() if label_pred == None else torch.cat((label_pred, predicted.data.cpu()), 0)
            label_score = scores.data.cpu() if label_score == None else torch.cat((label_score, scores.data.cpu()), 0)

    acc = calc_acc(label_true, label_pred)
    c_acc = 0.0
    precision, recall, fscore = calc_precision_recall_fscore(label_true, label_pred)
    print('Test: epoch: %d loss: %.6f | Acc: %.6f | Precision: %.6f | Recall: %.6f | FScore: %.6f' %(epoch, total_loss, acc, precision, recall, fscore))

    return (acc, c_acc, precision, recall, fscore)


if __name__ == '__main__':

    args = get_arg()
    os.makedirs('checkpoints/efvlegpt2Swin', exist_ok=True)
    seed_everything()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    start_epoch = 1
    best_epoch = [0]
    best_results = [0.0]
    epochs_since_improvement = 0
    final_args = {"emb_dim": args.emb_dim, "n_heads": args.n_heads, "dropout": args.dropout, "encoder_layers": args.encoder_layers}


    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token

    # data location
    train_seq = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    val_seq = [1, 5, 16]

    folder_head = 'EndoVis-18-VQA/seq_'
    folder_tail = '/vqa/Classification/*.txt'

    train_dataset = EndoVis18VQAGPTClassification(train_seq, folder_head, folder_tail)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size= args.batch_size, shuffle=True, num_workers=8)
    val_dataset = EndoVis18VQAGPTClassification(val_seq, folder_head, folder_tail)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size= args.batch_size, shuffle=False, num_workers=8)

    # num_classes
    args.num_class = 18

    model = Surgical_LlaVA(num_classes=args.num_class)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

    for epoch in range(start_epoch, args.epochs):

        if epochs_since_improvement > 0 and epochs_since_improvement % 5 == 0:
            adjust_learning_rate(optimizer, 0.8)

        # train
        train_acc = train(args, train_dataloader=train_dataloader, model = model, criterion=criterion, optimizer=optimizer, epoch=epoch, processor = processor, device = device)

        # validation
        test_acc, test_c_acc, test_precision, test_recall, test_fscore = validate(args, val_loader=val_dataloader, model = model, criterion=criterion, epoch=epoch, processor = processor, device = device)

        if test_acc >= best_results[0]:
            print('Best Epoch:', epoch)
            epochs_since_improvement = 0
            best_results[0] = test_acc
            best_epoch[0] = epoch
            save_clf_checkpoint(args.checkpoint_dir, epoch, epochs_since_improvement, model, optimizer, best_results[0], final_args)
        else:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))


Total files: 1560 | Total question: 9014
Total files: 447 | Total question: 2769




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]