<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Transformers_All_VQA_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q transformers
!pip -q install timm

In [2]:
import torch
from torch import nn
from transformers import VisualBertModel, VisualBertConfig, BertTokenizerFast
from PIL import Image
import requests
from torchvision.models import resnet18, resnet34, resnet101
from torchvision import transforms
from timm import create_model

img_url = 'https://www.animalfunfacts.net/images/stories/pets/dogs/pembroke_welsh_corgi_l.jpg'
img_raw = Image.open(requests.get(img_url, stream=True).raw)
mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.ToTensor(),
                                transforms.Normalize(mean=mean, std=std)])
img = transform(img_raw)[None]

test_question = ["Where is the dog?"]
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
inputs = bert_tokenizer(test_question, return_tensors="pt", padding="max_length",max_length=20,)


VisualBERT (ResNet101)

In [None]:
class VisualBERT_VQA(nn.Module):
    def __init__(self, num_labels=2):
        super(VisualBERT_VQA, self).__init__()
        self.visualbert = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.cls = nn.Linear(768, num_labels)

    def forward(self, inputs):
        last_hidden_state = self.visualbert(**inputs).last_hidden_state #[1, 56, 768]

        # Get the index of the last text token
        index_to_gather = inputs['attention_mask'].sum(1) - 2  # as in original code 5
        index_to_gather = (
            index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, last_hidden_state.size(-1))
        ) # [b c hw]=[1, 1, 768]

        pooled_output = torch.gather(last_hidden_state, 1, index_to_gather) # [1, 1, 768]
        logits = self.cls(pooled_output).squeeze(1)
        return logits

model_visual_feat = resnet101(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(img).view(-1, 49, 2048)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)

inputs.update(
    {
        "visual_embeds": visual_embeds,
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print('visual_embeds', visual_embeds.shape, 'Text:', inputs['input_ids'].shape)
model = VisualBERT_VQA()
model.eval()
logits = model(inputs)
pred_vqa = logits.argmax(-1)
print('Logits:',logits, 'Prediction:', pred_vqa)  

visual_embeds torch.Size([1, 49, 2048]) Text: torch.Size([1, 20])


Some weights of the model checkpoint at uclanlp/visualbert-vqa-coco-pre were not used when initializing VisualBertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing VisualBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VisualBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Logits: tensor([[-0.4455,  0.3589]], grad_fn=<SqueezeBackward1>) Prediction: tensor([1])


VisualBERT (ResNet34)

In [None]:
class VisualBERT_VQA(nn.Module):
    def __init__(self, num_labels=2):
        super(VisualBERT_VQA, self).__init__()
        self.config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.config.visual_embedding_dim = 512
        self.visualbert = VisualBertModel(config=self.config)#.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        #self.embeddings = self.visual_bert.embeddings
        self.cls = nn.Linear(768, num_labels)

    def forward(self, inputs):
        last_hidden_state = self.visualbert(**inputs).last_hidden_state #[1, 56, 768]

        # Get the index of the last text token
        index_to_gather = inputs['attention_mask'].sum(1) - 2  # as in original code 5
        index_to_gather = (
            index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, last_hidden_state.size(-1))
        ) # [b c hw]=[1, 1, 768]
        pooled_output = torch.gather(last_hidden_state, 1, index_to_gather) # [1, 1, 768]
        logits = self.cls(pooled_output).squeeze(1)
        return logits

model_visual_feat = resnet34(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(img).view(-1, 49, 512)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
inputs.update(
    {
        "visual_embeds": visual_embeds,
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print('visual_embeds', visual_embeds.shape, 'Text:', inputs['input_ids'].shape)

model = VisualBERT_VQA()
model.eval()
logits = model(inputs)
pred_vqa = logits.argmax(-1)
print('Logits:',logits, 'Prediction:', pred_vqa)        




visual_embeds torch.Size([1, 49, 512]) Text: torch.Size([1, 20])
self.visualbert.config.visual_embedding_dim: 512
tensor([5]) 1 768
torch.Size([1, 1, 768])
Logits: tensor([[ 0.5161, -0.5943]], grad_fn=<SqueezeBackward1>) Prediction: tensor([0])


# ViT_VQA(ResNet18)

(BertTokenizerFast = AutoTokenizer)

In [3]:
from timm import create_model
class ViT_VQA(nn.Module):
    def __init__(self, num_labels=2):
        super(ViT_VQA, self).__init__()
        self.config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.config.visual_embedding_dim = 512
        self.visualbert = VisualBertModel(config=self.config)
        self.embeddings = self.visualbert.embeddings

        self.vit = create_model("vit_base_patch16_224", pretrained=True)
        self.cls = nn.Linear(768, num_labels)

    def forward(self, inputs):
        embedding_output = self.embeddings(
            input_ids=inputs['input_ids'],
            token_type_ids=inputs['token_type_ids'],
            position_ids=None,
            inputs_embeds=None,
            visual_embeds=inputs['visual_embeds'],
            visual_token_type_ids=inputs['visual_token_type_ids'],
            image_text_alignment=None,
        ) #[1, 56, 768]
        
        x = self.vit.blocks(embedding_output)
        x = self.vit.norm(x)
        x = x.mean(dim=1)
        logits = self.cls(x)
        return logits

model_visual_feat = resnet18(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(img).view(-1, 49, 512)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
inputs.update(
    {
        "visual_embeds": visual_embeds,
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print('visual_embeds', visual_embeds.shape, 'Text:', inputs['input_ids'].shape)

model = ViT_VQA(num_labels=2)
model.eval()
logits = model(inputs)
pred_vqa = logits.argmax(-1)
print('Logits:',logits, 'Prediction:', pred_vqa) 

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

visual_embeds torch.Size([1, 49, 512]) Text: torch.Size([1, 20])


Downloading config.json:   0%|          | 0.00/631 [00:00<?, ?B/s]

Logits: tensor([[ 1.0024, -1.4413]], grad_fn=<AddmmBackward0>) Prediction: tensor([0])


# DeiT_VQA(ResNet18)
DeiT: Data-efficient Image Transformers - https://arxiv.org/abs/2012.12877

In [5]:
from timm import create_model
class DeiT_VQA(nn.Module):
    def __init__(self, num_labels=2):
        super(DeiT_VQA, self).__init__()
        self.config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.config.visual_embedding_dim = 512
        self.visualbert = VisualBertModel(config=self.config)
        self.embeddings = self.visualbert.embeddings

        self.deit = create_model("deit_base_patch16_224", pretrained=True)
        self.cls = nn.Linear(768, num_labels)

    def forward(self, inputs):
        embedding_output = self.embeddings(
            input_ids=inputs['input_ids'],
            token_type_ids=inputs['token_type_ids'],
            position_ids=None,
            inputs_embeds=None,
            visual_embeds=inputs['visual_embeds'],
            visual_token_type_ids=inputs['visual_token_type_ids'],
            image_text_alignment=None,
        ) #[1, 56, 768]
        x = self.deit.blocks(embedding_output)
        x = self.deit.norm(x)
        x = x.mean(dim=1)
        logits = self.cls(x)
        return logits

model_visual_feat = resnet18(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(img).view(-1, 49, 512)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
inputs.update(
    {
        "visual_embeds": visual_embeds,
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print('visual_embeds', visual_embeds.shape, 'Text:', inputs['input_ids'].shape)

model = DeiT_VQA(num_labels=2)
model.eval()
logits = model(inputs)
pred_vqa = logits.argmax(-1)
print('Logits:',logits, 'Prediction:', pred_vqa) 

visual_embeds torch.Size([1, 49, 512]) Text: torch.Size([1, 20])


Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /root/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth


Logits: tensor([[-0.4111, -0.2708]], grad_fn=<AddmmBackward0>) Prediction: tensor([1])


# CaiT_VQA(ResNet18)
CaiT: Class-Attention in Image Transformers (https://arxiv.org/abs/2103.17239)

In [18]:
from timm import create_model
class CaiT_VQA(nn.Module):
    def __init__(self, num_labels=2):
        super(CaiT_VQA, self).__init__()
        self.config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.config.visual_embedding_dim = 512
        self.config.hidden_size = 192
        self.visualbert = VisualBertModel(config=self.config)
        self.embeddings = self.visualbert.embeddings

        self.cait = create_model("cait_xxs24_224", pretrained=True)
        self.cls = nn.Linear(192, num_labels)

    def forward(self, inputs):
        embedding_output = self.embeddings(
            input_ids=inputs['input_ids'],
            token_type_ids=inputs['token_type_ids'],
            position_ids=None,
            inputs_embeds=None,
            visual_embeds=inputs['visual_embeds'],
            visual_token_type_ids=inputs['visual_token_type_ids'],
            image_text_alignment=None,
        ) #[1, 69, 768]
        x = self.cait.blocks(embedding_output)
        cls_tokens = self.cait.cls_token.expand(x.shape[0], -1, -1)
        for i, blk in enumerate(self.cait.blocks_token_only):
            cls_tokens = blk(x, cls_tokens)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.cait.norm(x)
        x = x.mean(dim=1)
        #x = x[:, 0]
        logits = self.cls(x)
        return logits

model_visual_feat = resnet18(pretrained=True)
model_visual_feat.avgpool = nn.Identity()
model_visual_feat.fc = nn.Identity()
model_visual_feat.eval()
visual_embeds = model_visual_feat(img).view(-1, 49, 512)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
inputs.update(
    {
        "visual_embeds": visual_embeds,
        "visual_token_type_ids": visual_token_type_ids,
        "visual_attention_mask": visual_attention_mask,
    }
)

print('visual_embeds', visual_embeds.shape, 'Text:', inputs['input_ids'].shape)

model = CaiT_VQA(num_labels=2)
model.eval()
logits = model(inputs)
pred_vqa = logits.argmax(-1)
print('Logits:',logits, 'Prediction:', pred_vqa) 

visual_embeds torch.Size([1, 49, 512]) Text: torch.Size([1, 20])
Logits: tensor([[-0.2232,  1.4499]], grad_fn=<AddmmBackward0>) Prediction: tensor([1])


# Swin-Transformer_VQA()

In [12]:
from torch import nn
from transformers import VisualBertModel, VisualBertConfig

# Initializing a VisualBERT visualbert-vqa-coco-pre style configuration
config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
class VisualBertEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))


    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        visual_embeds=None,
        visual_token_type_ids=None,
        image_text_alignment=None,
    ):

        input_shape = input_ids.size()
        seq_length = input_shape[1]
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = inputs_embeds + token_type_embeddings

        # Absolute Position Embeddings
        position_embeddings = self.position_embeddings(position_ids)
        embeddings += position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings

class SwinTranformer_VQA(nn.Module):
    def __init__(self, num_labels=2):
        super(SwinTranformer_VQA, self).__init__()
        self.config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        #self.config.visual_embedding_dim = 128
        self.config.hidden_size = 1024 
        self.embeddings = VisualBertEmbeddings(config=self.config)

        self.swintran = create_model("swin_base_patch4_window7_224", pretrained=True)
        self.cls = nn.Linear(1024, num_labels)

    def forward(self, inputs):
        embedding_output = self.embeddings(
            input_ids=inputs['input_ids'],
            token_type_ids=inputs['token_type_ids'],
            position_ids=None,
            inputs_embeds=None,
            visual_embeds=inputs['visual_embeds'],
            visual_token_type_ids=None,
            image_text_alignment=None,
        ) #[1, 69, 768]
        x = self.swintran.patch_embed(inputs['visual_embeds'])
        x = self.swintran.layers(x)
        x = torch.cat((x, embedding_output), dim=1)
        x = self.swintran.norm(x)
        x = x.mean(dim=1)
        logits = self.cls(x)
        return logits

inputs.update(
    {
        "visual_embeds": img,
        "visual_token_type_ids": None,
        "visual_attention_mask": None,
    }
)

print( 'Text Embedding:', inputs['input_ids'].shape)

model = SwinTranformer_VQA(num_labels=2)
model.eval()
logits = model(inputs)
pred_vqa = logits.argmax(-1)
print('Logits:',logits, 'Prediction:', pred_vqa) 

visual_embeds torch.Size([1, 3136, 128]) Text: torch.Size([1, 20])
Logits: tensor([[0.0392, 0.5809]], grad_fn=<AddmmBackward0>) Prediction: tensor([1])


# Ablation on All Transformers for Classification

# Swin-Transformer

In [28]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from timm import create_model


class SwinTranformer_Features(nn.Module):
    def __init__(self):
        super(SwinTranformer_Features, self).__init__()
        self.swintran = create_model("swin_base_patch4_window7_224", pretrained=True)

    def forward(self, x):
        print(x.shape)
        x = self.swintran.patch_embed(x) # [1, 3136, 128] # vit: [1, 196, 768]
        print(x.shape)
        x = self.swintran.layers(x)#[1, 49, 1024]
        print(x.shape)
        x = self.swintran.norm(x)#[1, 49, 1024] # [1, 197, 768]
        print(x.shape)
        x = x.mean(dim=1)#[1, 1024]
        print(x.shape)
        logits = self.swintran.head(x)#
        return logits

model = SwinTranformer_Features()
model.eval()
logits = model(img)
pred = logits.argmax(dim=1).item()
print('prediction:', int(torch.argmax(logits)))

torch.Size([1, 3, 224, 224])
torch.Size([1, 3136, 128])
torch.Size([1, 49, 1024])
torch.Size([1, 49, 1024])
torch.Size([1, 1024])
prediction: 263


# ViT

In [32]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from timm import create_model


class ViT_Features(nn.Module):
    def __init__(self):
        super(ViT_Features, self).__init__()
        model_name = "vit_base_patch16_224"
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.vit = create_model(model_name, pretrained=True).to(device)

    def forward(self, x):
        patches = self.vit.patch_embed(x) # [1, 196, 768]
        pos_embed = self.vit.pos_embed # [1, 197, 768]
        print(patches.shape, pos_embed.shape)
        print(torch.cat((self.vit.cls_token, patches), dim=1).shape)
        x = torch.cat((self.vit.cls_token, patches), dim=1) + pos_embed #[1, 197, 768]
        x = self.vit.blocks(x)
        print('self.vit.blocks(x):', x.shape)
        # for i, blk in enumerate(self.vit.blocks):
        #     x = blk(x)
        x = self.vit.norm(x)
        x = x.mean(dim=1)
        logits = self.vit.head(x)
        return logits

model = ViT_Features()
model.eval()
logits = model(img)
pred = logits.argmax(dim=1).item()
print('prediction:', int(torch.argmax(logits)))

torch.Size([1, 196, 768]) torch.Size([1, 197, 768])
torch.Size([1, 197, 768])
self.vit.blocks(x): torch.Size([1, 197, 768])
prediction: 263


# DeiT: Data-efficient Image Transformers - https://arxiv.org/abs/2012.12877

In [13]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from timm import create_model


class DeiT_Features(nn.Module):
    def __init__(self):
        super(DeiT_Features, self).__init__()
        self.deit = create_model("deit_base_patch16_224", pretrained=True)

    def forward(self, x):
        patches = self.deit.patch_embed(x) # [1, 196, 768]
        pos_embed = self.deit.pos_embed # [1, 197, 768]
        x = torch.cat((self.deit.cls_token, patches), dim=1) + pos_embed #[1, 197, 768]
        x = self.deit.blocks(x)
        x = self.deit.norm(x)
        x = x.mean(dim=1)
        logits = self.deit.head(x)
        return logits

model = DeiT_Features()
model.eval()
logits = model(img)
pred = logits.argmax(dim=1).item()
print('prediction:', int(torch.argmax(logits)))

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /root/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth


prediction: 263


# CaiT: Class-Attention in Image Transformers (https://arxiv.org/abs/2103.17239)

In [15]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from timm import create_model


class CaiT_Features(nn.Module):
    def __init__(self):
        super(CaiT_Features, self).__init__()
        self.cait = create_model("cait_xxs24_224", pretrained=True)

    def forward(self, x):
        patches = self.cait.patch_embed(x) # [1, 196, 768]
        pos_embed = self.cait.pos_embed # [1, 197, 768]
        x = patches + pos_embed #[1, 196, 192]
        print(x.shape)
        x = self.cait.blocks(x)
        cls_tokens = self.cait.cls_token.expand(x.shape[0], -1, -1)
        for i, blk in enumerate(self.cait.blocks_token_only):
            cls_tokens = blk(x, cls_tokens)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.cait.norm(x)
        #x = x.mean(dim=1)
        x = x[:, 0]
        logits = self.cait.head(x)
        return logits

model = CaiT_Features()
model.eval()
logits = model(img)
pred = logits.argmax(dim=1).item()
print('prediction:', int(torch.argmax(logits)))

torch.Size([1, 196, 192])
prediction: 263


# BeiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)

In [12]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from timm import create_model


class BeiT_Features(nn.Module):
    def __init__(self):
        super(BeiT_Features, self).__init__()
        self.beit = create_model("beit_base_patch16_224", pretrained=True)

    def forward(self, x):
        x = self.beit.patch_embed(x) # [1, 196, 768]
        x = torch.cat((self.beit.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        rel_pos_bias = self.beit.rel_pos_bias() if self.beit.rel_pos_bias is not None else None
        for blk in self.beit.blocks:
            x = blk(x, shared_rel_pos_bias=rel_pos_bias)
        x = self.beit.norm(x)
        x = x[:, 1:].mean(dim=1)
        x = self.beit.fc_norm(x)
        logits = self.beit.head(x)
        return logits

model = BeiT_Features()
model.eval()
logits = model(img)
pred = logits.argmax(dim=1).item()
print('prediction:', int(torch.argmax(logits)))

prediction: 129


# CoaT: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399

In [24]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from timm import create_model


class CoaT_Features(nn.Module):
    def __init__(self):
        super(CoaT_Features, self).__init__()
        self.beit = create_model("coat_mini", pretrained=True)

    def forward(self, x):
        x_feat = self.beit.forward_features(x)
        x = self.beit.forward_head(x_feat)
        return x

model = CoaT_Features()
model.eval()
logits = model(img)
pred = logits.argmax(dim=1).item()
print('prediction:', int(torch.argmax(logits)))

prediction: 263


# CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification (et al. ICCV 2021)

In [26]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from timm import create_model


class CrossViT_Features(nn.Module):
    def __init__(self):
        super(CrossViT_Features, self).__init__()
        self.crossvit = create_model("coat_mini", pretrained=True)

    def forward(self, x):
        x_feat = self.crossvit.forward_features(x)
        x = self.crossvit.forward_head(x_feat)
        return x

model = CrossViT_Features()
model.eval()
logits = model(img)
pred = logits.argmax(dim=1).item()
print('prediction:', int(torch.argmax(logits)))

prediction: 263


# ConvMixer: Patches Are All You Need? (https://arxiv.org/pdf/2201.09792.pdf)

In [18]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from timm import create_model


class BeiT_Features(nn.Module):
    def __init__(self):
        super(BeiT_Features, self).__init__()
        self.beit = create_model("convmixer_768_32", pretrained=True)

    def forward(self, x):
        x = self.beit.stem(x)
        x = self.beit.blocks(x)
        x = self.beit.pooling(x)
        logits = self.beit.head(x)
        return logits

model = BeiT_Features()
model.eval()
logits = model(img)
pred = logits.argmax(dim=1).item()
print('prediction:', int(torch.argmax(logits)))

prediction: 263


# ConvNeXt: A ConvNet for the 2020s - https://arxiv.org/pdf/2201.03545.pdf

In [22]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from timm import create_model


class ConvNeXt_Features(nn.Module):
    def __init__(self):
        super(ConvNeXt_Features, self).__init__()
        self.convnext = create_model("convnext_base", pretrained=True)

    def forward(self, x):
        x = self.convnext.stem(x)
        x = self.convnext.stages(x)
        x = self.convnext.norm_pre(x)
        x = self.convnext.head.global_pool(x)
        x = self.convnext.head.norm(x)
        x = self.convnext.head.flatten(x)
        x = self.convnext.head.drop(x)
        logits = self.convnext.head.fc(x)
        return logits

model = ConvNeXt_Features()
model.eval()
logits = model(img)
pred = logits.argmax(dim=1).item()
print('prediction:', int(torch.argmax(logits)))

prediction: 263


In [25]:
model = create_model("crossvit_base_240", pretrained=True)
model.eval()
logits = model(img)
pred = logits.argmax(dim=1).item()
print(pred)

Downloading: "https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth" to /root/.cache/torch/hub/checkpoints/crossvit_base_224.pth


263


In [None]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

# Embedding from Scratch

In [33]:
from torch import nn
from transformers import VisualBertModel, VisualBertConfig

# Initializing a VisualBERT visualbert-vqa-coco-pre style configuration
config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
class VisualBertEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

        # For Visual Features
        # Token type and position embedding for image features
        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        if config.special_visual_initialize:
            self.visual_token_type_embeddings.weight.data = nn.Parameter(
                self.token_type_embeddings.weight.data.clone(), requires_grad=True
            )
            self.visual_position_embeddings.weight.data = nn.Parameter(
                self.position_embeddings.weight.data.clone(), requires_grad=True
            )

        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)


    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        visual_embeds=None,
        visual_token_type_ids=None,
        image_text_alignment=None,
    ):

        input_shape = input_ids.size()
        seq_length = input_shape[1]
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        print('bef', input_ids.shape)
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        
        print('af', inputs_embeds.shape)

        print('token_type_ids', token_type_ids.shape)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        print('token_type_embeddings', token_type_embeddings.shape)
        embeddings = inputs_embeds + token_type_embeddings

        # Absolute Position Embeddings
        position_embeddings = self.position_embeddings(position_ids)
        embeddings += position_embeddings

        print('before:',visual_embeds.shape)
        visual_embeds = self.visual_projection(visual_embeds)
        print('after:',visual_embeds.shape)
        print('bef', visual_token_type_ids.shape)
        visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)
        print('af', visual_token_type_embeddings.shape)
        visual_position_ids = torch.zeros(
            *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device
        )
        print('bef',visual_position_ids.shape)
        visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)
        print('bef',visual_position_embeddings.shape)
        visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings
        print('visual_embeddings', visual_embeddings.shape)
        embeddings = torch.cat((embeddings, visual_embeddings), dim=1)

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings