Experiment 1
Objectives : 
1. Combine DinoV3-ViTS+ and Gemma-3-270M and get it to output something
2. build the data pipeline for MS-COCO image captioning dataset. test the model on the pipeline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoImageProcessor, AutoModel
from transformers.image_utils import load_image

device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu")

gemma_id = "google/gemma-3-270m"
vit_id = "facebook/dinov3-vits16plus-pretrain-lvd1689m"

gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_id)
vit_tokenizer = AutoImageProcessor.from_pretrained(vit_id)

gemma = AutoModelForCausalLM.from_pretrained(
    gemma_id,
    dtype=torch.float32,
    device_map = "auto",
).to(device)

vit = AutoModel.from_pretrained(
    vit_id,
    dtype=torch.float32,
    device_map = "auto",
).to(device)

for parameter in vit.parameters():
    parameter.requires_grad = False

In [None]:
gemma_embed_size =  gemma.get_input_embeddings().weight.shape[1]
vit_embed_size = vit.config.hidden_size

mlp_adapter = nn.Sequential(
    nn.LayerNorm(vit_embed_size),
    nn.Linear(vit_embed_size, gemma_embed_size),
    nn.GELU(approximate='tanh'),
    nn.Linear(gemma_embed_size, gemma_embed_size)
)

def prepare_inputs(images, captions):
    with torch.no_grad():
        vit_out = vit(images)

    image_embed = mlp_adapter(vit_out)

    tok = gemma_tokenizer(captions, return_tensors='pt', padding=True)
    input_ids = tok["input_ids"].to(device)
    attention_mask = tok["attention_mask"].to(device)

    input_embeds = gemma.get_input_embeddings()(input_ids)
    bs, num_txt_tok, gem_dim = input_embeds.shape
    num_img_tok = image_embed.shape[1]

    new_embeds = torch.cat([image_embed, input_embeds], dim=1) # (bs, img_tok+txt_tok, gem_dim)
    new_mask = torch.cat([torch.ones(bs, num_img_tok, device=device), attention_mask], dim=1) #(bs, img_tok+txt_tok, gem_dim)

    return new_embeds, new_mask, input_ids, num_img_tok

In [None]:
from datasets import load_dataset

train_dataset = load_dataset("lmms-lab/COCO-Caption2017", split="train")
train_dataset[0]