In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import BertModel, BertTokenizer
from PIL import Image
from MultimodalContextualAttentionNetwork import MultimodalContextualAttention
from HierarchicalEncodingNetwork import HierarchicalEncodingNetwork as HMCAN
from resnet_50 import ModifiedResNet50 as resnet

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


### Data Preprocessing

In [2]:
def preprocess_text(text):

    bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    # Tokenize text
    tokens = bert_tokenizer.tokenize(text)

    # Add special tokens [CLS] and [SEP]
    tokens = ['[CLS]'] + tokens + ['[SEP]']

    # Convert tokens to token IDs
    input_ids = bert_tokenizer.convert_tokens_to_ids(tokens)

    original_input_len = len(input_ids)

    # Pad or truncate input IDs to a fixed length
    max_length = 512
    padding_length = max_length - len(input_ids)
    input_ids = input_ids + [0] * padding_length  # Padding token ID for BERT

    # Create attention mask
    attention_mask = [1]*original_input_len  + [0]*padding_length # 1 for real tokens, 0 for padding tokens

    # Convert lists to tensors
    input_ids = torch.tensor(input_ids).unsqueeze(0)  # Add batch dimension
    attention_mask = torch.tensor(attention_mask).unsqueeze(0)  # Add batch dimension

    return input_ids, attention_mask

In [3]:
text = 'My name is Slim Shady'
filename = 'img.jpeg'

In [4]:
id, attn_mask = preprocess_text(text)
print(f'size of id: {id.shape}')
print(f'size of attn mask: {attn_mask.shape}')

size of id: torch.Size([1, 512])
size of attn mask: torch.Size([1, 512])


In [5]:
preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [6]:
input_image = Image.open(filename)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
print(f'size of processed image: {input_batch.shape}')

size of processed image: torch.Size([1, 3, 224, 224])


In [7]:
input2 = input_batch.view(input_batch.size(0), input_batch.size(1), -1)  # Changing to [batch, channels, height * width]
input2 = input2.permute(0, 2, 1)  # Rearrange to [batch, height * width, channels]
input2.shape


torch.Size([1, 50176, 3])

### Model Evaluation

In [8]:
text_d_model=768
img_d_model=2048

In [9]:
bert_model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)

# Freeze BERT parameters
for param in bert_model.parameters():
    param.requires_grad = False

In [10]:
resnet_model = resnet()

# Freeze all layers in the feature extractor
for param in resnet_model.resnet_feature_extractor.parameters():
    param.requires_grad = False

In [11]:
mcan = MultimodalContextualAttention(d_model=text_d_model, nhead=8, dim_feedforward=img_d_model, dropout=0.1)

In [12]:
hmcan = HMCAN(mcan_model=mcan, bert=bert_model, resnet=resnet_model, output_dim=(text_d_model+img_d_model), groups=3)

In [13]:
out = hmcan(id, attn_mask, input_batch)

After resnet
Img feature size = torch.Size([1, 49, 768]) 

Getting grouped text features size: torch.Size([1, 512, 768]) 

Inserting text and image into mcan

Inside mcan
Text: torch.Size([1, 512, 768])
Image: torch.Size([1, 49, 768])
Entering Contextual Transformers

Input1 before encoding: torch.Size([1, 512, 768])
Input2 before cross attn: torch.Size([1, 49, 768])
Inside Cross Attention
Size of Query: torch.Size([1, 512, 768])
Size of Key: torch.Size([1, 512, 768])
Size of Value: torch.Size([1, 512, 768])

Re-arranging Q, K, V
Size of Query: torch.Size([512, 1, 768])
Size of Key: torch.Size([512, 1, 768])
Size of Value: torch.Size([512, 1, 768])

After multi head
Size of Query: torch.Size([512, 1, 768])
Input1 after encoding: torch.Size([1, 512, 768])
Input2 after cross attn: torch.Size([1, 512, 768])
Pooling outputs...
Input1 after pooling: torch.Size([1, 768])
Input2 after pooling: torch.Size([1, 768])
Combined output: torch.Size([1, 1536])
Input1 before encoding: torch.Size([1, 4

AssertionError: expecting key_padding_mask shape of (1, 49), but got torch.Size([1, 512])