In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# Import BLIP specific components + AutoModel/Tokenizer
from transformers import BlipVisionModel, BlipConfig, BlipImageProcessor # Vision parts
from transformers import AutoModel, AutoTokenizer, AutoConfig # Text parts (PhoBERT)
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
import json
import os
import random
import numpy as np
try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm
import torch.nn.functional as F
import math
import time # For timing epochs
import transformers

print(f"PyTorch Version: {torch.__version__}")
print(f"Transformers Version: {transformers.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")

PyTorch Version: 2.6.0+cu124
Transformers Version: 4.50.0
CUDA Available: True
CUDA Device Name: NVIDIA GeForce RTX 4090


In [2]:
# === Cell 2: Configuration Class (CFG) === MODIFIED ===
class CFG:
    # --- Paths ---
    data_path = "./json_data/"
    image_path = "./data/OpenViVQA-dataset/"

    # Output directory for saved models
    model_path = "./ViBLIP_vivqa"

    # --- Model Selection ---
    # Source for BLIP Vision Model components
    selected_vision_source = "Salesforce/blip-image-captioning-large"
    # Vietnamese Text Model
    selected_text_model = "vinai/phobert-large"
    text_tokenizer_name = selected_text_model # Use PhoBERT's tokenizer

    # --- Model parameters ---
    blip_vision_model_name = selected_vision_source # For loading vision model
    blip_image_processor_name = selected_vision_source # For loading image processor

    @property
    def text_embedding(self): # PhoBERT-Large output dim
        return 1024
    @property
    def vision_embedding(self): # BLIP-Large (ViT-L) output dim
        # Check config if unsure: BlipConfig.from_pretrained(self.selected_vision_source).vision_config.hidden_size
        return 1024

    projection_dim = 256 # Shared latent space dimension

    # --- Training parameters ---
    seed = 42
    batch_size = 8  # START MODERATE (16-32) and adjust based on GPU memory
    num_workers = 12  # Adjust based on system
    # Learning rates (tune these)
    projection_lr = 1e-4
    vision_encoder_lr = 1e-5 # LR for the BLIP vision encoder base
    text_encoder_lr = 1e-5   # LR for the PhoBERT base
    weight_decay = 1e-3
    patience = 3 # Scheduler patience
    factor = 0.8 # Scheduler reduction factor
    epochs = 1 # Adjust as needed
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = False # Disabled for simplicity, can be re-enabled if needed

    # --- Image/Text parameters ---
    # Image size determined by BlipImageProcessor
    max_length = 256 # PhoBERT's default max length

    # --- Loss/Saving parameters ---
    temperature = 0.07
    learnable_temperature = True
    save_best_only = True
    metric_to_track = "avg_R@1" # Defaulting to Recall@1 average
    mode = "max"
    early_stopping_patience = 5
    early_stopping_min_delta = 0.001
    accumulation_steps = 1 # Example: use if batch_size needs to be very small

# --- Instantiate Config and Create Output Dir ---
config = CFG()
os.makedirs(config.model_path, exist_ok=True)
print(f"Using device: {config.device}")
print(f"Model output path: {config.model_path}")
print(f"Selected Vision Source: {config.selected_vision_source}")
print(f"Selected Text Model: {config.selected_text_model}")
print(f"Image base path (for resolving paths in JSON): {os.path.abspath(config.image_path)}")

Using device: cuda
Model output path: ./ViBLIP_vivqa
Selected Vision Source: Salesforce/blip-image-captioning-large
Selected Text Model: vinai/phobert-large
Image base path (for resolving paths in JSON): /home/researcher/huypq69/TuningModels/data/OpenViVQA-dataset


In [3]:
# === Cell 3: Seeding for Reproducibility ===
def set_seed(seed=config.seed):
    print(f"Setting seed: {seed}")
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False

set_seed()

Setting seed: 42


In [4]:
# === Cell 4: Metric Calculation Utilities ===

class AvgMeter:
    """Computes and stores the average and current value"""
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0

    def update(self, val, count=1):
        # Ensure val is a scalar number before adding to sum
        if torch.is_tensor(val):
             val = val.item() # Convert tensor to Python number
        if isinstance(val, (int, float)):
            self.sum += val * count
            self.count += count
            self.avg = self.sum / self.count if self.count != 0 else 0
        # else:
            # Optionally print a warning if the value is not usable
            # print(f"Warning: Cannot update AvgMeter '{self.name}' with value type {type(val)}")


    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def compute_recall_at_k(similarity_matrix, k, dim):
    n = similarity_matrix.shape[1-dim]
    correct_count = 0
    top_k_indices = torch.topk(similarity_matrix, k, dim=dim).indices
    ground_truth = torch.arange(n, device=similarity_matrix.device)

    if dim == 0: # I2T
        for img_idx in range(n):
            if ground_truth[img_idx] in top_k_indices[:, img_idx]:
                correct_count += 1
    elif dim == 1: # T2I
        for txt_idx in range(n):
             if ground_truth[txt_idx] in top_k_indices[txt_idx, :]:
                correct_count += 1
    else: raise ValueError("dim must be 0 or 1")
    return correct_count / n if n > 0 else 0.0

def compute_metrics(image_embeddings, text_embeddings):
    sim_matrix = text_embeddings @ image_embeddings.T
    sim_matrix = sim_matrix.float() # Ensure float for calculations
    n = sim_matrix.shape[0]
    if n == 0:
        # Return default zero metrics for empty batch
        return {
            "i2t_acc": 0.0, "t2i_acc": 0.0, "avg_acc": 0.0,
            "avg_cosine_sim": 0.0,
            "i2t_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0},
            "t2i_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0}
        }

    ground_truth = torch.arange(n, device=sim_matrix.device)
    i2t_preds = torch.argmax(sim_matrix, dim=0)
    t2i_preds = torch.argmax(sim_matrix, dim=1)
    i2t_acc = (i2t_preds == ground_truth).float().mean().item()
    t2i_acc = (t2i_preds == ground_truth).float().mean().item()
    avg_acc = (i2t_acc + t2i_acc) / 2
    avg_cosine_sim = torch.diagonal(sim_matrix).mean().item()

    i2t_recall = {}
    t2i_recall = {}
    recall_k_values = [k for k in [1, 5, 10] if k <= n]
    for k in recall_k_values:
        i2t_recall[f"R@{k}"] = compute_recall_at_k(sim_matrix, k, dim=0)
        t2i_recall[f"R@{k}"] = compute_recall_at_k(sim_matrix, k, dim=1)

    # Ensure all keys R@1, R@5, R@10 exist even if k>n
    for k in [1, 5, 10]:
        k_str = f"R@{k}"
        if k_str not in i2t_recall: i2t_recall[k_str] = 0.0
        if k_str not in t2i_recall: t2i_recall[k_str] = 0.0

    metrics = {
        "i2t_acc": i2t_acc, "t2i_acc": t2i_acc, "avg_acc": avg_acc,
        "avg_cosine_sim": avg_cosine_sim,
        "i2t_recall": i2t_recall, "t2i_recall": t2i_recall
    }
    return metrics

print("Metric utilities defined.")

Metric utilities defined.


In [5]:
# === Cell 5: Dataset Class Definition (Separate Tokenizer/Processor) === MODIFIED ===
class CustomImageCaptionDataset(Dataset):
    def __init__(self, json_path, image_base_path, tokenizer, image_processor, max_length): # Changed arguments
        super().__init__()
        print(f"Attempting to load data from: {os.path.abspath(json_path)}")
        try:
            with open(json_path, 'r', encoding='utf-8') as f: self.data = json.load(f)
        except Exception as e: print(f"ERROR loading JSON {json_path}: {e}"); self.data = []

        print(f"Found {len(self.data)} samples in {os.path.basename(json_path)}.")
        self.image_base_path = image_base_path
        self.tokenizer = tokenizer
        self.image_processor = image_processor # Store image processor
        self.max_length = max_length
        try: # Get image size
             # BlipImageProcessor uses 'size' dictionary with 'shortest_edge' or 'height'/'width'
             if isinstance(image_processor.size, dict):
                 self.img_size = image_processor.size.get('shortest_edge', image_processor.size.get('height', 224))
             else: # Older versions might just have an int/tuple
                 self.img_size = image_processor.size
                 if isinstance(self.img_size, (tuple, list)): self.img_size = self.img_size[0]
        except AttributeError:
             print("Warning: Could not determine image size from processor, defaulting to 224.")
             self.img_size = 224
        print(f"Using image size: {self.img_size}x{self.img_size}")
        if not os.path.isdir(self.image_base_path): print(f"WARNING: Image base path does not exist: {os.path.abspath(self.image_base_path)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if idx >= len(self.data): raise IndexError("Index out of bounds")
        item = self.data[idx]
        relative_image_path = item.get('image_path')
        # Expecting a list of captions, take the first one
        captions = item.get('caption', [])
        caption = captions[0] if captions else ""

        # Load Image
        image = None
        pixel_values = torch.zeros((3, self.img_size, self.img_size)) # Dummy tensor
        if relative_image_path:
            image_path = os.path.normpath(os.path.join(self.image_base_path, relative_image_path))
            try:
                image = Image.open(image_path).convert('RGB')
            except FileNotFoundError: print(f"Warning: Img not found: {image_path}. Using dummy for idx {idx}.")
            except Exception as e: print(f"Warning: Error loading image {image_path}: {e}. Using dummy for idx {idx}.")
        else: print(f"Warning: Missing 'image_path' for idx {idx}. Using dummy.")
        if image is None: image = Image.new('RGB', (self.img_size, self.img_size))

        # Process Image using image_processor
        try:
            image_inputs = self.image_processor(images=image, return_tensors="pt")
            pixel_values = image_inputs['pixel_values'].squeeze(0) # Remove batch dim
        except Exception as e:
            print(f"Error processing image idx {idx}: {e}")
            # pixel_values remains the dummy tensor initialized earlier

        # Process Text using tokenizer
        try:
            text_inputs = self.tokenizer(
                caption, padding='max_length', truncation=True,
                max_length=self.max_length, return_tensors='pt'
            )
            input_ids = text_inputs['input_ids'].squeeze(0) # Remove batch dim
            attention_mask = text_inputs['attention_mask'].squeeze(0) # Remove batch dim
        except Exception as e:
            print(f"Error processing text '{caption}' idx {idx}: {e}")
            input_ids = torch.zeros(self.max_length, dtype=torch.long)
            attention_mask = torch.zeros(self.max_length, dtype=torch.long)

        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

print("CustomImageCaptionDataset class defined.")

CustomImageCaptionDataset class defined.


In [6]:
# === Cell 6: Model Definition (PhoBERT + BLIP Vision) === MODIFIED ===

class ImageEncoder(nn.Module):
    """Encodes images using BLIP's Vision Model (Large)."""
    def __init__(self, config_train, pretrained=True): # pretrained flag might not be directly used by from_pretrained
        super().__init__()
        self.config_train = config_train
        print(f"Initializing BLIP Vision Encoder from: {config_train.blip_vision_model_name}")
        # Load only the vision part
        self.vision_model = BlipVisionModel.from_pretrained(config_train.blip_vision_model_name)
        self.input_features = config_train.vision_embedding # Should be 1024 for ViT-L

        self.projection = nn.Linear(self.input_features, config_train.projection_dim, bias=False)
        print(f"  Added projection head: {self.input_features} -> {config_train.projection_dim}")

    def forward(self, pixel_values):
        vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=True)
        # BLIP Vision typically uses pooler_output
        image_features = vision_outputs.pooler_output
        projected_features = self.projection(image_features)
        projected_features = F.normalize(projected_features, p=2, dim=-1)
        return projected_features

class TextEncoder(nn.Module):
    """Encodes text using PhoBERT-Large."""
    def __init__(self, config_train, pretrained=True):
        super().__init__()
        self.config_train = config_train
        print(f"Initializing Text Encoder: {config_train.selected_text_model}")

        # Use the specific text model name from config
        if pretrained:
            self.model = AutoModel.from_pretrained(config_train.selected_text_model)
        else:
            model_config = AutoConfig.from_pretrained(config_train.selected_text_model)
            self.model = AutoModel.from_config(model_config)

        self.input_features = config_train.text_embedding # Should be 1024 for PhoBERT-Large
        actual_hidden_size = self.model.config.hidden_size
        if actual_hidden_size != self.input_features:
             print(f"WARNING: Configured text_embedding ({self.input_features}) does not match PhoBERT hidden size ({actual_hidden_size}). Using actual size.")
             self.input_features = actual_hidden_size # Use actual size

        self.projection = nn.Linear(self.input_features, config_train.projection_dim, bias=False)
        print(f"  Added projection head: {self.input_features} -> {config_train.projection_dim}")

    def forward(self, input_ids, attention_mask):
        # PhoBERT typically uses the [CLS] token's output for pooled representation
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        text_features = outputs.last_hidden_state[:, 0, :] # [CLS] token embedding

        projected_features = self.projection(text_features)
        projected_features = F.normalize(projected_features, p=2, dim=-1)
        return projected_features

class CustomBlipPhobertModel(nn.Module): # Renamed for clarity
    """Combines BLIP Vision encoder and PhoBERT Text encoder for contrastive retrieval."""
    def __init__(self, image_encoder, text_encoder, config_train):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.config_train = config_train

        if config_train.learnable_temperature:
            self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / config_train.temperature))
            print(f"Using learnable temperature, initialized to {self.logit_scale.exp().item():.4f}")
        else:
            self.register_buffer('logit_scale', torch.tensor(np.log(1 / config_train.temperature)))
            print(f"Using fixed temperature: {config_train.temperature}")

    def forward(self, pixel_values, input_ids, attention_mask):
        image_features = self.image_encoder(pixel_values)
        text_features = self.text_encoder(input_ids, attention_mask)

        # Cosine similarity scaling
        if isinstance(self.logit_scale, nn.Parameter):
             current_logit_scale = self.logit_scale.exp()
        else:
             current_logit_scale = self.logit_scale.exp().to(image_features.device) # Move buffer if needed

        # Ensure FP32 for stability
        logits_per_image = current_logit_scale.float() * image_features.float() @ text_features.float().t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text, image_features, text_features

# --- Loss Function (Contrastive Loss) ---
def contrastive_loss(logits_per_image, logits_per_text):
    # Ensure logits are FP32 for cross_entropy
    logits_per_image = logits_per_image.float()
    logits_per_text = logits_per_text.float()

    batch_size = logits_per_image.shape[0]
    if batch_size == 0: return torch.tensor(0.0, device=logits_per_image.device, requires_grad=True)
    labels = torch.arange(batch_size, device=logits_per_image.device)
    loss_img = F.cross_entropy(logits_per_image, labels)
    loss_txt = F.cross_entropy(logits_per_text, labels)
    total_loss = (loss_img + loss_txt) / 2.0
    return total_loss

print("Custom BLIP+PhoBERT Model components and loss function defined.")

Custom BLIP+PhoBERT Model components and loss function defined.


In [7]:
# === Cell 7: Training and Validation Epoch Functions (No AMP/Scaler) ===

def train_epoch(model, dataloader, optimizer, device, epoch_num):
    model.train()
    loss_meter = AvgMeter(f"Train Loss E{epoch_num}")
    try: from tqdm.notebook import tqdm as pbar
    except ImportError: from tqdm import tqdm as pbar
    progress_bar = pbar(dataloader, desc=f"Training E{epoch_num}", leave=False, unit="batch")

    optimizer.zero_grad()
    for i, batch in enumerate(progress_bar):
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        batch_size = pixel_values.size(0)
        if batch_size == 0: continue

        logits_per_image, logits_per_text, _, _ = model(pixel_values, input_ids, attention_mask)

        # Ensure logits are float32 for loss
        loss = contrastive_loss(logits_per_image.float(), logits_per_text.float())
        loss = loss / config.accumulation_steps # Normalize loss

        loss.backward()

        if (i + 1) % config.accumulation_steps == 0 or (i + 1) == len(dataloader):
            optimizer.step()
            optimizer.zero_grad()

        loss_meter.update(loss.item() * config.accumulation_steps, batch_size) # Log un-normalized loss
        progress_bar.set_postfix(loss=f"{loss_meter.avg:.4f}")

    optimizer.zero_grad() # Clean up at end
    return loss_meter.avg


def validate_epoch(model, dataloader, device, epoch_num):
    model.eval()
    loss_meter = AvgMeter(f"Val Loss E{epoch_num}")
    all_image_embeddings = []
    all_text_embeddings = []
    try: from tqdm.notebook import tqdm as pbar
    except ImportError: from tqdm import tqdm as pbar
    progress_bar = pbar(dataloader, desc=f"Validation E{epoch_num}", leave=False, unit="batch")

    with torch.no_grad():
        for batch in progress_bar:
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            batch_size = pixel_values.size(0)
            if batch_size == 0: continue

            logits_per_image, logits_per_text, image_embeds, text_embeds = model(pixel_values, input_ids, attention_mask)
            loss = contrastive_loss(logits_per_image.float(), logits_per_text.float()) # Ensure FP32

            loss_meter.update(loss.item(), batch_size)
            all_image_embeddings.append(image_embeds.cpu())
            all_text_embeddings.append(text_embeds.cpu())
            progress_bar.set_postfix(loss=f"{loss_meter.avg:.4f}")

    if not all_image_embeddings or not all_text_embeddings:
         print("Warning: No embeddings collected during validation.")
         zero_metrics = { "loss": loss_meter.avg, "avg acc": 0.0, "avg cosine sim": 0.0,
                           "i2t recall R@1": 0.0, "i2t recall R@5": 0.0, "i2t recall R@10": 0.0,
                           "t2i recall R@1": 0.0, "t2i recall R@5": 0.0, "t2i recall R@10": 0.0,
                           "avg R@1": 0.0, "avg R@5": 0.0, "avg R@10": 0.0 }
         return zero_metrics

    all_image_embeddings = torch.cat(all_image_embeddings, dim=0)
    all_text_embeddings = torch.cat(all_text_embeddings, dim=0)

    print(f"\nComputing metrics over {all_image_embeddings.shape[0]} validation samples...")
    validation_metrics = compute_metrics(all_image_embeddings.to(device), all_text_embeddings.to(device))

    final_results = {"loss": loss_meter.avg}
    for k, v in validation_metrics.items():
        if isinstance(v, dict): # Handle recall dicts
            for recall_k, recall_v in v.items(): final_results[f"{k.replace('_', ' ')} {recall_k}"] = recall_v
        else: final_results[k.replace('_', ' ')] = v
    return final_results

print("Training and Validation epoch functions defined (No AMP).")

Training and Validation epoch functions defined (No AMP).


In [8]:
# === Cell 8: Setup - Tokenizer and Image Processor === MODIFIED ===
from transformers import AutoTokenizer, BlipImageProcessor

tokenizer = None
image_processor = None

print(f"Loading Tokenizer: {config.text_tokenizer_name}")
try:
    tokenizer = AutoTokenizer.from_pretrained(config.text_tokenizer_name)
    print("PhoBERT Tokenizer loaded successfully.")
except Exception as e:
    print(f"ERROR loading tokenizer '{config.text_tokenizer_name}': {e}")

print(f"Loading Image Processor from: {config.blip_image_processor_name}")
try:
    # Use BlipImageProcessor associated with the vision model source
    image_processor = BlipImageProcessor.from_pretrained(config.blip_image_processor_name)
    print("BLIP Image Processor loaded successfully.")
except Exception as e:
    print(f"ERROR loading image processor '{config.blip_image_processor_name}': {e}")

Loading Tokenizer: vinai/phobert-large
PhoBERT Tokenizer loaded successfully.
Loading Image Processor from: Salesforce/blip-image-captioning-large
BLIP Image Processor loaded successfully.


In [9]:
# === Cell 9: Setup - Datasets and DataLoaders === MODIFIED ===
train_loader = None
dev_loader = None

if tokenizer and image_processor: # Check if both loaded
    print("\nCreating datasets...")
    train_json = os.path.join(config.data_path, "train.json")
    dev_json = os.path.join(config.data_path, "dev.json")

    train_dataset = CustomImageCaptionDataset( # Use the custom dataset class
        json_path=train_json, image_base_path=config.image_path,
        tokenizer=tokenizer, image_processor=image_processor, # Pass separate components
        max_length=config.max_length
    )
    dev_dataset = CustomImageCaptionDataset( # Use the custom dataset class
        json_path=dev_json, image_base_path=config.image_path,
        tokenizer=tokenizer, image_processor=image_processor, # Pass separate components
        max_length=config.max_length
    )

    if not train_dataset.data: print("\nERROR: Failed to load training data.")
    if not dev_dataset.data: print("\nWARNING: Failed to load validation data.")

    print("\nCreating dataloaders...")
    num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
    print(f"Using {num_workers} workers for DataLoaders.")

    if train_dataset.data:
        train_loader = DataLoader(
            train_dataset, batch_size=config.batch_size, shuffle=True,
            num_workers=num_workers, pin_memory=True if config.device == torch.device("cuda") else False,
            drop_last=True
        )
        print(f"Train loader created with {len(train_loader)} batches.")
    else: print("Skipping train loader creation.")

    if dev_dataset.data:
        dev_loader = DataLoader(
            dev_dataset, batch_size=config.batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=True if config.device == torch.device("cuda") else False,
            drop_last=False
        )
        print(f"Validation loader created with {len(dev_loader)} batches.")
    else: print("Skipping validation loader creation.")

    if not train_loader: print("\nERROR: Train loader could not be created.")
else:
     print("ERROR: Tokenizer or Image Processor not loaded. Skipping dataset/loader creation.")


Creating datasets...
Attempting to load data from: /home/researcher/huypq69/TuningModels/json_data/train.json
Found 18899 samples in train.json.
Using image size: 384x384
Attempting to load data from: /home/researcher/huypq69/TuningModels/json_data/dev.json
Found 2239 samples in dev.json.
Using image size: 384x384

Creating dataloaders...
Using 12 workers for DataLoaders.
Train loader created with 2362 batches.
Validation loader created with 280 batches.


In [10]:
# === Cell 10: Setup - Model, Optimizer, Scheduler === MODIFIED ===
model = None
optimizer = None
lr_scheduler = None

print("\nInitializing model components...")
try:
    # Instantiate the modified encoders and model wrapper
    image_encoder = ImageEncoder(config).to(config.device) # Uses BlipVisionModel
    text_encoder = TextEncoder(config).to(config.device)   # Uses PhoBERT-Large
    model = CustomBlipPhobertModel(image_encoder, text_encoder, config).to(config.device)
    print(f"\nCustomBlipPhobertModel initialized successfully on {config.device}.")
    num_params_total = sum(p.numel() for p in model.parameters())
    num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {num_params_total / 1e6:.2f} M")
    print(f"Trainable parameters: {num_params_trainable / 1e6:.2f} M")

except Exception as e:
    print(f"ERROR initializing model components: {e}")
    model = None

if model:
    print("\nSetting up optimizer...")
    # Get parameters from the specific model components
    image_encoder_base_params = [p for p in model.image_encoder.vision_model.parameters() if p.requires_grad]
    image_head_params = [p for p in model.image_encoder.projection.parameters() if p.requires_grad]
    text_encoder_base_params = [p for p in model.text_encoder.model.parameters() if p.requires_grad] # PhoBERT base
    text_head_params = [p for p in model.text_encoder.projection.parameters() if p.requires_grad]
    logit_scale_param = [model.logit_scale] if isinstance(model.logit_scale, nn.Parameter) else []

    print(f"  Param counts (Trainable): VisionBase={len(image_encoder_base_params)}, VisionHead={len(image_head_params)}, TextBase={len(text_encoder_base_params)}, TextHead={len(text_head_params)}, LogitScale={len(logit_scale_param)}")

    optimizer_grouped_parameters = [
        {"params": image_encoder_base_params, "lr": config.vision_encoder_lr, "weight_decay": config.weight_decay},
        {"params": image_head_params, "lr": config.projection_lr, "weight_decay": config.weight_decay},
        {"params": text_encoder_base_params, "lr": config.text_encoder_lr, "weight_decay": config.weight_decay},
        {"params": text_head_params, "lr": config.projection_lr, "weight_decay": config.weight_decay},
        {"params": logit_scale_param, "lr": config.projection_lr, "weight_decay": 0}
    ]

    optimizer_grouped_parameters = [g for g in optimizer_grouped_parameters if g['params']]

    if not optimizer_grouped_parameters:
         print("ERROR: No trainable parameters found for the optimizer.")
    else:
        optimizer = optim.AdamW(optimizer_grouped_parameters)
        print(f"Optimizer AdamW initialized.")

        lr_scheduler = ReduceLROnPlateau(
            optimizer, mode=config.mode, factor=config.factor, patience=config.patience
        )
        print(f"LR Scheduler ReduceLROnPlateau initialized (mode='{config.mode}', factor={config.factor}, patience={config.patience})")

        early_stopping_counter = 0 # Renamed from config.early_stopping_counter
        print(f"Early stopping initialized (patience={config.early_stopping_patience}, min_delta={config.early_stopping_min_delta})")

else:
    print("ERROR: Model not initialized. Skipping optimizer/scheduler setup.")


Initializing model components...
Initializing BLIP Vision Encoder from: Salesforce/blip-image-captioning-large


Some weights of BlipVisionModel were not initialized from the model checkpoint at Salesforce/blip-image-captioning-large and are newly initialized: ['embeddings.class_embedding', 'embeddings.patch_embedding.bias', 'embeddings.patch_embedding.weight', 'embeddings.position_embedding', 'encoder.layers.0.layer_norm1.bias', 'encoder.layers.0.layer_norm1.weight', 'encoder.layers.0.layer_norm2.bias', 'encoder.layers.0.layer_norm2.weight', 'encoder.layers.0.mlp.fc1.bias', 'encoder.layers.0.mlp.fc1.weight', 'encoder.layers.0.mlp.fc2.bias', 'encoder.layers.0.mlp.fc2.weight', 'encoder.layers.0.self_attn.projection.bias', 'encoder.layers.0.self_attn.projection.weight', 'encoder.layers.0.self_attn.qkv.bias', 'encoder.layers.0.self_attn.qkv.weight', 'encoder.layers.1.layer_norm1.bias', 'encoder.layers.1.layer_norm1.weight', 'encoder.layers.1.layer_norm2.bias', 'encoder.layers.1.layer_norm2.weight', 'encoder.layers.1.mlp.fc1.bias', 'encoder.layers.1.mlp.fc1.weight', 'encoder.layers.1.mlp.fc2.bias', '

  Added projection head: 1024 -> 256
Initializing Text Encoder: vinai/phobert-large
  Added projection head: 1024 -> 256
Using learnable temperature, initialized to 14.2857

CustomBlipPhobertModel initialized successfully on cuda.
Total parameters: 673.38 M
Trainable parameters: 673.38 M

Setting up optimizer...
  Param counts (Trainable): VisionBase=294, VisionHead=1, TextBase=391, TextHead=1, LogitScale=1
Optimizer AdamW initialized.
LR Scheduler ReduceLROnPlateau initialized (mode='max', factor=0.8, patience=3)
Early stopping initialized (patience=5, min_delta=0.001)


In [11]:
# === Cell 11: Training Loop (No AMP) === MODIFIED ===
if model and train_loader and optimizer and lr_scheduler:
    print(f"\nStarting training for {config.epochs} epochs...")
    print(f"Tracking metric: '{config.metric_to_track}' (mode: {config.mode})")

    best_val_metric = -float('inf') if config.mode == "max" else float('inf')
    history = {'train_loss': [], 'validation_results': []}
    start_train_time = time.time()
    no_improve_epochs = 0 # Use the variable defined in Cell 10

    for epoch in range(config.epochs):
        epoch_start_time = time.time()
        print(f"\n--- Epoch {epoch+1}/{config.epochs} ---")

        # --- Training ---
        train_loss = train_epoch(model, train_loader, optimizer, config.device, epoch+1)
        history['train_loss'].append(train_loss)
        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}")

        # --- Validation ---
        val_results = {"loss": float('inf'), config.metric_to_track.replace('_', ' '): (-float('inf') if config.mode == 'max' else float('inf'))}
        if dev_loader:
            val_results = validate_epoch(model, dev_loader, config.device, epoch+1)
            history['validation_results'].append(val_results)
            print("  Validation Metrics:")
            metric_log_str = "  "
            sorted_keys = sorted(val_results.keys(), key=lambda x: (x.split()[0], int(x.split('@')[-1]) if '@' in x else -1))
            for name in sorted_keys: metric_log_str += f"{name}: {val_results[name]:.4f} | "
            print(metric_log_str.strip(" | "))

            # --- Scheduler Step ---
            current_val_metric_for_scheduler = val_results.get(config.metric_to_track.replace('_', ' '), None)
            if current_val_metric_for_scheduler is not None:
                lr_scheduler.step(current_val_metric_for_scheduler)
                current_lrs = [group['lr'] for group in optimizer.param_groups]
                # Updated LR string for this specific setup
                lr_str = f"  Current LRs: VisionBase={current_lrs[0]:.2e}, VisionHead={current_lrs[1]:.2e}, TextBase={current_lrs[2]:.2e}, TextHead={current_lrs[3]:.2e}"
                if len(current_lrs) > 4: lr_str += f", LogitScale={current_lrs[4]:.2e}"
                print(lr_str)
            else: print(f"  Warning: Metric '{config.metric_to_track}' not found. Scheduler not stepped.")
        else:
            print("  Validation skipped (no dev_loader).")
            history['validation_results'].append(None)

        # --- Save Checkpoint & Early Stopping Logic ---
        current_val_metric = val_results.get(config.metric_to_track.replace('_', ' '), -float('inf') if config.mode == "max" else float('inf'))
        is_best = False
        if dev_loader:
            if config.mode == "max":
                if current_val_metric > best_val_metric + config.early_stopping_min_delta: is_best = True
            else: # min mode
                if current_val_metric < best_val_metric - config.early_stopping_min_delta: is_best = True

            if is_best:
                print(f"  Metric '{config.metric_to_track}' improved from {best_val_metric:.4f} to {current_val_metric:.4f}")
                best_val_metric = current_val_metric; no_improve_epochs = 0
            else:
                no_improve_epochs += 1
                print(f"  Metric '{config.metric_to_track}' did not improve. Best: {best_val_metric:.4f}. Counter: {no_improve_epochs}/{config.early_stopping_patience}")

        # Update checkpoint names
        best_checkpoint_path = os.path.join(config.model_path, "phobert_blip_retrieval_best.pt")
        final_epoch_path = os.path.join(config.model_path, f"phobert_blip_retrieval_epoch_{epoch+1}.pt")

        # Save necessary configs and state dicts
        save_dict = {
            'epoch': epoch + 1, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': lr_scheduler.state_dict(),
            'train_loss': train_loss, 'validation_results': val_results,
            'best_val_metric': best_val_metric, 'metric_tracked': config.metric_to_track,
            # Save relevant configs for reloading this specific architecture
            'vision_model_name': config.blip_vision_model_name,
            'text_model_name': config.selected_text_model,
            'projection_dim': config.projection_dim,
            'learnable_temperature': config.learnable_temperature
        }

        if config.save_best_only and dev_loader:
            if is_best:
                torch.save(save_dict, best_checkpoint_path)
                print(f"  Saved Best Model (Epoch {epoch+1}) to {best_checkpoint_path}")
        else:
            torch.save(save_dict, final_epoch_path)
            print(f"  Saved Epoch {epoch+1} Checkpoint to {final_epoch_path}")
            if is_best and dev_loader:
                 torch.save(save_dict, best_checkpoint_path)
                 print(f"  (Also saved as best model)")

        epoch_end_time = time.time()
        print(f"--- Epoch {epoch+1} Time: {epoch_end_time - epoch_start_time:.2f} seconds ---")

        if dev_loader and no_improve_epochs >= config.early_stopping_patience:
            print(f"\nEarly stopping triggered after {config.early_stopping_patience} epochs without improvement.")
            break

    end_train_time = time.time()
    total_train_time = end_train_time - start_train_time
    print(f"\n=============== Training Finished ================")
    print(f"Total Training Time: {total_train_time:.2f} seconds ({total_train_time/60:.2f} minutes)")
    # Update final model path name
    final_model_path = os.path.join(config.model_path, 'phobert_blip_retrieval_final_epoch.pt')
    torch.save(save_dict, final_model_path)
    print(f"Final epoch model state saved to {final_model_path}")
    best_model_file = os.path.join(config.model_path, "phobert_blip_retrieval_best.pt")
    if dev_loader and os.path.exists(best_model_file):
        print(f"Best model based on '{config.metric_to_track}' ({best_val_metric:.4f}) is saved at: {best_model_file}")
    elif dev_loader: print("Best model checkpoint file not found.")
    print(f"=================================================")
else:
    print("ERROR: Prerequisites for training (model, dataloader, optimizer, scheduler) not met. Training loop skipped.")


Starting training for 1 epochs...
Tracking metric: 'avg_R@1' (mode: max)

--- Epoch 1/1 ---


Training E1:   0%|          | 0/2362 [00:00<?, ?batch/s]

In [12]:
# === Cell 12: Final Evaluation on Test Set === MODIFIED ===
print("\n=============== Starting Test Set Evaluation ===============")

test_loader = None
model_to_test = None
test_json_path = os.path.join(config.data_path, "test.json")

# 1. Check if test data, tokenizer, and image processor exist
if os.path.exists(test_json_path) and 'tokenizer' in globals() and tokenizer and 'image_processor' in globals() and image_processor:
    print(f"Loading test data from: {test_json_path}")
    try:
        test_dataset = CustomImageCaptionDataset( # Use the custom dataset
            json_path=test_json_path, image_base_path=config.image_path,
            tokenizer=tokenizer, image_processor=image_processor, # Pass separate components
            max_length=config.max_length
        )
        if test_dataset.data:
            num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
            test_loader = DataLoader(
                test_dataset, batch_size=config.batch_size, shuffle=False,
                num_workers=num_workers, pin_memory=True if config.device == torch.device("cuda") else False,
                drop_last=False
            )
            print(f"Test loader created with {len(test_loader)} batches.")
        else: print("Test dataset loaded but is empty.")
    except Exception as e: print(f"Error creating test dataset/loader: {e}")
else: print("Skipping test evaluation: Test JSON, Tokenizer or Image Processor not found/loaded.")

# 2. Load Model for Testing if test_loader was created
if test_loader:
    try:
        # Update checkpoint names
        best_model_path = os.path.join(config.model_path, "phobert_blip_retrieval_best.pt")
        final_model_path = os.path.join(config.model_path, "phobert_blip_retrieval_final_epoch.pt")
        load_path = None
        if os.path.exists(best_model_path): load_path = best_model_path; print(f"\nLoading best model: {load_path}")
        elif os.path.exists(final_model_path): load_path = final_model_path; print(f"\nLoading final model: {load_path}")
        else: print(f"\nWARNING: No checkpoints found in {config.model_path}.")

        if load_path:
            checkpoint = torch.load(load_path, map_location=config.device)
            print("Re-creating model structure for testing...")

            # Create a temporary config for model loading using info from checkpoint if possible
            temp_config_dict = {
                'selected_vision_source': checkpoint.get('vision_model_name', config.selected_vision_source),
                'selected_text_model': checkpoint.get('text_model_name', config.selected_text_model),
                'vision_embedding': checkpoint.get('vision_embedding', config.vision_embedding), # Need these in checkpoint or derived
                'text_embedding': checkpoint.get('text_embedding', config.text_embedding),      # Need these in checkpoint or derived
                'projection_dim': checkpoint.get('projection_dim', config.projection_dim),
                'learnable_temperature': checkpoint.get('learnable_temperature', config.learnable_temperature),
                'temperature': config.temperature # Use current config temp if not saved
            }
            # Create a dummy class or simple namespace to hold these attributes for init
            from types import SimpleNamespace
            temp_config = SimpleNamespace(**temp_config_dict)

            # Instantiate using the possibly loaded config values
            test_image_encoder = ImageEncoder(temp_config).to(config.device)
            test_text_encoder = TextEncoder(temp_config).to(config.device)
            model_to_test = CustomBlipPhobertModel(test_image_encoder, test_text_encoder, temp_config).to(config.device)

            state_dict = checkpoint['model_state_dict']
            if all(k.startswith('module.') for k in state_dict.keys()):
                print("Detected 'module.' prefix, removing.")
                from collections import OrderedDict
                state_dict = OrderedDict((k[7:], v) for k, v in state_dict.items())

            load_result = model_to_test.load_state_dict(state_dict, strict=False)
            print(f"  State dict loading result: {load_result}")
            if load_result.missing_keys: print(f"  Warning: Missing keys: {load_result.missing_keys}")
            if load_result.unexpected_keys: print(f"  Warning: Unexpected keys: {load_result.unexpected_keys}")
            print(f"Model weights loaded successfully.")

            print("\nRunning evaluation on test set...")
            test_results = validate_epoch(model_to_test, test_loader, config.device, epoch_num="Test")

            print("\n--- Test Set Results ---")
            metric_log_str = ""
            sorted_keys = sorted(test_results.keys(), key=lambda x: (x.split()[0], int(x.split('@')[-1]) if '@' in x else -1))
            for name in sorted_keys: metric_log_str += f"  {name}: {test_results[name]:.4f}\n"
            print(metric_log_str.strip())
            print("------------------------")
        else: print("Evaluation skipped (no weights found).")
    except Exception as e:
        print(f"\nERROR during test setup/evaluation: {e}")
        import traceback
        traceback.print_exc()

print("\n================= Evaluation Finished ==================")


Loading test data from: ./json_data/test.json
Attempting to load data from: /home/researcher/huypq69/TuningModels/json_data/test.json
Found 2176 samples in test.json.
Using image size: 384x384
Test loader created with 136 batches.

Evaluation skipped (no weights found).



In [None]:
# === Cell 13: Training Visualization (Adapted) ===
import matplotlib.pyplot as plt
import seaborn as sns
import os
import numpy as np

sns.set_style("whitegrid")
plot_dir = "ViBLIP_vivqa_train_plot"
os.makedirs(plot_dir, exist_ok=True)
print(f"Plot directory ensured at: {os.path.abspath(plot_dir)}")

def save_subplot_as_figure(subplot, save_path):
    fig_new = plt.figure(figsize=(8, 6))
    ax_new = fig_new.add_subplot(111)
    lines = subplot.get_lines()
    if not lines: print(f"Warning: No lines in subplot for {save_path}"); plt.close(fig_new); return
    labels = [line.get_label() for line in lines]
    for line in lines:
        ax_new.plot(line.get_xdata(), line.get_ydata(), color=line.get_color(),
                    linestyle=line.get_linestyle(), marker=line.get_marker(), label=line.get_label())
    ax_new.set_title(subplot.get_title()); ax_new.set_xlabel(subplot.get_xlabel()); ax_new.set_ylabel(subplot.get_ylabel())
    ax_new.grid(True)
    if any(label and not label.startswith('_') for label in labels): ax_new.legend()
    plt.tight_layout()
    fig_new.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close(fig_new)

def plot_training_metrics(history):
    if not history or not history.get('train_loss') or not history.get('validation_results'):
        print("No/incomplete training history available."); return
    valid_results = [res for res in history['validation_results'] if res is not None]
    if not valid_results:
        print("No valid validation results. Plotting only training loss.")
        epochs = range(1, len(history['train_loss']) + 1)
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        ax.plot(epochs, history['train_loss'], 'b-o', label='Training Loss')
        ax.set_title('Training Loss over Epochs'); ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.legend(); ax.grid(True)
        save_path = os.path.join(plot_dir, 'training_loss.png')
        fig.savefig(save_path, bbox_inches='tight', dpi=300); print(f"Saved loss plot to: {save_path}"); plt.close(fig)
        return

    num_epochs_trained = len(history['train_loss']); num_epochs_validated = len(valid_results)
    epochs_train = range(1, num_epochs_trained + 1); epochs_val = range(1, num_epochs_validated + 1)
    fig, axes = plt.subplots(2, 2, figsize=(16, 13))
    fig.suptitle('Training and Validation Metrics', fontsize=16, y=1.02)

    val_loss = [res.get('loss', float('nan')) for res in valid_results]
    axes[0, 0].plot(epochs_train, history['train_loss'], 'b-o', label='Training Loss')
    axes[0, 0].plot(epochs_val, val_loss, 'r-s', label='Validation Loss')
    axes[0, 0].set_title('Loss'); axes[0, 0].set_xlabel('Epoch'); axes[0, 0].set_ylabel('Loss'); axes[0, 0].legend(); axes[0, 0].grid(True)

    metric_key_acc = 'avg acc'
    if metric_key_acc in valid_results[0]:
        val_acc = [res[metric_key_acc] for res in valid_results]
        axes[0, 1].plot(epochs_val, val_acc, 'g-^', label='Avg Accuracy (Val)')
        axes[0, 1].set_title('Avg Accuracy'); axes[0, 1].set_xlabel('Epoch'); axes[0, 1].set_ylabel('Accuracy'); axes[0, 1].legend(); axes[0, 1].grid(True)
    else: axes[0, 1].set_title(f'{metric_key_acc} (Not Found)')

    has_recall = 'i2t recall R@1' in valid_results[0]
    if has_recall:
        for k in [1, 5, 10]:
            axes[1, 0].plot(epochs_val, [res.get(f'i2t recall R@{k}', float('nan')) for res in valid_results], marker='o', label=f'I2T R@{k}')
        axes[1, 0].set_title('I2T Recall (Val)'); axes[1, 0].set_xlabel('Epoch'); axes[1, 0].set_ylabel('Recall'); axes[1, 0].legend(); axes[1, 0].grid(True)
        for k in [1, 5, 10]:
            axes[1, 1].plot(epochs_val, [res.get(f't2i recall R@{k}', float('nan')) for res in valid_results], marker='s', label=f'T2I R@{k}')
        axes[1, 1].set_title('T2I Recall (Val)'); axes[1, 1].set_xlabel('Epoch'); axes[1, 1].set_ylabel('Recall'); axes[1, 1].legend(); axes[1, 1].grid(True)
    else: axes[1, 0].set_title('I2T Recall (Not Found)'); axes[1, 1].set_title('T2I Recall (Not Found)')

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plot_names = ['loss', 'accuracy', 'i2t_recall', 't2i_recall']
    for idx, name in enumerate(plot_names):
        i, j = divmod(idx, 2); save_path = os.path.join(plot_dir, f'training_{name}.png')
        if axes[i, j].has_data(): save_subplot_as_figure(axes[i, j], save_path); print(f"Saved {name} plot to: {save_path}")
        else: print(f"Skipping save for {name} plot (no data).")
    combined_save_path = os.path.join(plot_dir, 'training_metrics_combined.png')
    fig.savefig(combined_save_path, bbox_inches='tight', dpi=300); print(f"Saved combined plot to: {combined_save_path}")
    # plt.show()
    plt.close(fig)

# --- Plotting ---
if 'history' in locals() and isinstance(history, dict) and history.get('train_loss'):
    plot_training_metrics(history)
else:
    print("No training history found. Run training first.")

# --- END OF SCRIPT ---

Plot directory ensured at: /home/researcher/huypq69/TuningModels/ViBLIP_vivqa_train_plot
No training history found. Run training first.
