<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Surgical_VQA_GPT_BioGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VQA Surgery (Naive version)

sentence-transformers by UKPLab (To classify text) <br>
page: https://www.libhunt.com/r/sentence-transformers<br>
github: https://github.com/UKPLab/sentence-transformers

In [None]:
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms
from PIL import Image
import requests

url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)

labels = [ 'grasping', 'retraction', 'tissue manipulation', 'tool manipulation', 
          'cutting', 'cauterization', 'suction', 'looping', 'suturing', 'clipping', 'staple', 'ultrasound sensing']

preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])


num_classes = 12
img_preprocessed = preprocess(img)
batch_img_tensor = torch.unsqueeze(img_preprocessed, 0)

model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.eval();
new_fc = torch.nn.Sequential(*list(model.fc.children())[:-1])
model.fc = new_fc
img_features = model(batch_img_tensor)
print(img_features.shape)

class Surgical_VQA(nn.Module):
    def __init__(self, num_classes=12):
        super(Surgical_VQA, self).__init__()
        #self.num_classes = num_classes

        # text processing
        self.text_feature_extractor = SentenceTransformer('bert-base-nli-mean-tokens').cuda()
        # image processing
        self.img_feature_extractor = models.resnet50(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #classifier
        self.classifier = nn.Linear(2816, num_classes)

    def forward(self, img, text):
        img_feature = self.img_feature_extractor(img)
        text_feature = self.text_feature_extractor.encode([text])[0]
        img_text_features = torch.cat((img_feature, torch.tensor(text_feature).unsqueeze(0).cuda()), dim=1)
        out = self.classifier(img_text_features)
        return out

text = "What is the state of bipolar_forceps?"
SVQA = Surgical_VQA(num_classes=12).cuda()
output = SVQA(batch_img_tensor.cuda(), text)
answer = output.argmax(dim=1)
print('Question: {} \nAnswer: {}'.format(text, labels[answer.item()]))

#VQA: VisualBERT + ResNet (Early Fusion)

In [1]:
!pip -q install transformers

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import VisualBertModel, VisualBertConfig, BertTokenizerFast
from PIL import Image
import requests

class VisualBERT_VQA(nn.Module):
    def __init__(self, num_class=2):
        super(VisualBERT_VQA, self).__init__()
        self.visualbert = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.cls = nn.Linear(768, num_class)

    def forward(self, inputs):
        last_hidden_state = self.visualbert(**inputs).last_hidden_state #[1, 56, 768]

        # Get the index of the last text token
        index_to_gather = inputs['attention_mask'].sum(1) - 2  # as in original code 5
        index_to_gather = (
            index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, last_hidden_state.size(-1))
        ) # [b c hw]=[1, 1, 768]

        pooled_output = torch.gather(last_hidden_state, 1, index_to_gather) # [1, 1, 768]
        logits = self.cls(pooled_output).squeeze(1)
        return logits


url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

labels = [ 'grasping', 'retraction', 'tissue manipulation', 'tool manipulation', 
          'cutting', 'cauterization', 'suction', 'looping', 'suturing', 'clipping', 'staple', 'ultrasound sensing']
num_classes = 12

#visual feature
img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
model_visual_feat = models.resnet50(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(batch_img_tensor).view(-1, 49, 2048)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)

#text feature
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
question = "What is the state of bipolar_forceps?"
questions = list([question, question])
inputs = tokenizer(questions, return_tensors="pt", padding="max_length",max_length=20,)


#update diction 
inputs.update(
    {
        "visual_embeds": visual_embeds,
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print('visual_embeds', inputs['visual_embeds'].shape, 'Text:', inputs['input_ids'].shape)
model = VisualBERT_VQA(num_class=18)
model.eval()
logits = model(inputs)
pred_vqa = logits.argmax(-1)
print('Logits:',logits, 'Prediction:', pred_vqa)  




visual_embeds torch.Size([2, 49, 2048]) Text: torch.Size([2, 20])


Some weights of the model checkpoint at uclanlp/visualbert-vqa-coco-pre were not used when initializing VisualBertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing VisualBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VisualBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Logits: tensor([[ 0.1263,  0.3838, -0.0967, -0.2127,  0.0030,  0.4668,  0.0651,  0.0341,
          0.1571,  0.0366, -0.2244, -0.2846,  0.4410, -0.0932,  0.1033, -0.1001,
          0.4315, -0.0655],
        [ 0.1263,  0.3838, -0.0967, -0.2127,  0.0030,  0.4668,  0.0651,  0.0341,
          0.1571,  0.0366, -0.2244, -0.2846,  0.4410, -0.0932,  0.1033, -0.1001,
          0.4315, -0.0655]], grad_fn=<SqueezeBackward1>) Prediction: tensor([5, 5])


#VQA: ChatGPT + ResNet (Early Fusion)

In [33]:
import torch
from torch import nn
from transformers import VisualBertModel, BertTokenizer, VisualBertConfig, GPT2Model, GPT2Tokenizer, GPT2Config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GPT2_VQA(nn.Module):
    def __init__(self, num_class=2):
        super(GPT2_VQA, self).__init__()
        self.gpt2 = GPT2Model.from_pretrained('gpt2')
        self.config = GPT2Config.from_pretrained("gpt2")
        self.classifier = nn.Linear(59 * 768, num_class)

        self.config_bert = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.config_bert.visual_embedding_dim = 2048 #most right dim of the visual features
        self.config_bert.hidden_size = self.config.hidden_size
        self.config_bert.vocab_size = self.config.vocab_size 
        self.config_bert.pad_token_id = self.config.pad_token_id 

        self.visualbert = VisualBertModel(config=self.config_bert)
        self.embeddings = self.visualbert.embeddings

    def forward(self, inputs):
        hidden_states = self.embeddings(
            input_ids=inputs['input_ids'],
            token_type_ids=inputs['token_type_ids'],
            position_ids=None,
            inputs_embeds=None,
            visual_embeds=inputs['visual_embeds'],
            visual_token_type_ids=inputs['visual_token_type_ids'],
            image_text_alignment=None,
        )

        hidden_states = self.gpt2.drop(hidden_states)
        input_shape = inputs['input_ids'].size()
        visual_input_shape = inputs['visual_embeds'].size()[:-1]
        combined_attention_mask = torch.cat((inputs['attention_mask'], inputs['visual_attention_mask']), dim=-1)
        extended_attention_mask: torch.Tensor = self.gpt2.get_extended_attention_mask(
            combined_attention_mask, (input_shape[0], input_shape + visual_input_shape)
        )
        output_attentions = self.config.output_attentions
        head_mask = self.gpt2.get_head_mask(None, self.config.n_layer)
        past_key_values = tuple([None] * len(self.gpt2.h))
        for i, (block, layer_past) in enumerate(zip(self.gpt2.h, past_key_values)):
            outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=extended_attention_mask,
                    head_mask=head_mask[i],
                    encoder_hidden_states=None,
                    encoder_attention_mask=None,
                    use_cache=None,
                    output_attentions=output_attentions,
                )
            
            hidden_states = outputs[0]

        hidden_states = self.gpt2.ln_f(hidden_states) #[2, 59, 768]
        x = torch.flatten(hidden_states, 1) 
        x = self.classifier(x)
        return x

num_classes = 18

url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])


#visual feature
img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
model_visual_feat = models.resnet50(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(batch_img_tensor).view(-1, 49, 2048)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)

#text feature
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
question = "What is the state of bipolar_forceps?"
questions = list([question, question])
inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")
# inputs = tokenizer(questions, return_tensors="pt", padding="max_length",max_length=20,)#[2 20]
token_type_ids = torch.zeros(inputs['input_ids'].shape, dtype=torch.long) # zeros because text id types is ones

inputs.update(
    {
        "token_type_ids": token_type_ids,
        "visual_embeds": visual_embeds, #[2, 49, 2048]
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print(inputs['input_ids'].shape, inputs['token_type_ids'].shape, inputs['attention_mask'].shape, 
      inputs['visual_embeds'].shape, inputs['visual_token_type_ids'].shape, inputs['visual_attention_mask'].shape)

model = GPT2_VQA(num_class=18)
model.eval()
logits = model(inputs)
answer = logits.argmax(dim=1)
print(logits.shape, answer)


torch.Size([2, 10]) torch.Size([2, 10]) torch.Size([2, 10]) torch.Size([2, 49, 2048]) torch.Size([2, 49]) torch.Size([2, 49])
torch.Size([2, 18]) tensor([1, 1])


#VQA: ChatGPT + ResNet (Late Fusion)

In [None]:
!pip -q install transformers

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import BertTokenizer, GPT2Model, GPT2Tokenizer
from PIL import Image
import requests

url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

labels = [ 'grasping', 'retraction', 'tissue manipulation', 'tool manipulation', 
          'cutting', 'cauterization', 'suction', 'looping', 'suturing', 'clipping', 'staple', 'ultrasound sensing']
num_classes = 12
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.eval();

class GPT2RS18Classification(nn.Module):
    def __init__(self, num_class = 12):
        super(GPT2RS18Classification, self).__init__()

        # text processing
        self.text_feature_extractor = GPT2Model.from_pretrained('gpt2')
 
        # image processing
        self.img_feature_extractor = models.resnet18(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #intermediate_layers
        self.intermediate_layer = nn.Linear(1280, 512)  #(512+768)
        self.LayerNorm = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.1)

        # classifier
        self.classifier = nn.Linear(512, num_class)

    def forward(self, input, img):
        
        # image encoder features
        img_feature = self.img_feature_extractor(img)
        
        # question tokenizer features
        input['input_ids'] = input['input_ids'].to(device)
        input['attention_mask'] = input['attention_mask'].to(device)

        # GPT text encoder
        text_feature = self.text_feature_extractor(**input) # [2, 10, 768]
        print(text_feature.last_hidden_state.shape)
        text_feature = text_feature.last_hidden_state.swapaxes(1,2)
        text_feature = F.adaptive_avg_pool1d(text_feature,1)
        text_feature = text_feature.swapaxes(1,2).squeeze(1)        
        
        # late visual-text fusion
        img_text_features = torch.cat((img_feature, text_feature), dim=1)

        # intermediate layers
        out =self.intermediate_layer(img_text_features)
        out = self.LayerNorm(out)
        out = self.dropout(out)

        # classifier
        out = self.classifier(out)
        print(out.size())
        return out


# questions = "What is the state of bipolar_forceps?"
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

question = "What is the state of bipolar_forceps?"
questions = list([question, question])
inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")

img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
new_fc = torch.nn.Sequential(*list(model.fc.children())[:-1])
model.fc = new_fc
img_features = model(batch_img_tensor)

SVQA = GPT2RS18Classification(num_class=12).cuda()
output = SVQA(inputs, batch_img_tensor.cuda())
answer = output.argmax(dim=1)
print('Question: {} \nAnswer: {}'.format(questions[0], labels[answer[0].item()]))

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

torch.Size([2, 10, 768])
torch.Size([2, 12])
Question: What is the state of bipolar_forceps? 
Answer: tool manipulation


#VQA: BioGPT + ResNet(Late Fusion)
src: https://huggingface.co/microsoft/biogpt<br>
git: https://github.com/microsoft/BioGPT<br>
paper: https://arxiv.org/abs/2210.10341<br>


In [None]:
!pip -q install transformers sacremoses

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/880.6 KB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.4/880.6 KB[0m [31m5.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m870.4/880.6 KB[0m [31m15.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m880.6/880.6 KB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import BioGptTokenizer, BioGptForCausalLM
from transformers import BertTokenizer, GPT2Tokenizer
from PIL import Image
import requests
url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

labels = [ 'grasping', 'retraction', 'tissue manipulation', 'tool manipulation', 
          'cutting', 'cauterization', 'suction', 'looping', 'suturing', 'clipping', 'staple', 'ultrasound sensing']
num_classes = 12
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.eval();

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GPT2RS18Classification(nn.Module):
    def __init__(self, num_class = 12):
        super(GPT2RS18Classification, self).__init__()

        # text processing
        self.text_feature_extractor = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
 
        # image processing
        self.img_feature_extractor = models.resnet18(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #intermediate_layers
        self.intermediate_layer = nn.Linear(42896, 512)  #(512+768)
        self.LayerNorm = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.1)

        # classifier
        self.classifier = nn.Linear(512, num_class)

    def forward(self, input, img):
        
        # image encoder features
        img_feature = self.img_feature_extractor(img)
        
        # question tokenizer features
        input['input_ids'] = input['input_ids'].to(device)
        input['attention_mask'] = input['attention_mask'].to(device)

        # GPT text encoder
        text_feature = self.text_feature_extractor(**input)
        text_feature = text_feature[0].swapaxes(1,2)
        #mobarak: [1, 12, 42384], text feature is too big compare to img. We may pool it to 512 the equal size of img
        #F.adaptive_avg_pool2d(output[0],[1, 512])
        text_feature = F.adaptive_avg_pool1d(text_feature,1) 
        text_feature = text_feature.swapaxes(1,2).squeeze()

        # late visual-text fusion
        #mobarak: advanced level fusion can be used instead of naive concat (e.g., multihead attention fusion)
        img_text_features = torch.cat((img_feature, text_feature), dim=1)

        # intermediate layers
        out =self.intermediate_layer(img_text_features)
        #mobarak: we may add one more intermidiate layer if the features size is bigger
        out = self.LayerNorm(out)
        out = self.dropout(out)

        # classifier
        out = self.classifier(out)
        print(out.size())
        return out



tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
tokenizer.pad_token = tokenizer.eos_token

question = "What is the state of bipolar_forceps?"
questions = []
questions.append(question)
questions.append(question)
inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")

img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
new_fc = torch.nn.Sequential(*list(model.fc.children())[:-1])
model.fc = new_fc
img_features = model(batch_img_tensor)

SVQA = GPT2RS18Classification(num_class=12).cuda()
output = SVQA(inputs, batch_img_tensor.cuda())
answer = output.argmax(dim=1)
print('Question: {} \nAnswer: {}'.format(questions[0], labels[answer[0].item()]))

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/1.56G [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

torch.Size([2, 12])
Question: What is the state of bipolar_forceps? 
Answer: cauterization
