Code to develop a vision language model using ViT as vision encoder and Phi-2 as LLM

import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers import ViTModel
from safetensors.torch import load_file

# ---- Configuration ----
device = "cuda" if torch.cuda.is_available() else "cpu"
image_path = "image.jpg"
vit_model_path = "./vit"
phi2_model_path = "./phi-2"
prompt = "Describe this image in detail."

# ---- 1. Load ViT Encoder (from safetensors) ----
vit_config = AutoConfig.from_pretrained(vit_model_path)
vit_model = ViTModel.from_pretrained(vit_model_path, config=vit_config)
vit_model.load_state_dict(load_file(f"{vit_model_path}/model.safetensors"))
vit_model = vit_model.to(device).eval()

# ---- 2. Load Phi-2 from Local Directory ----
phi_tokenizer = AutoTokenizer.from_pretrained(phi2_model_path)
phi_model = AutoModelForCausalLM.from_pretrained(
    phi2_model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)

# ---- 3. Vision-Language Wrapper ----
class VisionLanguageModel(nn.Module):
    def __init__(self, vit, phi2, vit_dim, phi_dim):
        super().__init__()
        self.vit = vit
        self.phi2 = phi2
        self.projector = nn.Linear(vit_dim, phi_dim)

    def forward(self, image_tensor, input_ids, attention_mask):
        with torch.no_grad():
            vision_out = self.vit(pixel_values=image_tensor)
            cls_token = vision_out.last_hidden_state[:, 0, :]  # shape: [B, vit_dim]

        vision_embed = self.projector(cls_token).unsqueeze(1)  # [B, 1, phi_dim]
        token_embeds = self.phi2.transformer.wte(input_ids)
        full_embeds = torch.cat([vision_embed, token_embeds], dim=1)

        attention_mask = torch.cat(
            [torch.ones((attention_mask.shape[0], 1), dtype=attention_mask.dtype).to(device), attention_mask],
            dim=1
        )

        out = self.phi2(inputs_embeds=full_embeds, attention_mask=attention_mask, use_cache=True)
        return out

# ---- 4. Preprocess Image ----
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(device)

# ---- 5. Prepare Text Inputs ----
input_ids = phi_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
attention_mask = torch.ones_like(input_ids)

# ---- 6. Build Wrapper and Run Inference ----
vit_dim = vit_model.config.hidden_size
phi_dim = phi_model.config.hidden_size  # Phi-2 hidden size is 2048

wrapper = VisionLanguageModel(vit_model, phi_model, vit_dim, phi_dim).to(device)

# ---- 7. Generate Tokens ----
with torch.no_grad():
    out = wrapper(image_tensor, input_ids, attention_mask)
    next_token = torch.argmax(out.logits[:, -1:, :], dim=-1)
    for _ in range(50):  # change to desired max tokens
        input_ids = torch.cat([input_ids, next_token], dim=1)
        attention_mask = torch.ones_like(input_ids)
        out = wrapper(image_tensor, input_ids, attention_mask)
        next_token = torch.argmax(out.logits[:, -1:, :], dim=-1)

# ---- 8. Decode and Print Result ----
caption = phi_tokenizer.decode(input_ids[0], skip_special_tokens=True)
print("Image Summary:", caption)


In [None]:
##  Giving error 'PhiForCausalLM' object has no attribute 'transformer' in the vision language wrapper section forward function. Just change the error part and leave as it is
#Chnage below line to 
token_embeds = self.phi2.transformer.wte(input_ids)
#To
token_embeds = self.phi2.model.embed_tokens(input_ids)

In [None]:
###  Training  code

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import ViTModel, AutoConfig
from safetensors.torch import load_file
from tqdm import tqdm

# --- Device ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Load ViT (Frozen) ---
vit_config = AutoConfig.from_pretrained("./vit")
vit_model = ViTModel.from_pretrained("./vit", config=vit_config)
vit_model.load_state_dict(load_file("./vit/model.safetensors"))
vit_model = vit_model.to(device).eval()
for p in vit_model.parameters(): p.requires_grad = False

# --- Load Phi-2 (Frozen) ---
phi_tokenizer = AutoTokenizer.from_pretrained("./phi-2")
phi_tokenizer.pad_token = phi_tokenizer.eos_token
phi_model = AutoModelForCausalLM.from_pretrained(
    "./phi-2", torch_dtype=torch.float16, device_map={"": device}
)
phi_model.eval()
for p in phi_model.parameters(): p.requires_grad = False

# --- Projection Layer (Trainable) ---
vit_dim = vit_model.config.hidden_size
phi_dim = phi_model.config.hidden_size
projector = nn.Linear(vit_dim, phi_dim).to(device)

# --- Optimizer ---
optimizer = torch.optim.AdamW(projector.parameters(), lr=1e-4)

# --- Loss ---
loss_fn = nn.CrossEntropyLoss(ignore_index=phi_tokenizer.pad_token_id)

# --- Training Loop ---
def train_projection_layer(dataloader, projector,epochs=3):
    projector.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            images = batch["image"].to(device)  # shape: [B, 3, H, W]
            questions = batch["question"]
            answers = batch["answer"]

            # Prepare image embeddings
            with torch.no_grad():
                vit_out = vit_model(pixel_values=images)
                cls_token = vit_out.last_hidden_state[:, 0, :]
            projected_embed = projector(cls_token).unsqueeze(1).to(torch.float16)

            # Tokenize prompt + answer
            inputs = phi_tokenizer([q + " " + phi_tokenizer.eos_token for q in questions],
                                   return_tensors="pt", padding=True, truncation=True).to(device)
            labels = phi_tokenizer(answers, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

            token_embeds = phi_model.model.embed_tokens(inputs.input_ids).to(torch.float16)
            inputs_embeds = torch.cat([projected_embed, token_embeds], dim=1)

            # Adjust attention and labels
            attention_mask = torch.cat([
                torch.ones((inputs.input_ids.shape[0], 1), dtype=torch.long).to(device),
                inputs.attention_mask
            ], dim=1)

            labels = torch.cat([
                torch.full((labels.shape[0], 1), fill_value=-100).to(device),
                labels
            ], dim=1)

            # Forward pass
            output = phi_model(inputs_embeds=inputs_embeds,
                               attention_mask=attention_mask,
                               labels=labels)

            loss = output.loss
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")


In [None]:
######   Function with mask input before Answer

def train_projection_layer(dataloader, projector, epochs=3):
    projector.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            images = batch["image"].to(device)  # [B, 3, H, W]
            questions = batch["question"]
            answers = batch["answer"]

            # ---- ViT: image to embedding ----
            with torch.no_grad():
                vit_out = vit_model(pixel_values=images)
                cls_token = vit_out.last_hidden_state[:, 0, :]
            projected_embed = projector(cls_token).unsqueeze(1).to(torch.float16)  # [B, 1, D]

            # ---- Construct full prompt ----
            prompts = [f"Question: {q.strip()} Answer:" for q in questions]
            full_texts = [f"{prompt} {ans.strip()}" for prompt, ans in zip(prompts, answers)]

            # ---- Tokenize full prompt ----
            inputs = phi_tokenizer(full_texts, return_tensors="pt", padding=True, truncation=True)
            input_ids = inputs.input_ids.to(device)
            attention_mask = inputs.attention_mask.to(device)

            # ---- Tokenize prompt alone to get masking index ----
            with phi_tokenizer.as_target_tokenizer():
                prompt_ids = phi_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

            # ---- Build label tensor ----
            labels = input_ids.clone()
            for i in range(labels.shape[0]):
                prompt_len = (prompt_ids[i] != phi_tokenizer.pad_token_id).sum()
                labels[i, :prompt_len] = -100  # mask prompt tokens from loss

            # ---- Embed text ----
            token_embeds = phi_model.model.embed_tokens(input_ids).to(torch.float16)  # [B, T, D]

            # ---- Concatenate <image> + tokens ----
            inputs_embeds = torch.cat([projected_embed, token_embeds], dim=1)
            attention_mask = torch.cat([
                torch.ones((attention_mask.shape[0], 1), dtype=torch.long).to(device),
                attention_mask
            ], dim=1)
            labels = torch.cat([
                torch.full((labels.shape[0], 1), fill_value=-100).to(device),  # mask image token
                labels
            ], dim=1)

            # ---- Forward ----
            outputs = phi_model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")


*Updated code Using ANLS metric* 
1. Trains only the projection layer.
2. Formats input as "<image> Question: ... Answer: ..." and masks prompt + image during training loss computation.

In [1]:


import numpy as np
import editdistance  # Install via: pip install editdistance

def normalized_levenshtein(pred, gt):
    """Compute Normalized Levenshtein Similarity"""
    pred, gt = pred.strip().lower(), gt.strip().lower()
    if len(gt) == 0:
        return 1.0 if len(pred) == 0 else 0.0
    dist = editdistance.eval(pred, gt)
    norm = dist / max(len(pred), len(gt))
    return 1 - norm

def compute_anls(preds, gts, threshold=0.5):
    """Average Normalized Levenshtein Similarity (ANLS)"""
    scores = []
    for p, g in zip(preds, gts):
        sim = normalized_levenshtein(p, g)
        scores.append(sim if sim >= threshold else 0)
    return np.mean(scores)

def train_projection_layer(dataloader, projector, epochs=3):
    projector.train()
    for epoch in range(epochs):
        total_loss = 0
        total_anls = []

        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            images = batch["image"].to(device)
            questions = batch["question"]
            answers = batch["answer"]

            # --- Image Embedding ---
            with torch.no_grad():
                vit_out = vit_model(pixel_values=images)
                cls_token = vit_out.last_hidden_state[:, 0, :]
            projected_embed = projector(cls_token).unsqueeze(1).to(torch.float16)

            # --- Prompt Construction ---
            prompts = [f"Question: {q.strip()} Answer:" for q in questions]
            full_texts = [f"{p} {a.strip()}" for p, a in zip(prompts, answers)]

            inputs = phi_tokenizer(full_texts, return_tensors="pt", padding=True, truncation=True)
            input_ids = inputs.input_ids.to(device)
            attention_mask = inputs.attention_mask.to(device)

            with phi_tokenizer.as_target_tokenizer():
                prompt_ids = phi_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

            # --- Labels Masking ---
            labels = input_ids.clone()
            for i in range(labels.shape[0]):
                prompt_len = (prompt_ids[i] != phi_tokenizer.pad_token_id).sum()
                labels[i, :prompt_len] = -100

            # --- Embedding + Concatenation ---
            token_embeds = phi_model.model.embed_tokens(input_ids).to(torch.float16)
            inputs_embeds = torch.cat([projected_embed, token_embeds], dim=1)

            attention_mask = torch.cat([
                torch.ones((attention_mask.shape[0], 1), dtype=torch.long).to(device),
                attention_mask
            ], dim=1)
            labels = torch.cat([
                torch.full((labels.shape[0], 1), fill_value=-100).to(device),
                labels
            ], dim=1)

            # --- Forward ---
            outputs = phi_model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # --- Decode and Evaluate ANLS ---
            with torch.no_grad():
                generated_ids = phi_model.generate(
                    inputs_embeds=inputs_embeds,
                    attention_mask=attention_mask,
                    max_new_tokens=50,
                    do_sample=False
                )
                decoded_preds = phi_tokenizer.batch_decode(generated_ids[:, 1:], skip_special_tokens=True)
                decoded_preds = [pred.replace(prompt, "").strip() for pred, prompt in zip(decoded_preds, prompts)]

                anls_score = compute_anls(decoded_preds, answers)
                total_anls.append(anls_score)

        avg_loss = total_loss / len(dataloader)
        avg_anls = np.mean(total_anls)
        print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | ANLS: {avg_anls:.4f}")


######  Output format:   Epoch 1 | Loss: 1.2345 | ANLS: 0.8421


SyntaxError: invalid syntax (654833717.py, line 1)

Here's the inference code that uses the trained projector and computes the Average Normalized Levenshtein Similarity (ANLS):

In [2]:
import torch
import numpy as np
import editdistance
from tqdm import tqdm

def normalized_levenshtein(pred, gt):
    pred, gt = pred.strip().lower(), gt.strip().lower()
    if len(gt) == 0:
        return 1.0 if len(pred) == 0 else 0.0
    dist = editdistance.eval(pred, gt)
    norm = dist / max(len(pred), len(gt))
    return 1 - norm

def compute_anls(preds, gts, threshold=0.5):
    scores = []
    for p, g in zip(preds, gts):
        sim = normalized_levenshtein(p, g)
        scores.append(sim if sim >= threshold else 0)
    return np.mean(scores)

def run_inference(dataloader, vit_model, projector, phi_model, phi_tokenizer):
    vit_model.eval()
    projector.eval()
    phi_model.eval()

    all_preds, all_gts, all_prompts = [], [], []

    for batch in tqdm(dataloader, desc="Inference"):
        images = batch["image"].to(device)
        questions = batch["question"]
        answers = batch["answer"]

        with torch.no_grad():
            # Image embedding
            vit_out = vit_model(pixel_values=images)
            cls_token = vit_out.last_hidden_state[:, 0, :]
            projected_embed = projector(cls_token).unsqueeze(1).to(torch.float16)

            # Prompt: "<image> Question: ... Answer:"
            prompts = [f"Question: {q.strip()} Answer:" for q in questions]
            inputs = phi_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
            input_ids = inputs.input_ids.to(device)
            attention_mask = inputs.attention_mask.to(device)

            token_embeds = phi_model.model.embed_tokens(input_ids).to(torch.float16)
            inputs_embeds = torch.cat([projected_embed, token_embeds], dim=1)
            attention_mask = torch.cat([
                torch.ones((attention_mask.shape[0], 1), dtype=torch.long).to(device),
                attention_mask
            ], dim=1)

            # Generate answer
            generated_ids = phi_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                max_new_tokens=50,
                do_sample=False
            )
            decoded_preds = phi_tokenizer.batch_decode(generated_ids[:, 1:], skip_special_tokens=True)
            cleaned_preds = [pred.replace(prompt, "").strip() for pred, prompt in zip(decoded_preds, prompts)]

            all_preds.extend(cleaned_preds)
            all_gts.extend(answers)
            all_prompts.extend(prompts)

    # ANLS Score
    anls_score = compute_anls(all_preds, all_gts)
    print(f"\nAverage ANLS: {anls_score:.4f}")
    return all_preds, all_gts, all_prompts
###  Requirements;
###   batch["image"], batch["question"], batch["answer"] must be provided by the dataloader.
###### projector, vit_model, phi_model, and phi_tokenizer must be loaded and moved to correct device.

ModuleNotFoundError: No module named 'editdistance'