Python code to train the projection layer with another dataset

In [None]:
import torch
import torch.nn as nn
from safetensors.torch import load_file
from transformers import ViTModel, AutoTokenizer, AutoModelForCausalLM, AutoConfig
from tqdm import tqdm

# --- Device ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Load ViT Model (Frozen) ---
vit_config = AutoConfig.from_pretrained("./vit")
vit_model = ViTModel.from_pretrained("./vit", config=vit_config)
vit_model.load_state_dict(load_file("./vit/model.safetensors"))
vit_model.to(device).eval()
for p in vit_model.parameters(): p.requires_grad = False

# --- Load Phi-2 (Frozen) ---
phi_tokenizer = AutoTokenizer.from_pretrained("./phi-2")
phi_tokenizer.pad_token = phi_tokenizer.eos_token
phi_model = AutoModelForCausalLM.from_pretrained("./phi-2", torch_dtype=torch.float16, device_map={"": device})
phi_model.eval()
for p in phi_model.parameters(): p.requires_grad = False

# --- Load Trained Projection Layer ---
vit_dim = vit_model.config.hidden_size
phi_dim = phi_model.config.hidden_size
projector = nn.Linear(vit_dim, phi_dim).to(device)
projector.load_state_dict(load_file("projector_finetuned.safetensors"))
projector.train()


In [1]:
def fine_tune_projection_layer(dataloader, projector, num_epochs=3, lr=1e-4):
    optimizer = torch.optim.AdamW(projector.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss(ignore_index=phi_tokenizer.pad_token_id)

    for epoch in range(num_epochs):
        epoch_loss = 0
        projector.train()

        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            images = batch["image"].to(device)
            questions = batch["question"]
            answers = batch["answer"]

            with torch.no_grad():
                vit_out = vit_model(pixel_values=images)
                cls_token = vit_out.last_hidden_state[:, 0, :]
            projected_embed = projector(cls_token).unsqueeze(1).to(torch.float16)

            prompts = [f"Question: {q.strip()} Answer:" for q in questions]
            full_texts = [f"{p} {a.strip()}" for p, a in zip(prompts, answers)]

            inputs = phi_tokenizer(full_texts, return_tensors="pt", padding=True, truncation=True)
            input_ids = inputs.input_ids.to(device)
            attention_mask = inputs.attention_mask.to(device)

            with phi_tokenizer.as_target_tokenizer():
                prompt_ids = phi_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

            labels = input_ids.clone()
            for i in range(labels.shape[0]):
                prompt_len = (prompt_ids[i] != phi_tokenizer.pad_token_id).sum()
                labels[i, :prompt_len] = -100

            token_embeds = phi_model.model.embed_tokens(input_ids).to(torch.float16)
            inputs_embeds = torch.cat([projected_embed, token_embeds], dim=1)

            attention_mask = torch.cat([
                torch.ones((attention_mask.shape[0], 1), dtype=torch.long).to(device),
                attention_mask
            ], dim=1)
            labels = torch.cat([
                torch.full((labels.shape[0], 1), fill_value=-100).to(device),
                labels
            ], dim=1)

            output = phi_model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = output.loss
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1} | Loss: {epoch_loss / len(dataloader):.4f}")


How to Load the complete model for the Inference

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, ViTModel
from safetensors.torch import load_file
import os

# --- Device Setup ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Load ViT (Frozen) ---
vit_config = AutoConfig.from_pretrained("./vit")
vit_model = ViTModel.from_pretrained("./vit", config=vit_config)
vit_model.load_state_dict(load_file("./vit/model.safetensors"))
vit_model = vit_model.to(device).eval()
for p in vit_model.parameters(): p.requires_grad = False

# --- Load Phi-2 Model (Frozen) ---
phi_tokenizer = AutoTokenizer.from_pretrained("./phi-2")
phi_tokenizer.pad_token = phi_tokenizer.eos_token  # ensure pad_token exists
phi_model = AutoModelForCausalLM.from_pretrained(
    "./phi-2", torch_dtype=torch.float16, device_map={"": device}
)
phi_model.eval()
for p in phi_model.parameters(): p.requires_grad = False

# --- Load Trained Projection Layer ---
vit_dim = vit_model.config.hidden_size
phi_dim = phi_model.config.hidden_size

projector = nn.Linear(vit_dim, phi_dim).to(device)
projector.load_state_dict(load_file("projector_finetuned.safetensors"))
projector.eval()


In [None]:
def run_inference(image_tensor, question_text):
    """
    Args:
        image_tensor: torch.Tensor of shape [1, 3, H, W]
        question_text: str
    Returns:
        generated answer: str
    """
    with torch.no_grad():
        # 1. Get ViT image embedding
        vit_out = vit_model(pixel_values=image_tensor.to(device))
        cls_token = vit_out.last_hidden_state[:, 0, :]  # [1, vit_dim]
        projected_embed = projector(cls_token).unsqueeze(1).to(torch.float16)  # [1, 1, phi_dim]

        # 2. Tokenize prompt
        prompt = f"Question: {question_text.strip()} Answer:"
        inputs = phi_tokenizer(prompt, return_tensors="pt").to(device)

        token_embeds = phi_model.get_input_embeddings()(inputs.input_ids).to(torch.float16)
        inputs_embeds = torch.cat([projected_embed, token_embeds], dim=1)

        attention_mask = torch.cat([
            torch.ones((1, 1), dtype=torch.long).to(device),
            inputs.attention_mask
        ], dim=1)

        # 3. Generate answer
        generated_ids = phi_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=50,
            do_sample=False
        )
        output_text = phi_tokenizer.decode(generated_ids[0], skip_special_tokens=True)

        # Remove the prompt from generated text
        return output_text.replace(prompt, "").strip()


In [None]:
from PIL import Image
from torchvision import transforms

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # match ViT input size
    transforms.ToTensor(),
])
image = Image.open("test_image.png").convert("RGB")
image_tensor = transform(image).unsqueeze(0)  # shape: [1, 3, H, W]

# Inference
question = "What is the name of the document?"
answer = run_inference(image_tensor, question)
print("Generated Answer:", answer)
