Cross Attention for feature combination [Visual and Text Features]\
To replace simple concatenation of projected_embed and token embeddings with cross-attention between the image features and the text features, There is a need to introduce a cross-attention module.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import ViTModel, AutoConfig
from safetensors.torch import load_file
from tqdm import tqdm

# --- Device ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Load ViT (Frozen) ---
vit_config = AutoConfig.from_pretrained("./vit")
vit_model = ViTModel.from_pretrained("./vit", config=vit_config)
vit_model.load_state_dict(load_file("./vit/model.safetensors"))
vit_model = vit_model.to(device).eval()
for p in vit_model.parameters(): p.requires_grad = False

# --- Load Phi-2 (Frozen) ---
phi_tokenizer = AutoTokenizer.from_pretrained("./phi-2")
phi_tokenizer.pad_token = phi_tokenizer.eos_token
phi_model = AutoModelForCausalLM.from_pretrained(
    "./phi-2", torch_dtype=torch.float16, device_map={"": device}
)
phi_model.eval()
for p in phi_model.parameters(): p.requires_grad = False

# --- Dimensions ---
vit_dim = vit_model.config.hidden_size
phi_dim = phi_model.config.hidden_size

# --- Projection Layer (Trainable) ---
projector = nn.Linear(vit_dim, phi_dim).to(device)

# --- Cross-Attention Module (Trainable) ---
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.ln = nn.LayerNorm(embed_dim)

    def forward(self, text_embeds, image_embed):
        # image_embed: [B, 1, D] (key/value), text_embeds: [B, T, D] (query)
        attended, _ = self.attn(query=text_embeds, key=image_embed, value=image_embed)
        return self.ln(text_embeds + attended)  # residual connection

cross_attn = CrossAttention(phi_dim).to(device)

# --- Optimizer ---
optimizer = torch.optim.AdamW(list(projector.parameters()) + list(cross_attn.parameters()), lr=1e-4)

# --- Loss ---
loss_fn = nn.CrossEntropyLoss(ignore_index=phi_tokenizer.pad_token_id)

# --- Training Loop ---
def train_projection_layer(dataloader, projector, epochs=3):
    projector.train()
    cross_attn.train()

    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            images = batch["image"].to(device)
            questions = batch["question"]
            answers = batch["answer"]

            # Visual features
            with torch.no_grad():
                vit_out = vit_model(pixel_values=images)
                cls_token = vit_out.last_hidden_state[:, 0, :]  # [B, vit_dim]
            projected_image = projector(cls_token).unsqueeze(1).to(torch.float16)  # [B, 1, phi_dim]

            # Textual embeddings
            inputs = phi_tokenizer(
                [q + " " + phi_tokenizer.eos_token for q in questions],
                return_tensors="pt", padding=True, truncation=True
            ).to(device)

            labels = phi_tokenizer(
                answers, return_tensors="pt", padding=True, truncation=True
            ).input_ids.to(device)

            token_embeds = phi_model.model.embed_tokens(inputs.input_ids).to(torch.float16)  # [B, T, D]

            # Cross-attention: text attends to image
            fused_embeds = cross_attn(token_embeds, projected_image)  # [B, T, D]

            attention_mask = inputs.attention_mask
            labels = labels

            # Forward pass
            outputs = phi_model(
                inputs_embeds=fused_embeds,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")


Code with some modification in the loss function calculation

In [None]:
## Code with some modification in the loss function calculation

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import ViTModel, AutoConfig
from safetensors.torch import load_file
from tqdm import tqdm

# --- Device ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Load ViT (Frozen) ---
vit_config = AutoConfig.from_pretrained("./vit")
vit_model = ViTModel.from_pretrained("./vit", config=vit_config)
vit_model.load_state_dict(load_file("./vit/model.safetensors"))
vit_model = vit_model.to(device).eval()
for p in vit_model.parameters(): p.requires_grad = False

# --- Load Phi-2 (Frozen) ---
phi_tokenizer = AutoTokenizer.from_pretrained("./phi-2")
phi_tokenizer.pad_token = phi_tokenizer.eos_token
phi_model = AutoModelForCausalLM.from_pretrained(
    "./phi-2", torch_dtype=torch.float16, device_map={"": device}
)
phi_model.eval()
for p in phi_model.parameters(): p.requires_grad = False

# --- Dimensions ---
vit_dim = vit_model.config.hidden_size
phi_dim = phi_model.config.hidden_size

# --- Projection Layer (Trainable) ---
projector = nn.Linear(vit_dim, phi_dim).to(device)

# --- Cross-Attention Module (Trainable) ---
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.ln = nn.LayerNorm(embed_dim)

    def forward(self, text_embeds, image_embed):
        # image_embed: [B, 1, D] (key/value), text_embeds: [B, T, D] (query)
        attended, _ = self.attn(query=text_embeds, key=image_embed, value=image_embed)
        return self.ln(text_embeds + attended)

cross_attn = CrossAttention(phi_dim).to(device)

# --- Optimizer ---
optimizer = torch.optim.AdamW(list(projector.parameters()) + list(cross_attn.parameters()), lr=1e-4)

# --- Explicit CrossEntropy Loss ---
loss_fn = nn.CrossEntropyLoss(ignore_index=phi_tokenizer.pad_token_id)

# --- Training Loop ---
def train_projection_layer(dataloader, projector, epochs=3):
    projector.train()
    cross_attn.train()

    for epoch in range(epochs):
        total_loss = 0

        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            images = batch["image"].to(device)
            questions = batch["question"]
            answers = batch["answer"]

            # ViT image features
            with torch.no_grad():
                vit_out = vit_model(pixel_values=images)
                cls_token = vit_out.last_hidden_state[:, 0, :]
            projected_image = projector(cls_token).unsqueeze(1).to(torch.float16)  # [B, 1, D]

            # Tokenize prompt + answer
            prompts = [f"Question: {q.strip()} Answer:" for q in questions]
            full_texts = [f"{p} {a.strip()}" for p, a in zip(prompts, answers)]

            inputs = phi_tokenizer(full_texts, return_tensors="pt", padding=True, truncation=True).to(device)
            input_ids = inputs.input_ids.to(device)
            attention_mask = inputs.attention_mask.to(device)

            # Generate label masks to ignore prompt
            with phi_tokenizer.as_target_tokenizer():
                prompt_ids = phi_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

            labels = input_ids.clone()
            for i in range(labels.size(0)):
                prompt_len = (prompt_ids[i] != phi_tokenizer.pad_token_id).sum()
                labels[i, :prompt_len] = -100

            # Token embedding + cross-attention fusion
            token_embeds = phi_model.get_input_embeddings()(input_ids).to(torch.float16)  # [B, T, D]
            fused_embeds = cross_attn(token_embeds, projected_image)  # [B, T, D]

            # Forward pass to get logits
            logits = phi_model(inputs_embeds=fused_embeds, attention_mask=attention_mask).logits  # [B, T, V]

            # Compute loss manually
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")

Code with validation loop and ANLS metric
1. Validation loop

2. ANLS (Average Normalized Levenshtein Similarity) metric

3. Manual CrossEntropyLoss

4. Cross-attention fusion between ViT and Phi-2

In [None]:
###  Code with validation loop and ANLS metric
import numpy as np
import editdistance  # pip install editdistance

def normalized_levenshtein(pred, gt):
    pred, gt = pred.strip().lower(), gt.strip().lower()
    if len(gt) == 0:
        return 1.0 if len(pred) == 0 else 0.0
    dist = editdistance.eval(pred, gt)
    norm = dist / max(len(pred), len(gt))
    return 1 - norm

def compute_anls(preds, gts, threshold=0.5):
    scores = []
    for p, g in zip(preds, gts):
        sim = normalized_levenshtein(p, g)
        scores.append(sim if sim >= threshold else 0)
    return np.mean(scores)

In [1]:
def train_projection_layer_with_validation(train_loader, val_loader, projector, epochs=3):
    projector.train()
    cross_attn.train()

    for epoch in range(epochs):
        total_loss = 0
        total_anls = []

        for batch in tqdm(train_loader, desc=f"[Train] Epoch {epoch+1}"):
            images = batch["image"].to(device)
            questions = batch["question"]
            answers = batch["answer"]

            with torch.no_grad():
                vit_out = vit_model(pixel_values=images)
                cls_token = vit_out.last_hidden_state[:, 0, :]
            projected_image = projector(cls_token).unsqueeze(1).to(torch.float16)

            prompts = [f"Question: {q.strip()} Answer:" for q in questions]
            full_texts = [f"{p} {a.strip()}" for p, a in zip(prompts, answers)]

            inputs = phi_tokenizer(full_texts, return_tensors="pt", padding=True, truncation=True).to(device)
            input_ids = inputs.input_ids
            attention_mask = inputs.attention_mask

            with phi_tokenizer.as_target_tokenizer():
                prompt_ids = phi_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

            labels = input_ids.clone()
            for i in range(labels.size(0)):
                prompt_len = (prompt_ids[i] != phi_tokenizer.pad_token_id).sum()
                labels[i, :prompt_len] = -100

            token_embeds = phi_model.get_input_embeddings()(input_ids).to(torch.float16)
            fused_embeds = cross_attn(token_embeds, projected_image)

            logits = phi_model(inputs_embeds=fused_embeds, attention_mask=attention_mask).logits
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # --- In-batch ANLS ---
            with torch.no_grad():
                generated_ids = phi_model.generate(
                    inputs_embeds=fused_embeds,
                    attention_mask=attention_mask,
                    max_new_tokens=50,
                    do_sample=False
                )
                decoded_preds = phi_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
                cleaned_preds = [pred.replace(prompt, "").strip() for pred, prompt in zip(decoded_preds, prompts)]
                anls_score = compute_anls(cleaned_preds, answers)
                total_anls.append(anls_score)

        avg_loss = total_loss / len(train_loader)
        avg_anls = np.mean(total_anls)
        print(f"[Train] Epoch {epoch+1} | Loss: {avg_loss:.4f} | ANLS: {avg_anls:.4f}")

        # Run Validation
        run_validation(val_loader, projector)

def run_validation(val_loader, projector):
    projector.eval()
    cross_attn.eval()

    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="[Val]"):
            images = batch["image"].to(device)
            questions = batch["question"]
            answers = batch["answer"]

            vit_out = vit_model(pixel_values=images)
            cls_token = vit_out.last_hidden_state[:, 0, :]
            projected_image = projector(cls_token).unsqueeze(1).to(torch.float16)

            prompts = [f"Question: {q.strip()} Answer:" for q in questions]
            full_texts = [f"{p} {a.strip()}" for p, a in zip(prompts, answers)]

            inputs = phi_tokenizer(full_texts, return_tensors="pt", padding=True, truncation=True).to(device)
            input_ids = inputs.input_ids
            attention_mask = inputs.attention_mask

            with phi_tokenizer.as_target_tokenizer():
                prompt_ids = phi_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

            labels = input_ids.clone()
            for i in range(labels.size(0)):
                prompt_len = (prompt_ids[i] != phi_tokenizer.pad_token_id).sum()
                labels[i, :prompt_len] = -100

            token_embeds = phi_model.get_input_embeddings()(input_ids).to(torch.float16)
            fused_embeds = cross_attn(token_embeds, projected_image)

            logits = phi_model(inputs_embeds=fused_embeds, attention_mask=attention_mask).logits
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_loss += loss.item()

            generated_ids = phi_model.generate(
                inputs_embeds=fused_embeds,
                attention_mask=attention_mask,
                max_new_tokens=50,
                do_sample=False
            )
            decoded_preds = phi_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            cleaned_preds = [pred.replace(prompt, "").strip() for pred, prompt in zip(decoded_preds, prompts)]
            all_preds.extend(cleaned_preds)
            all_labels.extend(answers)

    avg_loss = total_loss / len(val_loader)
    anls = compute_anls(all_preds, all_labels)
    print(f"[Val] Loss: {avg_loss:.4f} | ANLS: {anls:.4f}")

In [2]:
train_projection_layer_with_validation(train_loader, val_loader, projector, epochs=3)


NameError: name 'train_loader' is not defined

Including Model checkpoint saving on best ANLS and progress plots
1. Model checkpointing based on the best validation ANLS

2. Progress plots for training loss and validation ANLS

3. Clean structure, preserving all previous features

In [None]:
import matplotlib.pyplot as plt
import os

In [None]:
def train_projection_layer_with_validation_and_checkpoint(
    train_loader,
    val_loader,
    projector,
    epochs=3,
    checkpoint_dir="./checkpoints"
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    projector.train()
    cross_attn.train()

    best_anls = -1
    history = {"train_loss": [], "val_anls": []}

    for epoch in range(epochs):
        total_loss = 0
        total_anls = []

        for batch in tqdm(train_loader, desc=f"[Train] Epoch {epoch+1}"):
            images = batch["image"].to(device)
            questions = batch["question"]
            answers = batch["answer"]

            with torch.no_grad():
                vit_out = vit_model(pixel_values=images)
                cls_token = vit_out.last_hidden_state[:, 0, :]
            projected_image = projector(cls_token).unsqueeze(1).to(torch.float16)

            prompts = [f"Question: {q.strip()} Answer:" for q in questions]
            full_texts = [f"{p} {a.strip()}" for p, a in zip(prompts, answers)]

            inputs = phi_tokenizer(full_texts, return_tensors="pt", padding=True, truncation=True).to(device)
            input_ids = inputs.input_ids
            attention_mask = inputs.attention_mask

            with phi_tokenizer.as_target_tokenizer():
                prompt_ids = phi_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

            labels = input_ids.clone()
            for i in range(labels.size(0)):
                prompt_len = (prompt_ids[i] != phi_tokenizer.pad_token_id).sum()
                labels[i, :prompt_len] = -100

            token_embeds = phi_model.get_input_embeddings()(input_ids).to(torch.float16)
            fused_embeds = cross_attn(token_embeds, projected_image)

            logits = phi_model(inputs_embeds=fused_embeds, attention_mask=attention_mask).logits
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                generated_ids = phi_model.generate(
                    inputs_embeds=fused_embeds,
                    attention_mask=attention_mask,
                    max_new_tokens=50,
                    do_sample=False
                )
                decoded_preds = phi_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
                cleaned_preds = [pred.replace(prompt, "").strip() for pred, prompt in zip(decoded_preds, prompts)]
                anls_score = compute_anls(cleaned_preds, answers)
                total_anls.append(anls_score)

        avg_loss = total_loss / len(train_loader)
        avg_anls = run_validation(val_loader, projector, save_best=True, best_anls=best_anls, checkpoint_dir=checkpoint_dir)

        history["train_loss"].append(avg_loss)
        history["val_anls"].append(avg_anls)

        print(f"[Train] Epoch {epoch+1} | Loss: {avg_loss:.4f} | Train ANLS: {np.mean(total_anls):.4f} | Val ANLS: {avg_anls:.4f}")

        if avg_anls > best_anls:
            best_anls = avg_anls
            print(f"✅ New best ANLS: {best_anls:.4f} — model saved.")

    # Plot loss and ANLS
    plot_metrics(history)


In [None]:
def run_validation(val_loader, projector, save_best=False, best_anls=-1, checkpoint_dir="./checkpoints"):
    projector.eval()
    cross_attn.eval()

    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="[Val]"):
            images = batch["image"].to(device)
            questions = batch["question"]
            answers = batch["answer"]

            vit_out = vit_model(pixel_values=images)
            cls_token = vit_out.last_hidden_state[:, 0, :]
            projected_image = projector(cls_token).unsqueeze(1).to(torch.float16)

            prompts = [f"Question: {q.strip()} Answer:" for q in questions]
            full_texts = [f"{p} {a.strip()}" for p, a in zip(prompts, answers)]

            inputs = phi_tokenizer(full_texts, return_tensors="pt", padding=True, truncation=True).to(device)
            input_ids = inputs.input_ids
            attention_mask = inputs.attention_mask

            with phi_tokenizer.as_target_tokenizer():
                prompt_ids = phi_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

            labels = input_ids.clone()
            for i in range(labels.size(0)):
                prompt_len = (prompt_ids[i] != phi_tokenizer.pad_token_id).sum()
                labels[i, :prompt_len] = -100

            token_embeds = phi_model.get_input_embeddings()(input_ids).to(torch.float16)
            fused_embeds = cross_attn(token_embeds, projected_image)

            logits = phi_model(inputs_embeds=fused_embeds, attention_mask=attention_mask).logits
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_loss += loss.item()

            generated_ids = phi_model.generate(
                inputs_embeds=fused_embeds,
                attention_mask=attention_mask,
                max_new_tokens=50,
                do_sample=False
            )
            decoded_preds = phi_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            cleaned_preds = [pred.replace(prompt, "").strip() for pred, prompt in zip(decoded_preds, prompts)]
            all_preds.extend(cleaned_preds)
            all_labels.extend(answers)

    avg_loss = total_loss / len(val_loader)
    anls = compute_anls(all_preds, all_labels)

    # Save checkpoint if best
    if save_best and anls > best_anls:
        from safetensors.torch import save_file
        save_file(projector.state_dict(), f"{checkpoint_dir}/projector_best_anls.safetensors")
        save_file(cross_attn.state_dict(), f"{checkpoint_dir}/cross_attn_best_anls.safetensors")

    print(f"[Val] Loss: {avg_loss:.4f} | ANLS: {anls:.4f}")
    return anls

In [None]:
##Plotting function

def plot_metrics(history):
    plt.figure(figsize=(10, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history["train_loss"], label="Train Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.grid()
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history["val_anls"], label="Validation ANLS", color="green")
    plt.xlabel("Epoch")
    plt.ylabel("ANLS")
    plt.title("Validation ANLS")
    plt.grid()
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
## To run the model
train_projection_layer_with_validation_and_checkpoint(train_loader, val_loader, projector, epochs=5)


Resume training from a checkpoint and save model in .safetensors

In [None]:
####    Add helper functions to save and load checkpoint:

from safetensors.torch import save_file, load_file

def save_checkpoint(projector, cross_attn, checkpoint_dir="./checkpoints"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    save_file(projector.state_dict(), f"{checkpoint_dir}/projector_best_anls.safetensors")
    save_file(cross_attn.state_dict(), f"{checkpoint_dir}/cross_attn_best_anls.safetensors")
    print(f"✅ Checkpoint saved to {checkpoint_dir}")

def load_checkpoint(projector, cross_attn, checkpoint_dir="./checkpoints"):
    projector.load_state_dict(load_file(f"{checkpoint_dir}/projector_best_anls.safetensors"))
    cross_attn.load_state_dict(load_file(f"{checkpoint_dir}/cross_attn_best_anls.safetensors"))
    print(f"🔁 Resumed from checkpoint in {checkpoint_dir}")

In [None]:
###   Add resume Option to Training Function......Update training function definition like:

def train_projection_layer_with_validation_and_checkpoint(
    train_loader,
    val_loader,
    projector,
    epochs=3,
    checkpoint_dir="./checkpoints",
    resume=False
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    if resume:
        load_checkpoint(projector, cross_attn, checkpoint_dir)

    projector.train()
    cross_attn.train()
    best_anls = -1
    history = {"train_loss": [], "val_anls": []}

In [None]:
######   Inside validation:

# Save if best
if save_best and anls > best_anls:
    best_anls = anls
    save_checkpoint(projector, cross_attn, checkpoint_dir)

In [None]:
###   Run training with Resume

train_projection_layer_with_validation_and_checkpoint(
    train_loader,
    val_loader,
    projector,
    epochs=5,
    checkpoint_dir="./checkpoints",
    resume=True  # <-- Resume from .safetensors
)