In [1]:
!conda info --envs


# conda environments:
#
base                     /home/dtrad/anaconda3
arvin                    /home/dtrad/anaconda3/envs/arvin
devito                   /home/dtrad/anaconda3/envs/devito
dummy                    /home/dtrad/anaconda3/envs/dummy
genai                    /home/dtrad/anaconda3/envs/genai
genai0                   /home/dtrad/anaconda3/envs/genai0
general                  /home/dtrad/anaconda3/envs/general
jax                      /home/dtrad/anaconda3/envs/jax
langchain                /home/dtrad/anaconda3/envs/langchain
langchain0               /home/dtrad/anaconda3/envs/langchain0
llmbook                  /home/dtrad/anaconda3/envs/llmbook
llmbook2                 /home/dtrad/anaconda3/envs/llmbook2
mlcrewes                 /home/dtrad/anaconda3/envs/mlcrewes
mlcrewes0                /home/dtrad/anaconda3/envs/mlcrewes0
pylops                   /home/dtrad/anaconda3/envs/pylops
pyparallel               /home/dtrad/anaconda3/envs/pyparallel
pytorch                  /home

In [2]:
import torch
import torch.nn as nn
from transformers import BertModel, ViTModel

In [3]:
# Define the Multimodal Transformer model
class MultimodalTransformer(nn.Module):
    def __init__(self, text_model_name="bert-base-uncased", image_model_name="google/vit-base-patch16-224"):
        super(MultimodalTransformer, self).__init__()
        
        # Load pre-trained text (BERT) and image (ViT) models
        self.text_model = BertModel.from_pretrained(text_model_name)
        self.image_model = ViTModel.from_pretrained(image_model_name)
        
        # Hidden size from the text and image models
        text_hidden_size = self.text_model.config.hidden_size
        image_hidden_size = self.image_model.config.hidden_size
        
        # Combine the text and image embeddings
        self.fc = nn.Linear(text_hidden_size + image_hidden_size, 256)
        self.classifier = nn.Linear(256, 2)  # Binary classification
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, input_ids, attention_mask, pixel_values):
        # Process text inputs
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_embedding = text_outputs.pooler_output  # [batch_size, text_hidden_size]
        
        # Process image inputs
        image_outputs = self.image_model(pixel_values=pixel_values)
        image_embedding = image_outputs.pooler_output  # [batch_size, image_hidden_size]
        
        # Concatenate text and image embeddings
        combined_embedding = torch.cat((text_embedding, image_embedding), dim=1)
        
        # Pass through fully connected layers and classifier
        combined_embedding = self.dropout(self.fc(combined_embedding))
        logits = self.classifier(combined_embedding)
        
        return logits


In [4]:
# Instantiate the model
model = MultimodalTransformer()

# Print the model architecture
print(model)


Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/68.0k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/330M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


MultimodalTransformer(
  (text_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

In [5]:
from transformers import BertTokenizer, ViTFeatureExtractor
from PIL import Image
import requests

In [6]:
# Load tokenizer and feature extractor
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/160 [00:00<?, ?B/s]

In [9]:
# Example text and image
text = "A cat sitting on a chair."
image_url = "https://example.com/cat.jpg"
#image = Image.open(requests.get(image_url, stream=True).raw)
image = Image.open("cat.jpeg")
# Tokenize text
text_inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

# Extract features from the image
image_inputs = feature_extractor(images=image, return_tensors="pt")

# Forward pass
logits = model(
    input_ids=text_inputs["input_ids"],
    attention_mask=text_inputs["attention_mask"],
    pixel_values=image_inputs["pixel_values"]
)

# Output
print("Logits:", logits)


Logits: tensor([[-0.1802,  0.1249]], grad_fn=<AddmmBackward0>)


In [10]:
image.show()