In [1]:
# Cell 1: Installs and Imports
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, Blip2Processor, Blip2Model, Blip2Config, Blip2VisionModel, Blip2QFormerModel, BlipImageProcessor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast
from PIL import Image
import json
import os
import random
import numpy as np
from tqdm.notebook import tqdm
import torch.nn.functional as F
import math
import time
import traceback


print(f"PyTorch Version: {torch.__version__}")
# print(f"Transformers Version: {transformers.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Capability: {torch.cuda.get_device_capability(0)}")

PyTorch Version: 2.6.0+cu124
CUDA Available: True
CUDA Device Name: NVIDIA GeForce RTX 4090
CUDA Capability: (8, 9)


In [2]:
# Cell 2: Configuration Class (CFG) - Updated with Early Stopping

class CFG:
    # --- Paths ---
    data_path = "./data/LANDMARK-IN-VIETNAM/"  # Adjust to where train.json, dev.json, test.json are located
    image_path = "./data/LANDMARK-IN-VIETNAM/"  # Base path for resolving image paths in JSON
    model_path = "./ViBLIP_QFormer_Trained"  # Output directory for saved models

    # --- Model Selection ---
    blip2_model_name = "Salesforce/blip2-flan-t5-xl"  # Uses ViT-B by default
    text_tokenizer_name = "vinai/phobert-base"  # Vietnamese tokenizer

    # --- Training Parameters ---
    seed = 42
    batch_size = 32  # Reduced for RTX 4090 stability
    num_workers = 8  # Adjusted for typical CPU
    qformer_lr = 1e-4
    weight_decay = 0.05
    patience = 2  # For LR scheduler
    factor = 0.8
    epochs = 1
    early_stop_patience = 3  # Stop if no improvement for 3 epochs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = True

    # --- Image/Text Parameters ---
    image_size = 224
    max_length = 77

    # --- Loss/Saving Parameters ---
    temperature = 0.07
    save_best_only = True
    metric_to_track = "val_itc_acc"
    mode = "max"

config = CFG()
os.makedirs(config.model_path, exist_ok=True)
print(f"--- ViBLIP Q-Former Training Configuration ---")
print(f"Device: {config.device}")
print(f"Base BLIP-2 Model: {config.blip2_model_name}")
print(f"Text Tokenizer: {config.text_tokenizer_name}")
print(f"Batch Size: {config.batch_size}")
print(f"Use AMP: {config.use_amp}")
print(f"Epochs: {config.epochs}")
print(f"Q-Former LR: {config.qformer_lr}")
print(f"Early Stop Patience: {config.early_stop_patience}")
print(f"Output Path: {config.model_path}")
print(f"Data Path (JSONs): {os.path.abspath(config.data_path)}")
print(f"Image Base Path: {os.path.abspath(config.image_path)}")
print(f"---------------------------------------------\n")
if config.data_path == "." and config.image_path == ".":
    print("WARNING: Using current directory for data and image paths. Ensure JSON files and images are present.")

--- ViBLIP Q-Former Training Configuration ---
Device: cuda
Base BLIP-2 Model: Salesforce/blip2-flan-t5-xl
Text Tokenizer: vinai/phobert-base
Batch Size: 32
Use AMP: True
Epochs: 1
Q-Former LR: 0.0001
Early Stop Patience: 3
Output Path: ./ViBLIP_QFormer_Trained
Data Path (JSONs): /home/researcher/huypq69/TuningModels/data/LANDMARK-IN-VIETNAM
Image Base Path: /home/researcher/huypq69/TuningModels/data/LANDMARK-IN-VIETNAM
---------------------------------------------



In [3]:
# Cell 3: Seeding for Reproducibility

def set_seed(seed=config.seed):
    print(f"Setting seed: {seed}")
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = True

set_seed()

Setting seed: 42


In [4]:
# Cell 4: Metric Calculation Utilities

class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0

    def update(self, val, count=1):
        if torch.is_tensor(val): val = val.item()
        if isinstance(val, (int, float)):
            self.sum += val * count
            self.count += count
            self.avg = self.sum / self.count if self.count != 0 else 0

    def __repr__(self):
        return f"{self.name}: {self.avg:.4f}"

def compute_recall_at_k(similarity_matrix, k, dim):
    n = similarity_matrix.shape[1-dim]
    if n == 0 or k <= 0: return 0.0
    effective_k = min(k, n)
    correct_count = 0
    top_k_indices = torch.topk(similarity_matrix, effective_k, dim=dim).indices
    ground_truth = torch.arange(n, device=similarity_matrix.device)

    if dim == 0: # I2T
        for img_idx in range(n):
            if ground_truth[img_idx] in top_k_indices[:, img_idx]: correct_count += 1
    elif dim == 1: # T2I
        for txt_idx in range(n):
            if ground_truth[txt_idx] in top_k_indices[txt_idx, :]: correct_count += 1
    else: raise ValueError("dim must be 0 or 1")
    return correct_count / n

def compute_metrics(image_embeddings, text_embeddings):
    image_embeddings = image_embeddings.float()
    text_embeddings = text_embeddings.float()

    sim_matrix = text_embeddings @ image_embeddings.T
    n = sim_matrix.shape[0]
    default_metrics = {
        "i2t_acc": 0.0, "t2i_acc": 0.0, "avg_acc": 0.0,
        "avg_cosine_sim": 0.0,
        "i2t_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0},
        "t2i_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0}
    }
    if n == 0: return default_metrics

    try:
        ground_truth = torch.arange(n, device=sim_matrix.device)
        i2t_preds = torch.argmax(sim_matrix, dim=0)
        t2i_preds = torch.argmax(sim_matrix, dim=1)
        i2t_acc = (i2t_preds == ground_truth).float().mean().item()
        t2i_acc = (t2i_preds == ground_truth).float().mean().item()
        avg_acc = (i2t_acc + t2i_acc) / 2
        diag_len = min(sim_matrix.shape[0], sim_matrix.shape[1])
        avg_cosine_sim = torch.diagonal(sim_matrix[:diag_len, :diag_len]).mean().item()

        i2t_recall = {}
        t2i_recall = {}
        for k in [1, 5, 10]:
            k_str = f"R@{k}"
            i2t_recall[k_str] = compute_recall_at_k(sim_matrix, k, dim=0)
            t2i_recall[k_str] = compute_recall_at_k(sim_matrix, k, dim=1)

        return {
            "i2t_acc": i2t_acc, "t2i_acc": t2i_acc, "avg_acc": avg_acc,
            "avg_cosine_sim": avg_cosine_sim,
            "i2t_recall": i2t_recall, "t2i_recall": t2i_recall
        }
    except Exception as e:
        print(f"Error during metric calculation: {e}")
        print(f"Shapes: ImgEmb={image_embeddings.shape}, TxtEmb={text_embeddings.shape}, SimMtx={sim_matrix.shape}")
        return default_metrics

print("Metric utilities defined.")

Metric utilities defined.


In [5]:
# Cell 5: Dataset Class Definition

class ImageCaptionDataset(Dataset):
    def __init__(self, json_path, image_base_path, tokenizer, image_processor, max_length):
        super().__init__()
        print(f"Loading data from: {os.path.abspath(json_path)}")
        self.data = []
        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                self.data = json.load(f)
            print(f"Loaded {len(self.data)} samples from {os.path.basename(json_path)}.")
        except FileNotFoundError:
            print(f"ERROR: JSON file not found at {json_path}")
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
        except Exception as e:
            print(f"Unexpected error loading JSON: {e}")

        self.image_base_path = image_base_path
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        if not os.path.isdir(self.image_base_path):
            print(f"WARNING: Image base path does not exist: {os.path.abspath(self.image_base_path)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if idx >= len(self.data): raise IndexError("Index out of bounds")
        item = self.data[idx]
        relative_image_path = item.get('image_path')
        caption = item.get('caption', '')
        image = None

        if relative_image_path:
            image_path = os.path.normpath(os.path.join(self.image_base_path, relative_image_path))
            try:
                img_pil = Image.open(image_path).convert('RGB')
                image_processed = self.image_processor(images=img_pil, return_tensors="pt")
                image = image_processed['pixel_values'].squeeze(0)
            except FileNotFoundError:
                image = None
            except Exception as e:
                print(f"Error processing image {image_path}: {e}")
                image = None

        if image is None:
            c = 3
            h = w = config.image_size
            image = torch.zeros((c, h, w))

        text_inputs = self.tokenizer(
            caption, padding='max_length', truncation=True,
            max_length=self.max_length, return_tensors='pt'
        )
        input_ids = text_inputs['input_ids'].squeeze(0)
        attention_mask = text_inputs['attention_mask'].squeeze(0)

        if input_ids.dim() == 0: input_ids = input_ids.unsqueeze(0)
        if attention_mask.dim() == 0: attention_mask = attention_mask.unsqueeze(0)

        return {
            "pixel_values": image,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "raw_caption": caption
        }

print("ImageCaptionDataset class defined.")

ImageCaptionDataset class defined.


In [6]:
# Cell 6: Model Loading & Freezing

model = None
blip_config_loaded = None
model_loaded = False

try:
    print(f"Loading BLIP-2 configuration for: {config.blip2_model_name}")
    blip_config_loaded = Blip2Config.from_pretrained(config.blip2_model_name)

    print(f"Loading BLIP-2 model: {config.blip2_model_name}")
    model_dtype = torch.float16 if config.use_amp and config.device == torch.device('cuda') else torch.float32
    model = Blip2Model.from_pretrained(
        config.blip2_model_name,
        config=blip_config_loaded,
        torch_dtype=model_dtype
    )

    print("Freezing Vision Model and Language Model parameters...")
    frozen_params_count = 0
    total_params = 0

    if hasattr(model, 'vision_model'):
        for param in model.vision_model.parameters():
            param.requires_grad = False
            frozen_params_count += param.numel()
        print(f"  Vision model frozen.")
    else:
        print("  Warning: model.vision_model not found.")

    if hasattr(model, 'language_model'):
        for param in model.language_model.parameters():
            param.requires_grad = False
            frozen_params_count += param.numel()
        print(f"  Language model frozen.")
    else:
        print("  Warning: model.language_model not found.")

    trainable_params_count = 0
    if hasattr(model, 'qformer'):
        print("Verifying Q-Former parameters are trainable...")
        model.qformer.train()
        for param in model.qformer.parameters():
            param.requires_grad = True
            trainable_params_count += param.numel()

        proj_layers_found = 0
        for proj_name in ['vision_proj', 'text_proj']:
            if hasattr(model, proj_name):
                layer = getattr(model, proj_name)
                if layer is not None and isinstance(layer, nn.Module):
                    print(f"  Verifying {proj_name} parameters are trainable...")
                    layer.train()
                    for param in layer.parameters():
                        param.requires_grad = True
                    trainable_params_count += sum(p.numel() for p in layer.parameters())
                    proj_layers_found += 1

        if proj_layers_found == 0:
            print("  Note: Projection layers (vision_proj, text_proj) not found.")

        model.to(config.device)
        model_loaded = True
        total_params = sum(p.numel() for p in model.parameters())
        print("\nModel components loaded successfully.")
        print(f"  Total parameters: ~{total_params / 1e6:.2f} M")
        print(f"  Frozen parameters: ~{frozen_params_count / 1e6:.2f} M")
        print(f"  Trainable parameters: ~{trainable_params_count / 1e6:.2f} M")

    else:
        print("ERROR: model.qformer not found!")

except Exception as e:
    print(f"ERROR loading model '{config.blip2_model_name}': {e}")
    traceback.print_exc()
    model = None

Loading BLIP-2 configuration for: Salesforce/blip2-flan-t5-xl
Loading BLIP-2 model: Salesforce/blip2-flan-t5-xl


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Freezing Vision Model and Language Model parameters...
  Vision model frozen.
  Language model frozen.
Verifying Q-Former parameters are trainable...
  Note: Projection layers (vision_proj, text_proj) not found.

Model components loaded successfully.
  Total parameters: ~3942.45 M
  Frozen parameters: ~3835.71 M
  Trainable parameters: ~105.14 M


In [7]:
# Cell 7: Data Setup (Tokenizer, Image Processor, Datasets, DataLoaders)

tokenizer = None
image_processor = None
train_loader = None
dev_loader = None
data_setup_ok = False

if model_loaded:
    try:
        print(f"Loading Tokenizer: {config.text_tokenizer_name}")
        tokenizer = AutoTokenizer.from_pretrained(config.text_tokenizer_name)
        print(f"  PhoBERT Tokenizer Vocab Size: {tokenizer.vocab_size}")

        print(f"Loading Image Processor for: {config.blip2_model_name}")
        image_processor = BlipImageProcessor.from_pretrained(config.blip2_model_name)
        if hasattr(image_processor, 'size'):
            processor_size = image_processor.size['height'] if isinstance(image_processor.size, dict) else image_processor.size
            if processor_size != config.image_size:
                print(f"  Updating config.image_size from {config.image_size} to {processor_size}")
                config.image_size = processor_size

        print("\nCreating datasets...")
        train_json = os.path.join(config.data_path, "train.json")
        dev_json = os.path.join(config.data_path, "val.json")

        train_dataset = ImageCaptionDataset(
            json_path=train_json, image_base_path=config.image_path,
            tokenizer=tokenizer, image_processor=image_processor, max_length=config.max_length
        )
        dev_dataset = ImageCaptionDataset(
            json_path=dev_json, image_base_path=config.image_path,
            tokenizer=tokenizer, image_processor=image_processor, max_length=config.max_length
        )

        if not train_dataset.data: raise ValueError("Training data failed to load.")
        if not dev_dataset.data: print("Warning: Validation data not loaded.")

        print("\nCreating dataloaders...")
        num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
        print(f"Using {num_workers} workers.")

        train_loader = DataLoader(
            train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=num_workers,
            pin_memory=True if config.device == torch.device("cuda") else False, drop_last=True
        )
        print(f"Train loader created ({len(train_loader)} batches).")

        if dev_dataset.data:
            dev_loader = DataLoader(
                dev_dataset, batch_size=config.batch_size, shuffle=False, num_workers=num_workers,
                pin_memory=True if config.device == torch.device("cuda") else False, drop_last=False
            )
            print(f"Validation loader created ({len(dev_loader)} batches).")

        data_setup_ok = True
        print("\nData setup complete.")

    except Exception as e:
        print(f"ERROR during data setup: {e}")
        traceback.print_exc()
else:
    print("Skipping data setup because model failed to load.")

Loading Tokenizer: vinai/phobert-base
  PhoBERT Tokenizer Vocab Size: 64000
Loading Image Processor for: Salesforce/blip2-flan-t5-xl

Creating datasets...
Loading data from: /home/researcher/huypq69/TuningModels/data/LANDMARK-IN-VIETNAM/train.json
Loaded 19844 samples from train.json.
Loading data from: /home/researcher/huypq69/TuningModels/data/LANDMARK-IN-VIETNAM/val.json
Loaded 5667 samples from val.json.

Creating dataloaders...
Using 8 workers.
Train loader created (620 batches).
Validation loader created (178 batches).

Data setup complete.


In [8]:
# Cell 8: Optimizer & Scheduler Setup

optimizer = None
lr_scheduler = None
optimizer_setup_ok = False

if model_loaded and data_setup_ok:
    print("\nSetting up optimizer and scheduler...")
    try:
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        param_count = sum(p.numel() for p in trainable_params)
        print(f"Found {len(trainable_params)} parameter tensors to optimize (~{param_count / 1e6:.2f} M).")

        if not trainable_params:
            raise ValueError("No trainable parameters found.")

        optimizer = optim.AdamW(trainable_params, lr=config.qformer_lr, weight_decay=config.weight_decay)
        print(f"Optimizer AdamW initialized with lr={config.qformer_lr:.1e}")

        lr_scheduler = ReduceLROnPlateau(
            optimizer, mode=config.mode, factor=config.factor, patience=config.patience, verbose=True
        )
        print(f"LR Scheduler ReduceLROnPlateau initialized (mode='{config.mode}')")
        optimizer_setup_ok = True

    except Exception as e:
        print(f"ERROR setting up optimizer/scheduler: {e}")
        traceback.print_exc()
else:
    print("Skipping optimizer setup due to previous errors.")


Setting up optimizer and scheduler...
Found 257 parameter tensors to optimize (~106.74 M).
Optimizer AdamW initialized with lr=1.0e-04
LR Scheduler ReduceLROnPlateau initialized (mode='max')


In [9]:
# Cell 9: Loss Function Definitions

def calculate_itc_loss(image_feats_norm, text_feats_norm, temperature):
    logits = (image_feats_norm @ text_feats_norm.T) / temperature
    logits = logits.float()
    batch_size = image_feats_norm.shape[0]
    if batch_size == 0: return torch.tensor(0.0, device=logits.device)
    labels = torch.arange(batch_size, device=logits.device)
    loss_i = F.cross_entropy(logits, labels)
    loss_t = F.cross_entropy(logits.T, labels)
    return (loss_i + loss_t) / 2.0

def calculate_itm_loss(model, outputs, batch_size, device):
    """Image-Text Matching Loss with Hard Negative Mining"""
    if batch_size == 0 or not hasattr(model, 'qformer') or not hasattr(outputs, 'qformer_outputs'):
        return torch.tensor(0.0, device=device)

    try:
        # Extract Q-Former multimodal features (CLS token)
        multimodal_feats = outputs.qformer_outputs.last_hidden_state[:, 0]  # [batch_size, hidden_size]

        # Hard negative mining: Find mismatched pairs with high ITC similarity
        image_feats = model.vision_proj(outputs.image_embeds) if hasattr(model, 'vision_proj') else outputs.image_embeds
        text_feats = model.text_proj(outputs.text_embeds) if hasattr(model, 'text_proj') else outputs.text_embeds
        image_feats_norm = F.normalize(image_feats, dim=-1)
        text_feats_norm = F.normalize(text_feats, dim=-1)
        sim_matrix = image_feats_norm @ text_feats_norm.T
        sim_matrix.fill_diagonal_(-float('inf'))  # Exclude true pairs
        hard_neg_indices = torch.argmax(sim_matrix, dim=1)  # [batch_size]

        # Create negative pairs by pairing images with hard-negative texts
        neg_input_ids = outputs.input_ids[hard_neg_indices]
        neg_attention_mask = outputs.attention_mask[hard_neg_indices]
        pixel_values = outputs.pixel_values

        with torch.no_grad():
            neg_outputs = model(
                pixel_values=pixel_values,
                input_ids=neg_input_ids,
                attention_mask=neg_attention_mask
            )
        neg_multimodal_feats = neg_outputs.qformer_outputs.last_hidden_state[:, 0]

        # Combine positive and negative features
        all_feats = torch.cat([multimodal_feats, neg_multimodal_feats], dim=0)  # [2*batch_size, hidden_size]
        itm_logits = model.itm_head(all_feats) if hasattr(model, 'itm_head') else nn.Linear(all_feats.size(-1), 2).to(device)(all_feats)

        # Labels: 1 for positive pairs, 0 for negative pairs
        itm_labels = torch.cat([torch.ones(batch_size, dtype=torch.long), torch.zeros(batch_size, dtype=torch.long)]).to(device)

        return F.cross_entropy(itm_logits, itm_labels)
    except Exception as e:
        print(f"Error in ITM loss calculation: {e}")
        return torch.tensor(0.0, device=device)

def calculate_itg_loss(model_outputs, target_ids, target_mask):
    return torch.tensor(0.0, device=target_ids.device)

print("Loss functions defined: ITC and ITM implemented, ITG is placeholder.") 

Loss functions defined: ITC and ITM implemented, ITG is placeholder.


In [10]:
# Cell 10: Training Loop - Updated with Early Stopping

ready_to_train = model_loaded and optimizer_setup_ok and data_setup_ok and train_loader is not None

if ready_to_train:
    print(f"\n=============== Starting Q-Former Training ===============")
    print(f"Epochs: {config.epochs}, Batch Size: {config.batch_size}, Device: {config.device}, AMP: {config.use_amp}")
    print(f"Tracking metric: '{config.metric_to_track}' (mode: {config.mode})")
    print(f"Early Stopping Patience: {config.early_stop_patience} epochs")

    best_val_metric = -float('inf') if config.mode == "max" else float('inf')
    early_stop_counter = 0
    history = {'train_loss': [], 'train_itc_loss': [], 'train_itm_loss': [], 'validation_results': []}
    start_train_time = time.time()
    scaler = GradScaler(enabled=config.use_amp)

    for epoch in range(config.epochs):
        epoch_start_time = time.time()
        print(f"\n--- Epoch {epoch+1}/{config.epochs} ---")

        model.train()
        train_loss_meter = AvgMeter(f"Train Total E{epoch+1}")
        train_itc_meter = AvgMeter(f"Train ITC E{epoch+1}")
        train_itm_meter = AvgMeter(f"Train ITM E{epoch+1}")

        progress_bar = tqdm(train_loader, desc=f"Training E{epoch+1}", leave=True, unit="batch")

        for step, batch in enumerate(progress_bar):
            optimizer.zero_grad(set_to_none=True)

            pixel_values = batch['pixel_values'].to(config.device)
            input_ids = batch['input_ids'].to(config.device)
            attention_mask = batch['attention_mask'].to(config.device)
            batch_size = pixel_values.size(0)
            if batch_size == 0: continue

            expected_dtype = torch.float16 if config.use_amp else torch.float32
            pixel_values = pixel_values.to(dtype=expected_dtype)

            with autocast(enabled=config.use_amp):
                try:
                    outputs = model(
                        pixel_values=pixel_values,
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        return_dict=True
                    )

                    image_embeds = outputs.image_embeds
                    text_embeds = outputs.text_embeds

                    image_feats = model.vision_proj(image_embeds) if hasattr(model, 'vision_proj') else image_embeds
                    text_feats = model.text_proj(text_embeds) if hasattr(model, 'text_proj') else text_embeds

                    image_feats_norm = F.normalize(image_feats, dim=-1)
                    text_feats_norm = F.normalize(text_feats, dim=-1)

                    loss_itc = calculate_itc_loss(image_feats_norm, text_feats_norm, config.temperature)
                    loss_itm = calculate_itm_loss(model, outputs, batch_size, config.device)
                    loss_itg = torch.tensor(0.0, device=config.device)

                    total_loss = loss_itc + loss_itm + loss_itg

                except Exception as forward_err:
                    print(f"Error at step {step}: {forward_err}")
                    traceback.print_exc()
                    continue

            if torch.isnan(total_loss) or torch.isinf(total_loss):
                print(f"Warning: NaN/Inf loss at step {step}. Skipping.")
                continue

            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss_meter.update(total_loss.item(), batch_size)
            train_itc_meter.update(loss_itc.item(), batch_size)
            train_itm_meter.update(loss_itm.item(), batch_size)
            progress_bar.set_postfix(loss=f"{train_loss_meter.avg:.4f}", itc=f"{train_itc_meter.avg:.4f}", itm=f"{train_itm_meter.avg:.4f}")

        history['train_loss'].append(train_loss_meter.avg)
        history['train_itc_loss'].append(train_itc_meter.avg)
        history['train_itm_loss'].append(train_itm_meter.avg)
        print(f"Epoch {epoch+1}: Train Loss={train_loss_meter.avg:.4f} (ITC={train_itc_meter.avg:.4f}, ITM={train_itm_meter.avg:.4f})")

        val_results = None
        current_val_metric = -float('inf') if config.mode == "max" else float('inf')

        if dev_loader:
            val_results = validate_qformer_epoch(model, dev_loader, config.device, epoch+1)
            history['validation_results'].append(val_results)
            current_val_metric = val_results.get(config.metric_to_track, current_val_metric)

            try:
                metric_for_scheduler = val_results.get(config.metric_to_track, val_results.get('loss', float('inf')))
                lr_scheduler.step(metric_for_scheduler)
                current_lrs = [group['lr'] for group in optimizer.param_groups]
                print(f"  Validation Metrics: {val_results}")
                print(f"  Current LR(s): {[f'{lr:.2e}' for lr in current_lrs]}")
            except Exception as e:
                print(f"Error stepping scheduler: {e}")
        else:
            history['validation_results'].append(None)

        is_best = False
        if dev_loader:
            if config.mode == "max" and current_val_metric > best_val_metric:
                is_best = True
                early_stop_counter = 0  # Reset counter on improvement
            elif config.mode == "min" and current_val_metric < best_val_metric:
                is_best = True
                early_stop_counter = 0
            else:
                early_stop_counter += 1
                print(f"  No improvement in {config.metric_to_track}. Early stop counter: {early_stop_counter}/{config.early_stop_patience}")
            if is_best:
                best_val_metric = current_val_metric

        model.cpu()
        save_dict = {
            'epoch': epoch + 1, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': lr_scheduler.state_dict(),
            'train_loss': train_loss_meter.avg, 'validation_results': val_results,
            'best_val_metric': best_val_metric, 'metric_tracked': config.metric_to_track,
            'config_blip2_model_name': config.blip2_model_name, 'config_text_tokenizer_name': config.text_tokenizer_name
        }
        model.to(config.device)

        try:
            if config.save_best_only and dev_loader:
                if is_best:
                    best_ckpt_path = os.path.join(config.model_path, "ViBLIP_QFormer_best.pt")
                    torch.save(save_dict, best_ckpt_path)
                    print(f"  Saved Best Model (Epoch {epoch+1}, {config.metric_to_track}={current_val_metric:.4f})")
            else:
                epoch_ckpt_path = os.path.join(config.model_path, f"ViBLIP_QFormer_epoch_{epoch+1}.pt")
                torch.save(save_dict, epoch_ckpt_path)
                print(f"  Saved Epoch {epoch+1} Checkpoint")
                if is_best and dev_loader:
                    best_ckpt_path = os.path.join(config.model_path, "ViBLIP_QFormer_best.pt")
                    torch.save(save_dict, best_ckpt_path)
                    print(f"  (Also marked as best model)")
        except Exception as e:
            print(f"ERROR saving checkpoint for epoch {epoch+1}: {e}")

        epoch_end_time = time.time()
        print(f"--- Epoch {epoch+1} Time: {epoch_end_time - epoch_start_time:.2f} seconds ---")

        # Early stopping check
        if dev_loader and early_stop_counter >= config.early_stop_patience:
            print(f"\nEarly stopping triggered after {early_stop_counter} epochs without improvement in {config.metric_to_track}.")
            break

    end_train_time = time.time()
    print(f"\n=============== Q-Former Training Finished ===============")
    print(f"Total Training Time: {(end_train_time - start_train_time)/60:.2f} minutes")

    try:
        final_model_path = os.path.join(config.model_path, 'ViBLIP_QFormer_final_epoch.pt')
        model.cpu()
        final_save_dict = {
            'epoch': epoch + 1, 'model_state_dict': model.state_dict(),
            'best_val_metric': best_val_metric, 'metric_tracked': config.metric_to_track
        }
        torch.save(final_save_dict, final_model_path)
        print(f"Final epoch model state saved to {final_model_path}")
        best_ckpt_path = os.path.join(config.model_path, "ViBLIP_QFormer_best.pt")
        if os.path.exists(best_ckpt_path):
            print(f"Best model saved to: {best_ckpt_path}")
    except Exception as e:
        print(f"ERROR saving final model state: {e}")
    print(f"========================================================")

else:
    print("ERROR: Prerequisites not met. Training loop skipped.")


Epochs: 1, Batch Size: 32, Device: cuda, AMP: True
Tracking metric: 'val_itc_acc' (mode: max)
Early Stopping Patience: 3 epochs

--- Epoch 1/1 ---


  scaler = GradScaler(enabled=config.use_amp)


Training E1:   0%|          | 0/620 [00:00<?, ?batch/s]

  with autocast(enabled=config.use_amp):
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [852,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [852,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [852,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [852,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [852,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [852,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/py

Error at step 0: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# Cell 11: Validation Loop Implementation

def validate_qformer_epoch(model, dataloader, device, epoch_num):
    print(f"--- Running Validation Epoch {epoch_num} ---")
    model.eval()
    val_loss_meter = AvgMeter(f"Val Total E{epoch_num}")
    val_itc_meter = AvgMeter(f"Val ITC E{epoch_num}")
    val_itm_meter = AvgMeter(f"Val ITM E{epoch_num}")
    val_itc_acc_meter = AvgMeter(f"Val ITC Acc E{epoch_num}")
    val_itm_acc_meter = AvgMeter(f"Val ITM Acc E{epoch_num}")

    progress_bar = tqdm(dataloader, desc=f"Validating E{epoch_num}", leave=True, unit="batch")

    all_image_feats = []
    all_text_feats = []

    with torch.no_grad():
        for batch in progress_bar:
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            batch_size = pixel_values.size(0)
            if batch_size == 0: continue

            expected_dtype = torch.float16 if config.use_amp else torch.float32
            pixel_values = pixel_values.to(dtype=expected_dtype)

            with autocast(enabled=config.use_amp):
                try:
                    outputs = model(
                        pixel_values=pixel_values,
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        return_dict=True
                    )

                    image_embeds = outputs.image_embeds
                    text_embeds = outputs.text_embeds

                    image_feats = model.vision_proj(image_embeds) if hasattr(model, 'vision_proj') else image_embeds
                    text_feats = model.text_proj(text_embeds) if hasattr(model, 'text_proj') else text_embeds

                    image_feats_norm = F.normalize(image_feats, dim=-1)
                    text_feats_norm = F.normalize(text_feats, dim=-1)

                    loss_itc = calculate_itc_loss(image_feats_norm, text_feats_norm, config.temperature)
                    loss_itm = calculate_itm_loss(model, outputs, batch_size, device)
                    total_loss = loss_itc + loss_itm

                    all_image_feats.append(image_feats_norm)
                    all_text_feats.append(text_feats_norm)

                    metrics = compute_metrics(image_feats_norm, text_feats_norm)
                    itc_acc = metrics['avg_acc']

                    multimodal_feats = outputs.qformer_outputs.last_hidden_state[:, 0]
                    itm_logits = model.itm_head(multimodal_feats) if hasattr(model, 'itm_head') else nn.Linear(multimodal_feats.size(-1), 2).to(device)(multimodal_feats)
                    itm_preds = torch.argmax(itm_logits, dim=-1)
                    itm_labels = torch.ones(batch_size, dtype=torch.long).to(device)
                    itm_acc = (itm_preds == itm_labels).float().mean().item()

                except Exception as e:
                    print(f"Error in validation batch: {e}")
                    continue

                val_loss_meter.update(total_loss.item(), batch_size)
                val_itc_meter.update(loss_itc.item(), batch_size)
                val_itm_meter.update(loss_itm.item(), batch_size)
                val_itc_acc_meter.update(itc_acc, batch_size)
                val_itm_acc_meter.update(itm_acc, batch_size)

                progress_bar.set_postfix(loss=f"{val_loss_meter.avg:.4f}", itc_acc=f"{val_itc_acc_meter.avg:.4f}", itm_acc=f"{val_itm_acc_meter.avg:.4f}")

    all_image_feats = torch.cat(all_image_feats, dim=0)
    all_text_feats = torch.cat(all_text_feats, dim=0)
    final_metrics = compute_metrics(all_image_feats, all_text_feats)

    results = {
        'loss': val_loss_meter.avg,
        'val_itc_acc': val_itc_acc_meter.avg,
        'val_itm_acc': val_itm_acc_meter.avg,
        'val_itc_loss': val_itc_meter.avg,
        'val_itm_loss': val_itm_meter.avg,
        'i2t_recall': final_metrics['i2t_recall'],
        't2i_recall': final_metrics['t2i_recall']
    }

    print(f"Validation Epoch {epoch_num} Results: Loss={results['loss']:.4f}, ITC Acc={results['val_itc_acc']:.4f}, ITM Acc={results['val_itm_acc']:.4f}")
    print(f"  I2T Recall: {results['i2t_recall']}")
    print(f"  T2I Recall: {results['t2i_recall']}")

    return results

print("Validation function implemented.")

Validation function implemented.


In [None]:
# Cell 12: Test Set Evaluation

print("\n=============== Starting Test Set Evaluation ===============")

test_json_path = os.path.join(config.data_path, "test.json")
evaluation_performed = False
model_loaded_for_test = False

if not (model_loaded and data_setup_ok):
    print("Skipping test evaluation: Model or data setup failed.")
elif not os.path.exists(test_json_path):
    print(f"Skipping test evaluation: Test JSON not found ({test_json_path}).")
else:
    print(f"Loading test data from: {test_json_path}")
    if 'tokenizer' in globals() and 'image_processor' in globals():
        test_dataset = ImageCaptionDataset(
            json_path=test_json_path, image_base_path=config.image_path,
            tokenizer=tokenizer, image_processor=image_processor, max_length=config.max_length
        )

        if test_dataset.data:
            num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
            test_loader = DataLoader(
                test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=num_workers,
                pin_memory=True if config.device == torch.device("cuda") else False, drop_last=False
            )
            print(f"Test loader created with {len(test_loader)} batches.")

            model_to_test = None
            try:
                blip_config_test = Blip2Config.from_pretrained(config.blip2_model_name)
                model_to_test = Blip2Model.from_pretrained(
                    config.blip2_model_name, config=blip_config_test,
                    torch_dtype=torch.float16 if config.use_amp else torch.float32
                )
                for param in model_to_test.vision_model.parameters(): param.requires_grad = False
                if hasattr(model_to_test, 'language_model'):
                    for param in model_to_test.language_model.parameters(): param.requires_grad = False
                print("Model structure for testing created.")

                best_model_path = os.path.join(config.model_path, "ViBLIP_QFormer_best.pt")
                if os.path.exists(best_model_path):
                    print(f"Loading best model weights from: {best_model_path}")
                    checkpoint = torch.load(best_model_path, map_location='cpu')
                    state_dict = checkpoint['model_state_dict']

                    if next(iter(state_dict)).startswith('module.'):
                        from collections import OrderedDict
                        state_dict = OrderedDict((k[7:], v) for k, v in state_dict.items())

                    load_result = model_to_test.load_state_dict(state_dict, strict=False)
                    print(f"Load Result: {load_result}")
                    model_to_test.to(config.device)
                    model_loaded_for_test = True
                    print("Loaded trained weights into model structure.")

                    print("\nRunning evaluation on test set...")
                    test_results = validate_qformer_epoch(model_to_test, test_loader, config.device, "Test")
                    evaluation_performed = True
                    print("\n--- Test Set Results ---")
                    metric_log_str = f"  Loss: {test_results['loss']:.4f}\n"
                    metric_log_str += f"  ITC Acc: {test_results['val_itc_acc']:.4f}\n"
                    metric_log_str += f"  ITM Acc: {test_results['val_itm_acc']:.4f}\n"
                    metric_log_str += f"  I2T Recall: {test_results['i2t_recall']}\n"
                    metric_log_str += f"  T2I Recall: {test_results['t2i_recall']}\n"
                    print(metric_log_str.strip())
                    print("------------------------")

                else:
                    print(f"ERROR: Best model checkpoint not found at {best_model_path}.")

            except Exception as e:
                print(f"ERROR during test setup or evaluation: {e}")
                traceback.print_exc()
        else:
            print("Could not load test data. Skipping test evaluation.")
    else:
        print("Skipping test evaluation: Tokenizer or Image Processor not available.")

if not evaluation_performed:
    print("Test set evaluation was not performed.")

print("\n================= Evaluation Finished =================")


Skipping test evaluation: Model or data setup failed.
Test set evaluation was not performed.

