<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Surgical_VQA_GPT_BioGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VQA Surgery (Naive version)

sentence-transformers by UKPLab (To classify text) <br>
page: https://www.libhunt.com/r/sentence-transformers<br>
github: https://github.com/UKPLab/sentence-transformers

In [None]:
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms
from PIL import Image
import requests

url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)

labels = [ 'grasping', 'retraction', 'tissue manipulation', 'tool manipulation',
          'cutting', 'cauterization', 'suction', 'looping', 'suturing', 'clipping', 'staple', 'ultrasound sensing']

preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])


num_classes = 12
img_preprocessed = preprocess(img)
batch_img_tensor = torch.unsqueeze(img_preprocessed, 0)

model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.eval();
new_fc = torch.nn.Sequential(*list(model.fc.children())[:-1])
model.fc = new_fc
img_features = model(batch_img_tensor)
print(img_features.shape)

class Surgical_VQA(nn.Module):
    def __init__(self, num_classes=12):
        super(Surgical_VQA, self).__init__()
        #self.num_classes = num_classes

        # text processing
        self.text_feature_extractor = SentenceTransformer('bert-base-nli-mean-tokens').cuda()
        # image processing
        self.img_feature_extractor = models.resnet50(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #classifier
        self.classifier = nn.Linear(2816, num_classes)

    def forward(self, img, text):
        img_feature = self.img_feature_extractor(img)
        text_feature = self.text_feature_extractor.encode([text])[0]
        img_text_features = torch.cat((img_feature, torch.tensor(text_feature).unsqueeze(0).cuda()), dim=1)
        out = self.classifier(img_text_features)
        return out

text = "What is the state of bipolar_forceps?"
SVQA = Surgical_VQA(num_classes=12).cuda()
output = SVQA(batch_img_tensor.cuda(), text)
answer = output.argmax(dim=1)
print('Question: {} \nAnswer: {}'.format(text, labels[answer.item()]))

#VQA: VisualBERT + ResNet (Early Fusion)

In [2]:
!pip -q install transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m37.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m42.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import VisualBertModel, VisualBertConfig, BertTokenizerFast, BertTokenizer
from PIL import Image
import requests

class VisualBERT_VQA(nn.Module):
    def __init__(self, num_class=2):
        super(VisualBERT_VQA, self).__init__()
        self.visualbert = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.cls = nn.Linear(768, num_class)

    def forward(self, inputs):
        last_hidden_state = self.visualbert(**inputs).last_hidden_state #[1, 56, 768]

        # Get the index of the last text token
        index_to_gather = inputs['attention_mask'].sum(1) - 2  # as in original code 5
        index_to_gather = (
            index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, last_hidden_state.size(-1))
        ) # [b c hw]=[1, 1, 768]

        pooled_output = torch.gather(last_hidden_state, 1, index_to_gather) # [1, 1, 768]
        logits = self.cls(pooled_output).squeeze(1)
        return logits


url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

labels = [ 'grasping', 'retraction', 'tissue manipulation', 'tool manipulation',
          'cutting', 'cauterization', 'suction', 'looping', 'suturing', 'clipping', 'staple', 'ultrasound sensing']
num_classes = 12

#visual feature
img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
model_visual_feat = models.resnet50(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(batch_img_tensor).view(-1, 49, 2048)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)

#text feature
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
question = "What is the state of bipolar_forceps?"
questions = list([question, question])
inputs = tokenizer(questions, return_tensors="pt", padding="max_length",max_length=20,)


#update diction
inputs.update(
    {
        "visual_embeds": visual_embeds,
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print('visual_embeds', inputs['visual_embeds'].shape, 'Text:', inputs['input_ids'].shape)
model = VisualBERT_VQA(num_class=18)
model.eval()
logits = model(inputs)
pred_vqa = logits.argmax(-1)
print('Logits:',logits, 'Prediction:', pred_vqa)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 101MB/s]


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

visual_embeds torch.Size([2, 49, 2048]) Text: torch.Size([2, 20])


Downloading (…)lve/main/config.json:   0%|          | 0.00/631 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/448M [00:00<?, ?B/s]

Logits: tensor([[-0.1505,  0.0851,  0.1359,  0.5542,  0.0769, -0.0526, -0.2740,  0.0462,
          0.0527, -0.4569,  0.5654,  0.4765, -0.3537,  0.3167, -0.2160, -0.2892,
          0.4020,  0.2451],
        [-0.1505,  0.0851,  0.1359,  0.5542,  0.0769, -0.0526, -0.2740,  0.0462,
          0.0527, -0.4569,  0.5654,  0.4765, -0.3537,  0.3167, -0.2160, -0.2892,
          0.4020,  0.2451]], grad_fn=<SqueezeBackward1>) Prediction: tensor([10, 10])


MLP: Projection of Vision Embedding for Language Vision Fusion:

In [4]:
url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

from typing import Tuple
class MLP(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)

img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])

model_visual_feat = models.resnet50(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
out = model_visual_feat(batch_img_tensor)
print(out.shape)
print(out.view(-1, 49, 2048).shape)

image_size = 2048#768
llm_embedding_size = 4096
image_length = 1
image_project = MLP((image_size, (llm_embedding_size * image_length) // 2,
                                     llm_embedding_size * image_length))
img_embed = image_project(out.view(-1, 49, 2048))
img_embed.shape

torch.Size([2, 100352])
torch.Size([2, 49, 2048])


torch.Size([2, 49, 4096])

Transformer: Projection of Vision Embedding for Language Vision Fusion:

In [36]:
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from typing import Tuple, Optional, Union
from enum import Enum
import itertools

class MlpTransformer(nn.Module):
    def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
        super().__init__()
        out_d = out_d if out_d is not None else in_dim
        self.fc1 = nn.Linear(in_dim, h_dim)
        self.act = act
        self.fc2 = nn.Linear(h_dim, out_d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class MultiHeadAttention(nn.Module):

    def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim_self // num_heads
        self.scale = head_dim ** -0.5
        self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
        self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
        self.project = nn.Linear(dim_self, dim_self)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y=None, mask=None):
        y = y if y is not None else x
        b, n, c = x.shape
        _, m, d = y.shape
        # b n h dh
        queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
        # b m 2 h dh
        keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
        keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
        attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(1)
            attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
        attention = attention.softmax(dim=2)
        out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
        out = self.project(out)
        return out, attention


class TransformerLayer(nn.Module):

    def forward_with_attention(self, x, y=None, mask=None):
        x_, attention = self.attn(self.norm1(x), y, mask)
        x = x + x_
        x = x + self.mlp(self.norm2(x))
        return x, attention

    def forward(self, x, y=None, mask=None):
        x = x + self.attn(self.norm1(x), y, mask)[0]
        x = x + self.mlp(self.norm2(x))
        return x

    def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
                 norm_layer: nn.Module = nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim_self)
        self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
        self.norm2 = norm_layer(dim_self)
        self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)


class Transformer(nn.Module):

    def forward_with_attention(self, x, y=None, mask=None):
        attentions = []
        for layer in self.layers:
            x, att = layer.forward_with_attention(x, y, mask)
            attentions.append(att)
        return x, attentions

    def forward(self, x, y=None, mask=None):
        for i, layer in enumerate(self.layers):
            if i % 2 == 0 and self.enc_dec: # cross
                x = layer(x, y)
            elif self.enc_dec:  # self
                x = layer(x, x, mask)
            else:  # self or cross
                x = layer(x, y, mask)
        return x

    def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
                 mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
        super(Transformer, self).__init__()
        dim_ref = dim_ref if dim_ref is not None else dim_self
        self.enc_dec = enc_dec
        if enc_dec:
            num_layers = num_layers * 2
        layers = []
        for i in range(num_layers):
            if i % 2 == 0 and enc_dec:  # cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            elif enc_dec:  # self
                layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            else:  # self or cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
        self.layers = nn.ModuleList(layers)



class TransformerMapper(nn.Module):

    def forward(self, x):
        x = self.linear(x).view(x.shape[0], self.clip_length, -1)
        prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
        prefix = torch.cat((x, prefix), dim=1)
        out = self.transformer(prefix)[:, self.clip_length:]
        return out

    def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
        super(TransformerMapper, self).__init__()
        self.clip_length = clip_length
        self.transformer = Transformer(dim_embedding, 8, num_layers)
        self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
        self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)


torch.Size([2, 49, 4096])

# Version2: VQA: VisualBERT + ResNet (Early Fusion)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import VisualBertModel, VisualBertConfig, BertTokenizerFast, BertTokenizer, VisualBertForQuestionAnswering
from PIL import Image
import requests
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class VisualBertClassification(nn.Module):

    def __init__(self, num_class = 18):
        super(VisualBertClassification, self).__init__()
        config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        config.visual_embedding_dim = 512
        # config.hidden_size = 2048
        # config.num_attention_heads = 8
        # self.VisualBertEncoder = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.VisualBertEncoder = VisualBertModel(config)
        self.classifier = nn.Linear(config.hidden_size, num_class)

        # self.VisualBertEncoder = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa")
        # self.VisualBertEncoder.cls = nn.Linear(config.hidden_size, num_class)

    def forward(self, inputs, visual_embeds):
        # prepare visual embedding
        visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long).to(device)
        visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float).to(device)

        # append visual features to text
        inputs.update({
                        "visual_embeds": visual_embeds.to(device),
                        "visual_token_type_ids": visual_token_type_ids,
                        "visual_attention_mask": visual_attention_mask,
                        "output_attentions": True
                        })

        inputs['input_ids'] = inputs['input_ids'].to(device)
        inputs['token_type_ids'] = inputs['token_type_ids'].to(device)
        inputs['attention_mask'] = inputs['attention_mask'].to(device)
        inputs['visual_token_type_ids'] = inputs['visual_token_type_ids'].to(device)
        inputs['visual_attention_mask'] = inputs['visual_attention_mask'].to(device)

        outputs = self.VisualBertEncoder(**inputs)
        outputs = self.classifier(outputs['pooler_output'])
        return outputs

num_classes = 18

url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])


#visual feature
img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
model_visual_feat = models.resnet18(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(batch_img_tensor).view(-1, 49, 512)

#text feature
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# tokenizer.pad_token = tokenizer.eos_token
question = "What is the state of bipolar_forceps?"
questions = list([question, question])
inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")
# inputs = tokenizer(questions, return_tensors="pt", padding="max_length",max_length=25,)#[2 20]

model = VisualBertClassification(num_class=18).to(device)
model.eval()
logits = model(inputs, visual_embeds)
answer = logits.argmax(dim=1)
print(logits.shape, answer)



torch.Size([2, 18]) tensor([1, 1], device='cuda:0')


#VQA: ChatGPT + ResNet (Early Fusion)

In [None]:
!pip -q install transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import VisualBertModel, BertTokenizer, VisualBertConfig, GPT2Model, GPT2Tokenizer, GPT2Config
from PIL import Image
import requests
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GPT2_VQA(nn.Module):
    def __init__(self, num_class=2):
        super(GPT2_VQA, self).__init__()
        self.gpt2 = GPT2Model.from_pretrained('gpt2')
        self.config = GPT2Config.from_pretrained("gpt2")
        self.classifier = nn.Linear(59 * 768, num_class)

        self.config_bert = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.config_bert.visual_embedding_dim = 2048 #most right dim of the visual features
        self.config_bert.hidden_size = self.config.hidden_size
        self.config_bert.vocab_size = self.config.vocab_size
        self.config_bert.pad_token_id = self.config.pad_token_id

        self.visualbert = VisualBertModel(config=self.config_bert)
        self.embeddings = self.visualbert.embeddings

    def forward(self, inputs):
        hidden_states = self.embeddings(
            input_ids=inputs['input_ids'],
            token_type_ids=inputs['token_type_ids'],
            position_ids=None,
            inputs_embeds=None,
            visual_embeds=inputs['visual_embeds'],
            visual_token_type_ids=inputs['visual_token_type_ids'],
            image_text_alignment=None,
        )

        hidden_states = self.gpt2.drop(hidden_states)
        input_shape = inputs['input_ids'].size()
        visual_input_shape = inputs['visual_embeds'].size()[:-1]
        combined_attention_mask = torch.cat((inputs['attention_mask'], inputs['visual_attention_mask']), dim=-1)
        extended_attention_mask: torch.Tensor = self.gpt2.get_extended_attention_mask(
            combined_attention_mask, (input_shape[0], input_shape + visual_input_shape)
        )
        output_attentions = self.config.output_attentions
        head_mask = self.gpt2.get_head_mask(None, self.config.n_layer)
        past_key_values = tuple([None] * len(self.gpt2.h))
        for i, (block, layer_past) in enumerate(zip(self.gpt2.h, past_key_values)):
            outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=extended_attention_mask,
                    head_mask=head_mask[i],
                    encoder_hidden_states=None,
                    encoder_attention_mask=None,
                    use_cache=None,
                    output_attentions=output_attentions,
                )

            hidden_states = outputs[0]

        hidden_states = self.gpt2.ln_f(hidden_states) #[2, 59, 768]
        x = torch.flatten(hidden_states, 1)
        x = self.classifier(x)
        return x

num_classes = 18

url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])


#visual feature
img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
model_visual_feat = models.resnet50(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(batch_img_tensor).view(-1, 49, 2048)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)

#text feature
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
question = "What is the state of bipolar_forceps?"
questions = list([question, question])
inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")
# inputs = tokenizer(questions, return_tensors="pt", padding="max_length",max_length=20,)#[2 20]
token_type_ids = torch.zeros(inputs['input_ids'].shape, dtype=torch.long) # zeros because text id types is ones

inputs.update(
    {
        "token_type_ids": token_type_ids,
        "visual_embeds": visual_embeds, #[2, 49, 2048]
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print(inputs['input_ids'].shape, inputs['token_type_ids'].shape, inputs['attention_mask'].shape,
      inputs['visual_embeds'].shape, inputs['visual_token_type_ids'].shape, inputs['visual_attention_mask'].shape)

model = GPT2_VQA(num_class=18)
model.eval()
logits = model(inputs)
answer = logits.argmax(dim=1)
print(logits.shape, answer)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:03<00:00, 29.1MB/s]


Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

torch.Size([2, 10]) torch.Size([2, 10]) torch.Size([2, 10]) torch.Size([2, 49, 2048]) torch.Size([2, 49]) torch.Size([2, 49])


Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/631 [00:00<?, ?B/s]

torch.Size([2, 18]) tensor([7, 7])


In [None]:
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaForSequenceClassification, LlamaTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/tiny-random-LlamaForSequenceClassification")
model = AutoModelForSequenceClassification.from_pretrained("HuggingFaceH4/tiny-random-LlamaForSequenceClassification")
prompt = "I like you. I love you"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
logits

tensor([[-0.0137]], grad_fn=<IndexBackward0>)

In [None]:
!pip -q install transformers
!pip -q install sentencepiece

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m65.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m38.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m91.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m81.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m37.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaForSequenceClassification, LlamaTokenizer, AutoModelForSequenceClassification
# tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
# model = AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf")
model_name = "decapoda-research/llama-7b-hf"
num_labels = 2 # replace with the actual number of labels in your classification task
# model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
tokenizer = LlamaTokenizer.from_pretrained(model_name)
text = "I like you. I love you"
encoded_input = tokenizer(text, truncation=True, return_tensors='pt')

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [None]:
!huggingface-cli login --token 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

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaForSequenceClassification, LlamaTokenizer, AutoModelForSequenceClassification

# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = LlamaForSequenceClassification.from_pretrained("meta-llama/Llama-2-7b-hf", num_labels=18)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-2-7b-hf and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
text = "I like you. I love you"
encoded_input = tokenizer(text, truncation=True, return_tensors='pt')
outputs = model(**encoded_input)
logits = outputs.logits
logits

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


tensor([[ 2.7643, -3.8482, -0.0450, -0.3592,  2.7904, -0.2264, -0.7342,  1.0870,
         -1.2946,  1.5603, -3.2351,  6.3604,  2.4915,  1.2961, -1.0647,  2.7150,
         -4.9605, -3.7603]], grad_fn=<IndexBackward0>)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import VisualBertModel, BertTokenizer, VisualBertConfig, GPT2Model, GPT2Tokenizer, GPT2Config
from PIL import Image
import requests
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GPT2_VQA(nn.Module):
    def __init__(self, num_class=2):
        super(GPT2_VQA, self).__init__()
        self.gpt2 = GPT2Model.from_pretrained('gpt2')
        self.config = GPT2Config.from_pretrained("gpt2")
        self.classifier = nn.Linear(59 * 768, num_class)

        self.config_bert = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.config_bert.visual_embedding_dim = 2048 #most right dim of the visual features
        self.config_bert.hidden_size = self.config.hidden_size
        self.config_bert.vocab_size = self.config.vocab_size
        self.config_bert.pad_token_id = self.config.pad_token_id

        self.visualbert = VisualBertModel(config=self.config_bert)
        self.embeddings = self.visualbert.embeddings

    def forward(self, inputs):
        hidden_states = self.embeddings(
            input_ids=inputs['input_ids'],
            token_type_ids=inputs['token_type_ids'],
            position_ids=None,
            inputs_embeds=None,
            visual_embeds=inputs['visual_embeds'],
            visual_token_type_ids=inputs['visual_token_type_ids'],
            image_text_alignment=None,
        )

        hidden_states = self.gpt2.drop(hidden_states)
        input_shape = inputs['input_ids'].size()
        visual_input_shape = inputs['visual_embeds'].size()[:-1]
        combined_attention_mask = torch.cat((inputs['attention_mask'], inputs['visual_attention_mask']), dim=-1)
        extended_attention_mask: torch.Tensor = self.gpt2.get_extended_attention_mask(
            combined_attention_mask, (input_shape[0], input_shape + visual_input_shape)
        )
        output_attentions = self.config.output_attentions
        head_mask = self.gpt2.get_head_mask(None, self.config.n_layer)
        past_key_values = tuple([None] * len(self.gpt2.h))
        for i, (block, layer_past) in enumerate(zip(self.gpt2.h, past_key_values)):
            outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=extended_attention_mask,
                    head_mask=head_mask[i],
                    encoder_hidden_states=None,
                    encoder_attention_mask=None,
                    use_cache=None,
                    output_attentions=output_attentions,
                )

            hidden_states = outputs[0]

        hidden_states = self.gpt2.ln_f(hidden_states) #[2, 59, 768]
        x = torch.flatten(hidden_states, 1)
        x = self.classifier(x)
        return x

num_classes = 18

url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])


#visual feature
img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
model_visual_feat = models.resnet50(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(batch_img_tensor).view(-1, 49, 2048)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)

#text feature
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
question = "What is the state of bipolar_forceps?"
questions = list([question, question])
inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")
# inputs = tokenizer(questions, return_tensors="pt", padding="max_length",max_length=20,)#[2 20]
token_type_ids = torch.zeros(inputs['input_ids'].shape, dtype=torch.long) # zeros because text id types is ones

inputs.update(
    {
        "token_type_ids": token_type_ids,
        "visual_embeds": visual_embeds, #[2, 49, 2048]
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print(inputs['input_ids'].shape, inputs['token_type_ids'].shape, inputs['attention_mask'].shape,
      inputs['visual_embeds'].shape, inputs['visual_token_type_ids'].shape, inputs['visual_attention_mask'].shape)

model = GPT2_VQA(num_class=18)
model.eval()
logits = model(inputs)
answer = logits.argmax(dim=1)
print(logits.shape, answer)


#VQA: ChatGPT + ResNet (Early Fusion) + Embedding

In [None]:
from PIL import Image
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import GPT2Model, GPT2Tokenizer, GPT2Config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ViLEmbeddings(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

        # For Visual Features
        # Token type and position embedding for image features
        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        if config.special_visual_initialize:
            self.visual_token_type_embeddings.weight.data = nn.Parameter(
                self.token_type_embeddings.weight.data.clone(), requires_grad=True
            )
            self.visual_position_embeddings.weight.data = nn.Parameter(
                self.position_embeddings.weight.data.clone(), requires_grad=True
            )

        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)


    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        visual_embeds=None,
        visual_token_type_ids=None,
        image_text_alignment=None,
    ):

        input_shape = input_ids.size()
        seq_length = input_shape[1]
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = inputs_embeds + token_type_embeddings

        # Absolute Position Embeddings
        position_embeddings = self.position_embeddings(position_ids)
        embeddings += position_embeddings
        visual_embeds = self.visual_projection(visual_embeds)
        visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)
        visual_position_ids = torch.zeros(
            *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device
        )
        visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)
        visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings
        embeddings = torch.cat((embeddings, visual_embeddings), dim=1)

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

class GPT2_VQA(nn.Module):
    def __init__(self, num_class=2, config_emb=None):
        super(GPT2_VQA, self).__init__()
        self.gpt2 = GPT2Model.from_pretrained('gpt2')
        self.config = GPT2Config.from_pretrained("gpt2")
        self.classifier = nn.Linear(59 * 768, num_class)

        self.config_emb = config_emb
        self.config_emb.visual_embedding_dim = 2048 #most right dim of the visual features
        self.config_emb.hidden_size = self.config.hidden_size
        self.config_emb.vocab_size = self.config.vocab_size
        self.config_emb.pad_token_id = self.config.pad_token_id

        self.embeddings = ViLEmbeddings(config=self.config_emb)
        # self.embeddings = self.visualbert.embeddings

    def forward(self, inputs):
        hidden_states = self.embeddings(
            input_ids=inputs['input_ids'],
            token_type_ids=inputs['token_type_ids'],
            position_ids=None,
            inputs_embeds=None,
            visual_embeds=inputs['visual_embeds'],
            visual_token_type_ids=inputs['visual_token_type_ids'],
            image_text_alignment=None,
        )

        hidden_states = self.gpt2.drop(hidden_states)
        input_shape = inputs['input_ids'].size()
        visual_input_shape = inputs['visual_embeds'].size()[:-1]
        combined_attention_mask = torch.cat((inputs['attention_mask'], inputs['visual_attention_mask']), dim=-1)
        extended_attention_mask: torch.Tensor = self.gpt2.get_extended_attention_mask(
            combined_attention_mask, (input_shape[0], input_shape + visual_input_shape)
        )
        output_attentions = self.config.output_attentions
        head_mask = self.gpt2.get_head_mask(None, self.config.n_layer)
        past_key_values = tuple([None] * len(self.gpt2.h))
        for i, (block, layer_past) in enumerate(zip(self.gpt2.h, past_key_values)):
            outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=extended_attention_mask,
                    head_mask=head_mask[i],
                    encoder_hidden_states=None,
                    encoder_attention_mask=None,
                    use_cache=None,
                    output_attentions=output_attentions,
                )

            hidden_states = outputs[0]

        hidden_states = self.gpt2.ln_f(hidden_states) #[2, 59, 768]
        x = torch.flatten(hidden_states, 1)
        x = self.classifier(x)
        return x

num_classes = 18
url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])


#visual feature
img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
model_visual_feat = models.resnet50(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(batch_img_tensor).view(-1, 49, 2048)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)

#text feature
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
question = "What is the state of bipolar_forceps?"
questions = list([question, question])
inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")
# inputs = tokenizer(questions, return_tensors="pt", padding="max_length",max_length=20,)#[2 20]
token_type_ids = torch.zeros(inputs['input_ids'].shape, dtype=torch.long) # zeros because text id types is ones

inputs.update(
    {
        "token_type_ids": token_type_ids,
        "visual_embeds": visual_embeds, #[2, 49, 2048]
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print(inputs['input_ids'].shape, inputs['token_type_ids'].shape, inputs['attention_mask'].shape,
      inputs['visual_embeds'].shape, inputs['visual_token_type_ids'].shape, inputs['visual_attention_mask'].shape)

class config_emb:
    visual_embedding_dim = 2048
    vocab_size = 30522
    type_vocab_size = 2
    pad_token_id = 1
    hidden_size = 768
    max_position_embeddings = 512
    layer_norm_eps = 1e-12
    hidden_dropout_prob = 0.1
    special_visual_initialize = True

model = GPT2_VQA(num_class=18, config_emb=config_emb)
model.eval()
logits = model(inputs)
answer = logits.argmax(dim=1)
print('logits:',logits.shape, 'answer:',answer)


torch.Size([2, 10]) torch.Size([2, 10]) torch.Size([2, 10]) torch.Size([2, 49, 2048]) torch.Size([2, 49]) torch.Size([2, 49])
logits: torch.Size([2, 18]) answer: tensor([16, 16])


#VQA: ChatGPT + ResNet (Late Fusion)

In [None]:
!pip -q install transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import BertTokenizer, GPT2Model, GPT2Tokenizer
from PIL import Image
import requests

url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

labels = [ 'grasping', 'retraction', 'tissue manipulation', 'tool manipulation',
          'cutting', 'cauterization', 'suction', 'looping', 'suturing', 'clipping', 'staple', 'ultrasound sensing']
num_classes = 12
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.eval();

class GPT2RS18Classification(nn.Module):
    def __init__(self, num_class = 12):
        super(GPT2RS18Classification, self).__init__()

        # text processing
        self.text_feature_extractor = GPT2Model.from_pretrained('gpt2')

        # image processing
        self.img_feature_extractor = models.resnet18(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #intermediate_layers
        self.intermediate_layer = nn.Linear(1280, 512)  #(512+768)
        self.LayerNorm = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.1)

        # classifier
        self.classifier = nn.Linear(512, num_class)

    def forward(self, input, img):

        # image encoder features
        img_feature = self.img_feature_extractor(img)

        # question tokenizer features
        input['input_ids'] = input['input_ids'].to(device)
        input['attention_mask'] = input['attention_mask'].to(device)

        # GPT text encoder
        text_feature = self.text_feature_extractor(**input) # [2, 10, 768]
        print(text_feature.last_hidden_state.shape)
        text_feature = text_feature.last_hidden_state.swapaxes(1,2)
        text_feature = F.adaptive_avg_pool1d(text_feature,1)
        text_feature = text_feature.swapaxes(1,2).squeeze(1)

        # late visual-text fusion
        img_text_features = torch.cat((img_feature, text_feature), dim=1)

        # intermediate layers
        out =self.intermediate_layer(img_text_features)
        out = self.LayerNorm(out)
        out = self.dropout(out)

        # classifier
        out = self.classifier(out)
        print(out.size())
        return out


# questions = "What is the state of bipolar_forceps?"
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

question = "What is the state of bipolar_forceps?"
questions = list([question, question])
inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")

img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
new_fc = torch.nn.Sequential(*list(model.fc.children())[:-1])
model.fc = new_fc
img_features = model(batch_img_tensor)

SVQA = GPT2RS18Classification(num_class=12).cuda()
output = SVQA(inputs, batch_img_tensor.cuda())
answer = output.argmax(dim=1)
print('Question: {} \nAnswer: {}'.format(questions[0], labels[answer[0].item()]))

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

torch.Size([2, 10, 768])
torch.Size([2, 12])
Question: What is the state of bipolar_forceps? 
Answer: tool manipulation


#VQA: BioGPT + ResNet(Late Fusion)
src: https://huggingface.co/microsoft/biogpt<br>
git: https://github.com/microsoft/BioGPT<br>
paper: https://arxiv.org/abs/2210.10341<br>


In [None]:
!pip -q install transformers sacremoses

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/880.6 KB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.4/880.6 KB[0m [31m5.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m870.4/880.6 KB[0m [31m15.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m880.6/880.6 KB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import BioGptTokenizer, BioGptForCausalLM
from transformers import BertTokenizer, GPT2Tokenizer
from PIL import Image
import requests
url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

labels = [ 'grasping', 'retraction', 'tissue manipulation', 'tool manipulation',
          'cutting', 'cauterization', 'suction', 'looping', 'suturing', 'clipping', 'staple', 'ultrasound sensing']
num_classes = 12
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.eval();

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GPT2RS18Classification(nn.Module):
    def __init__(self, num_class = 12):
        super(GPT2RS18Classification, self).__init__()

        # text processing
        self.text_feature_extractor = BioGptForCausalLM.from_pretrained("microsoft/biogpt")

        # image processing
        self.img_feature_extractor = models.resnet18(pretrained=True)
        new_fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        self.img_feature_extractor.fc = new_fc

        #intermediate_layers
        self.intermediate_layer = nn.Linear(42896, 512)  #(512+768)
        self.LayerNorm = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.1)

        # classifier
        self.classifier = nn.Linear(512, num_class)

    def forward(self, input, img):

        # image encoder features
        img_feature = self.img_feature_extractor(img)

        # question tokenizer features
        input['input_ids'] = input['input_ids'].to(device)
        input['attention_mask'] = input['attention_mask'].to(device)

        # GPT text encoder
        text_feature = self.text_feature_extractor(**input)
        text_feature = text_feature[0].swapaxes(1,2)
        #mobarak: [1, 12, 42384], text feature is too big compare to img. We may pool it to 512 the equal size of img
        #F.adaptive_avg_pool2d(output[0],[1, 512])
        text_feature = F.adaptive_avg_pool1d(text_feature,1)
        text_feature = text_feature.swapaxes(1,2).squeeze()

        # late visual-text fusion
        #mobarak: advanced level fusion can be used instead of naive concat (e.g., multihead attention fusion)
        img_text_features = torch.cat((img_feature, text_feature), dim=1)

        # intermediate layers
        out =self.intermediate_layer(img_text_features)
        #mobarak: we may add one more intermidiate layer if the features size is bigger
        out = self.LayerNorm(out)
        out = self.dropout(out)

        # classifier
        out = self.classifier(out)
        print(out.size())
        return out



tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
tokenizer.pad_token = tokenizer.eos_token

question = "What is the state of bipolar_forceps?"
questions = list([question, question])
inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")

img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
new_fc = torch.nn.Sequential(*list(model.fc.children())[:-1])
model.fc = new_fc
img_features = model(batch_img_tensor)

SVQA = GPT2RS18Classification(num_class=12).cuda()
output = SVQA(inputs, batch_img_tensor.cuda())
answer = output.argmax(dim=1)
print('Question: {} \nAnswer: {}'.format(questions[0], labels[answer[0].item()]))

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/1.56G [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

torch.Size([2, 12])
Question: What is the state of bipolar_forceps? 
Answer: cauterization


#Token Smoothing

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from transformers import VisualBertModel, VisualBertConfig, BertTokenizerFast, BertTokenizer, VisualBertForQuestionAnswering
from PIL import Image
import requests

import math
import numpy as np

#####Gaussian#####
def get_gaussian_kernel_2d(ksize=0, sigma=0, channels=1):
    x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize)
    y_grid = x_grid.t()
    xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
    mean = (ksize - 1)/2.
    variance = sigma**2.
    gaussian_kernel = (1./(2.*math.pi*variance.view(channels,1,1) + 1e-16)) *\
        torch.exp( -torch.sum((xy_grid - mean)**2., dim=-1).view(1, ksize, ksize).repeat(channels,1,1) /\
        (2*variance.view(channels,1,1) + 1e-16)
        )
    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel, dim=(1,2)).view(channels,1,1)
    return gaussian_kernel.unsqueeze(1).float()

class get_gaussian_filter(nn.Module):
    def __init__(self, ksize=3, sigma=0, channels=0):
        super(get_gaussian_filter, self).__init__()
        sigma = torch.tensor(sigma).repeat(channels) if np.isscalar(sigma) else sigma
        gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma, channels=channels)

        padding = ksize // 2
        self.gk_layer = nn.Conv2d(in_channels=channels, out_channels=channels,
                                    kernel_size=ksize, groups=channels,
                                    bias=False, padding=padding)
        self.gk_layer.weight.data = gkernel
        self.gk_layer.weight.requires_grad = False
    def forward(self, x):
        return self.gk_layer(x)




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class VisualBertClassification(nn.Module):

    def __init__(self, num_class = 18):
        super(VisualBertClassification, self).__init__()
        config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        config.visual_embedding_dim = 512
        # config.hidden_size = 2048
        # config.num_attention_heads = 8
        # self.VisualBertEncoder = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.VisualBertEncoder = VisualBertModel(config)
        self.classifier = nn.Linear(config.hidden_size, num_class)

        # self.VisualBertEncoder = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa")
        # self.VisualBertEncoder.cls = nn.Linear(config.hidden_size, num_class)

    def forward(self, inputs, visual_embeds):
        # prepare visual embedding
        visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long).to(device)
        visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float).to(device)

        # append visual features to text
        inputs.update({
                        "visual_embeds": visual_embeds.to(device),
                        "visual_token_type_ids": visual_token_type_ids,
                        "visual_attention_mask": visual_attention_mask,
                        "output_attentions": True
                        })

        inputs['input_ids'] = inputs['input_ids'].to(device)
        inputs['token_type_ids'] = inputs['token_type_ids'].to(device)
        inputs['attention_mask'] = inputs['attention_mask'].to(device)
        inputs['visual_token_type_ids'] = inputs['visual_token_type_ids'].to(device)
        inputs['visual_attention_mask'] = inputs['visual_attention_mask'].to(device)

        outputs = self.VisualBertEncoder(**inputs)
        outputs = self.classifier(outputs['pooler_output'])
        return outputs

num_classes = 18

url = 'https://www.frontiersin.org/files/MyHome%20Article%20Library/446547/446547_Thumb_400.jpg'
img = Image.open(requests.get(url, stream=True).raw)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])


#visual feature
img_preprocessed = preprocess(img)
batch_img_tensor = torch.stack([img_preprocessed, img_preprocessed])
model_visual_feat = models.resnet18(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()

#text feature
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
question = "What is the state of bipolar_forceps?"
questions = list([question, question])

model = VisualBertClassification(num_class=18).to(device)
smoothing_layer = get_gaussian_filter(ksize=3, sigma=0.1,channels=49)
#training
for epoch in range(2):
    visual_embeds = model_visual_feat(batch_img_tensor).view(-1, 49, 512)
    sigma = torch.tensor(np.random.uniform(0, 0.2, 49))
    is_smoothing = np.random.rand() > 0.5
    if is_smoothing:
        smoothing_layer = get_gaussian_filter(ksize=3, sigma=sigma, channels=49)
        visual_embeds = smoothing_layer(visual_embeds.unsqueeze(3)).squeeze()

    print('epoch:{}, smoothing:{}'.format(epoch, is_smoothing))
    print(visual_embeds[0,0,:2])
    inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")
    logits = model(inputs, visual_embeds)





epoch:0, smoothing:True
tensor([0.0000, 0.5964], grad_fn=<SliceBackward0>)
epoch:1, smoothing:True
tensor([2.6795e-08, 5.9638e-01], grad_fn=<SliceBackward0>)
