In [1]:
# === Cell 1: Imports ===
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR # Keep Cosine for option
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True # Handle potential image loading issues

# --- Import SigLIP Vision model and its Processor ---
from transformers import SiglipVisionModel, SiglipConfig, SiglipImageProcessor
# --- Keep PhoBERT parts ---
from transformers import AutoModel, AutoTokenizer, AutoConfig
# --- Blip models no longer needed for loading ---

from PIL import Image
import json
import os
import random
import numpy as np
import math
import time
import transformers
import gc
import traceback # Import traceback

try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm

print(f"PyTorch Version: {torch.__version__}")
print(f"Transformers Version: {transformers.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    # Ensure you're setting the correct device index if needed
    device_idx = 1 # Or 0
    print(f"CUDA Device Name: {torch.cuda.get_device_name(device_idx)}")
    torch.cuda.set_device(device_idx)

PyTorch Version: 2.6.0+cu124
Transformers Version: 4.50.0
CUDA Available: True
CUDA Device Name: NVIDIA GeForce RTX 4090


In [2]:
# === Cell 2: Configuration Class (CFG) - Modified for SigLIP Vision + PhoBERT Text ===
class CFG:
    # --- Paths ---
    data_path = "./json_data/"
    image_base_path = "./data/UIT-OpenViIC-dataset/" # Switched back to UIT-OpenViIC for this example
    # --- Use a NEW model path ---
    model_path = "./trained_models/ViPhobertSiglip_uitopenviic"

    # --- Model Selection ---
    # --- SigLIP Vision Model ---
    selected_vision_source = "google/siglip-base-patch16-224"
    # --- Keep PhoBERT Text Model ---
    selected_text_model = "vinai/phobert-base"
    text_tokenizer_name = selected_text_model

    # --- Model parameters ---
    vision_model_name = selected_vision_source # For clarity
    text_model_name = selected_text_model   # For clarity
    # --- Image Processor: Use SigLIP's processor ---
    image_processor_name = selected_vision_source

    @property
    def text_embedding(self): return 768 # PhoBERT-base output
    @property
    def vision_embedding(self): return 768 # Siglip-base-patch16-224 output

    projection_dim = 768 # Common projection dim for CLIP-style models (adjust if needed, e.g., 768)

    # --- Fine-tuning parameters ---
    seed = 42
    # Adjust batch size based on VRAM for SigLIP base + PhoBERT base
    batch_size = 128
    num_workers = 20
    accumulation_steps = 1 # Effective batch size = 128

    # --- Learning Rates for Fine-tuning ---
    projection_lr = 1e-4
    vision_encoder_lr = 1e-5 # Lower LR for SigLIP backbone
    text_encoder_lr = 2e-5   # Slightly higher LR for PhoBERT backbone
    weight_decay = 1e-4     # Lower weight decay for fine-tuning
    learning_rate = 1e-4

    # --- Use standard contrastive loss temperature (like CLIP) ---
    temperature = 0.07
    learnable_temperature = True
    # --- Bias term is NOT used in standard contrastive loss ---
    learnable_bias = False
    bias_init = 0.0

    # --- Scheduler ---
    scheduler_type = "reduce_on_plateau" # RoP often used for fine-tuning
    rop_patience = 2
    rop_factor = 0.8

    epochs = 5 # Fine-tuning might converge faster

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = True # Keep AMP enabled

    # --- Image/Text parameters ---
    max_length = 77 # Standard CLIP length

    # --- Loss/Saving parameters ---
    save_best_only = True
    metric_to_track = "avg_acc" # Track validation accuracy
    mode = "max"
    # Adjust intervals if needed
    save_interval_steps = 1000 # Save periodically during fine-tuning (optional)
    validation_interval_steps = 1000 # Validate more often during fine-tuning
    log_interval_steps = 50

    early_stopping_patience = 5 # Patience in terms of validation checks
    early_stopping_min_delta = 0.001 # Min change to be considered improvement

# --- Instantiate Config and Create Output Dir ---
config = CFG()
os.makedirs(config.model_path, exist_ok=True)
print(f"Using device: {config.device}")
print(f"Per-Device Batch Size: {config.batch_size}")
print(f"Accumulation Steps: {config.accumulation_steps}")
print(f"Effective Batch Size (per optimizer step): {config.batch_size * config.accumulation_steps}")
print(f"Model output path: {config.model_path}")
print(f"Selected Vision Source: {config.selected_vision_source}")
print(f"Selected Text Model: {config.selected_text_model}")
print(f"Image base path (for resolving paths in JSON): {os.path.abspath(config.image_base_path)}")
print(f"AMP Enabled: {config.use_amp}")

Using device: cuda
Per-Device Batch Size: 128
Accumulation Steps: 1
Effective Batch Size (per optimizer step): 128
Model output path: ./trained_models/ViPhobertSiglip_uitopenviic
Selected Vision Source: google/siglip-base-patch16-224
Selected Text Model: vinai/phobert-base
Image base path (for resolving paths in JSON): /home/researcher/huypq69/TuningModels/data/UIT-OpenViIC-dataset
AMP Enabled: True


In [None]:
# === Cell 3: Seeding ===
def set_seed(seed=config.seed):
    print(f"Setting seed: {seed}")
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed()

Setting seed: 42


In [4]:
# === Cell 4: Metric & AvgMeter Utilities ===
# (Keep AvgMeter, compute_recall_at_k, compute_metrics as before)
class AvgMeter:
    """Computes and stores the average and current value"""
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0.0

    def update(self, val, count=1):
        if torch.is_tensor(val):
            val = val.item()
        if isinstance(val, (int, float)):
            self.sum += val * count
            self.count += count
            self.avg = float(self.sum) / self.count if self.count != 0 else 0.0
        else:
            print(f"Warning: Cannot update AvgMeter '{self.name}' with value type {type(val)}")

    def __repr__(self):
        return f"{self.name}: {self.avg:.4f}"

def compute_recall_at_k(similarity_matrix, k, dim):
    n = similarity_matrix.shape[1-dim]
    if n == 0: return 0.0
    correct_count = 0
    actual_k = min(k, similarity_matrix.shape[dim])
    if actual_k == 0: return 0.0
    top_k_indices = torch.topk(similarity_matrix, actual_k, dim=dim).indices
    ground_truth = torch.arange(n, device=similarity_matrix.device)
    if dim == 0: # I2T
        for img_idx in range(n):
            if ground_truth[img_idx] in top_k_indices[:, img_idx]:
                correct_count += 1
    elif dim == 1: # T2I
        for txt_idx in range(n):
             if ground_truth[txt_idx] in top_k_indices[txt_idx, :]:
                correct_count += 1
    else: raise ValueError("dim must be 0 or 1")
    return float(correct_count) / n if n > 0 else 0.0

def compute_metrics(image_embeddings, text_embeddings):
    sim_matrix = text_embeddings.float() @ image_embeddings.float().T
    n = sim_matrix.shape[0]
    if n == 0:
        return {
            "i2t_acc": 0.0, "t2i_acc": 0.0, "avg_acc": 0.0,
            "avg_cosine_sim": 0.0,
            "i2t_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0},
            "t2i_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0}
        }
    ground_truth = torch.arange(n, device=sim_matrix.device)
    i2t_preds = torch.argmax(sim_matrix, dim=0)
    t2i_preds = torch.argmax(sim_matrix, dim=1)
    i2t_acc = (i2t_preds == ground_truth).float().mean().item()
    t2i_acc = (t2i_preds == ground_truth).float().mean().item()
    avg_acc = (i2t_acc + t2i_acc) / 2.0
    avg_cosine_sim = torch.diag(sim_matrix).mean().item()
    i2t_recall = {}
    t2i_recall = {}
    recall_k_values = [k for k in [1, 5, 10] if k <= n]
    for k in recall_k_values:
        i2t_recall[f"R@{k}"] = compute_recall_at_k(sim_matrix, k, dim=0)
        t2i_recall[f"R@{k}"] = compute_recall_at_k(sim_matrix, k, dim=1)
    for k in [1, 5, 10]:
        k_str = f"R@{k}"
        if k_str not in i2t_recall: i2t_recall[k_str] = 0.0
        if k_str not in t2i_recall: t2i_recall[k_str] = 0.0
    metrics = {
        "i2t_acc": i2t_acc, "t2i_acc": t2i_acc, "avg_acc": avg_acc,
        "avg_cosine_sim": avg_cosine_sim,
        "i2t_recall": i2t_recall, "t2i_recall": t2i_recall
    }
    return metrics

print("Metric utilities defined.")

Metric utilities defined.


In [5]:
# === Cell 5: Dataset Class Definition (Corrected JSON Loading & Processor Update) ===

import traceback
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True # Keep allowing truncated images

class CustomImageCaptionDataset(Dataset):
    """
    Loads image-caption pairs from JSON metadata.
    Handles both single JSON list format and JSON-per-line format.
    Uses specified image_processor (SigLIP or other).
    """
    def __init__(self, json_path_or_list, image_base_path, tokenizer, image_processor, max_length):
        super().__init__()
        self.data = []
        if isinstance(json_path_or_list, str) and os.path.isdir(json_path_or_list):
             json_files = [os.path.join(json_path_or_list, f) for f in os.listdir(json_path_or_list) if f.endswith('.json')]
             print(f"Found {len(json_files)} JSON files in {json_path_or_list}")
        elif isinstance(json_path_or_list, str) and os.path.isfile(json_path_or_list):
            json_files = [json_path_or_list]
        elif isinstance(json_path_or_list, list):
            json_files = json_path_or_list
        else:
            raise ValueError("json_path_or_list must be a directory, a single JSON file, or a list of JSON files.")

        print("Loading JSON metadata...")
        total_loaded_count = 0
        for json_path in tqdm(json_files, desc="Loading JSONs"):
            try:
                with open(json_path, 'r', encoding='utf-8') as f:
                    try:
                        file_data = json.load(f)
                        if isinstance(file_data, list):
                             self.data.extend(file_data)
                             total_loaded_count += len(file_data)
                        else:
                             self.data.append(file_data)
                             total_loaded_count += 1
                             print(f"  Warning: Loaded single JSON object from {json_path}, expected a list.")
                    except json.JSONDecodeError:
                        print(f"  Info: Failed to load {json_path} as single JSON. Attempting JSON-per-line format...")
                        f.seek(0)
                        count_line_by_line = 0
                        for line in f:
                            line = line.strip()
                            if line:
                                try:
                                    line_data = json.loads(line)
                                    self.data.append(line_data)
                                    count_line_by_line += 1
                                except json.JSONDecodeError as line_err:
                                     print(f"  ERROR parsing line in {json_path}: {line_err}. Line content (partial): {line[:100]}...")
                        total_loaded_count += count_line_by_line
                        if count_line_by_line > 0:
                            print(f"  Successfully loaded {count_line_by_line} items using JSON-per-line format from {json_path}.")
                        else:
                             print(f"  Failed to load any data using JSON-per-line format from {json_path} either.")
            except Exception as e:
                print(f"ERROR opening or processing file {json_path}: {e}")

        print(f"Loaded {total_loaded_count} samples total from {len(json_files)} file(s).")
        self.data = [item for item in self.data if item]
        print(f"Dataset size after potential cleaning: {len(self.data)}")

        if not self.data:
             print("WARNING: No data loaded!")

        self.image_base_path = image_base_path
        self.tokenizer = tokenizer
        self.image_processor = image_processor # Store the passed processor
        self.max_length = max_length

        # --- Get image size from the loaded processor ---
        try:
            # SiglipImageProcessor uses config.image_size
            self.img_size = image_processor.config.image_size
        except AttributeError:
            # Fallback for BlipImageProcessor or older versions
            try:
                 if isinstance(image_processor.size, dict):
                     proc_size = image_processor.size
                     self.img_size = proc_size.get('height', proc_size.get('shortest_edge', 224))
                 else:
                     self.img_size = image_processor.size
                     if isinstance(self.img_size, (tuple, list)): self.img_size = self.img_size[0]
            except AttributeError:
                 print("Warning: Could not determine image size from processor, defaulting to 224.")
                 self.img_size = 224
        print(f"Using image target size: {self.img_size}x{self.img_size}")
        # ----------------------------------------------

        if not os.path.isdir(self.image_base_path):
            print(f"WARNING: Image base path does not exist: {os.path.abspath(self.image_base_path)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if idx >= len(self.data): raise IndexError("Index out of bounds")
        item = self.data[idx]

        relative_image_path = item.get('image_path', item.get('url', item.get('filename')))
        caption_data = item.get('caption', item.get('text', item.get('title', '')))
        if isinstance(caption_data, list):
            caption = caption_data[0] if caption_data else ""
        elif isinstance(caption_data, str):
            caption = caption_data
        else:
            caption = ""

        if not relative_image_path or not caption:
            return self._get_dummy_item()

        # Load Image
        try:
            image_path = os.path.join(self.image_base_path, relative_image_path)
            image = Image.open(image_path).convert('RGB')
            # --- Use the stored image processor ---
            image_inputs = self.image_processor(images=image, return_tensors="pt")
            pixel_values = image_inputs['pixel_values'].squeeze(0)
            # ------------------------------------
        except Exception: # Catch broad exceptions during loading/processing
            return self._get_dummy_item()

        # Process Text
        try:
            text_inputs = self.tokenizer(
                caption, padding='max_length', truncation=True,
                max_length=self.max_length, return_tensors='pt'
            )
            input_ids = text_inputs['input_ids'].squeeze(0)
            attention_mask = text_inputs['attention_mask'].squeeze(0)
        except Exception:
            return self._get_dummy_item()

        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

    def _get_dummy_item(self):
        # Use self.img_size determined in __init__
        return {
            "pixel_values": torch.zeros((3, self.img_size, self.img_size), dtype=torch.float),
            "input_ids": torch.zeros(self.max_length, dtype=torch.long),
            "attention_mask": torch.zeros(self.max_length, dtype=torch.long)
        }

print("CustomImageCaptionDataset class defined.")

CustomImageCaptionDataset class defined.


In [6]:
# === Cell 6: Model Definition (SigLIP Vision + PhoBERT Text) ===

class ImageEncoder(nn.Module):
    """Encodes images using SigLIP's Vision Model."""
    def __init__(self, config_train, pretrained=True):
        super().__init__()
        self.config_train = config_train
        print(f"Initializing SigLIP Vision Encoder from: {config_train.vision_model_name}")

        if pretrained:
            try:
                self.vision_model = SiglipVisionModel.from_pretrained(config_train.vision_model_name)
                print("  SigLIP Vision model loaded successfully.")
            except Exception as e:
                print(f"  ERROR loading pretrained SiglipVisionModel: {e}")
                raise # Stop if vision model fails to load
        else:
            print("  Initializing SiglipVisionModel from scratch.")
            siglip_vision_config = SiglipConfig.from_pretrained(config_train.vision_model_name).vision_config
            self.vision_model = SiglipVisionModel(siglip_vision_config)

        try:
            self.input_features = self.vision_model.config.hidden_size
        except AttributeError as e:
             print(f"  ERROR accessing vision_model.config.hidden_size: {e}. Attempting config_train value.")
             self.input_features = config_train.vision_embedding # Fallback

        if hasattr(config_train, 'vision_embedding') and self.input_features != config_train.vision_embedding:
             print(f"  WARNING: Configured vision_embedding ({config_train.vision_embedding}) doesn't match loaded model hidden size ({self.input_features}). Using actual size.")
        else:
             print(f"  Confirmed/Using vision model hidden size: {self.input_features}")

        self.projection = nn.Linear(self.input_features, config_train.projection_dim, bias=False)
        print(f"  Added projection head: {self.input_features} -> {config_train.projection_dim}")

    def forward(self, pixel_values):
        vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=True)
        image_embed = vision_outputs.pooler_output
        projected_features = self.projection(image_embed)
        return projected_features

# --- TextEncoder remains the same (loading PhoBERT) ---
class TextEncoder(nn.Module):
    """Encodes text using PhoBERT-Base."""
    def __init__(self, config_train, pretrained=True):
        super().__init__()
        self.config_train = config_train
        print(f"Initializing Text Encoder: {config_train.text_model_name}")
        if pretrained:
            self.model = AutoModel.from_pretrained(config_train.text_model_name)
        else:
            model_config = AutoConfig.from_pretrained(config_train.text_model_name)
            self.model = AutoModel.from_config(model_config)
        try:
            self.input_features = self.model.config.hidden_size
        except AttributeError as e:
            print(f"  ERROR accessing model.config.hidden_size: {e}. Attempting config_train value.")
            self.input_features = config_train.text_embedding # Fallback
        if hasattr(config_train, 'text_embedding') and self.input_features != config_train.text_embedding:
             print(f"  WARNING: Configured text_embedding ({config_train.text_embedding}) doesn't match loaded PhoBERT hidden size ({self.input_features}). Using actual size.")
        else:
            print(f"  Confirmed text model hidden size: {self.input_features}")
        self.projection = nn.Linear(self.input_features, config_train.projection_dim, bias=False)
        print(f"  Added projection head: {self.input_features} -> {config_train.projection_dim}")

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        text_features = outputs.last_hidden_state[:, 0, :]
        projected_features = self.projection(text_features)
        return projected_features

# --- Combined Model (CLIP-Style) ---
class ViPhobertSiglipModel(nn.Module): # Renamed for clarity
    """Combines SigLIP Vision encoder and PhoBERT Text encoder for contrastive retrieval."""
    def __init__(self, image_encoder, text_encoder, config_train):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.config_train = config_train

        if config_train.learnable_temperature:
            init_val_t = torch.tensor(np.log(1 / config_train.temperature), dtype=torch.float)
            self.logit_scale = nn.Parameter(init_val_t)
            print(f"Using learnable temperature (logit_scale), initialized to {self.logit_scale.exp().item():.4f}")
        else:
            temp_tensor = torch.tensor(np.log(1 / config_train.temperature), dtype=torch.float)
            self.register_buffer('logit_scale', temp_tensor)
            print(f"Using fixed temperature: {config_train.temperature}")

    def forward(self, pixel_values, input_ids, attention_mask):
        pixel_values = pixel_values.to(self.config_train.device)
        input_ids = input_ids.to(self.config_train.device)
        attention_mask = attention_mask.to(self.config_train.device)

        image_embed = self.image_encoder(pixel_values)
        text_embed = self.text_encoder(input_ids, attention_mask)

        image_features = F.normalize(image_embed, p=2, dim=-1)
        text_features = F.normalize(text_embed, p=2, dim=-1)

        logit_scale = self.logit_scale.exp().clamp(max=100)

        logits_per_image = logit_scale.float() * image_features.float() @ text_features.float().t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text, image_features, text_features

print("ViPhobertSiglip Model components defined.")

ViPhobertSiglip Model components defined.


In [7]:
# === Cell 7: Loss Function (Standard Contrastive Loss) ===
def contrastive_loss(logits_per_image, logits_per_text):
    """ Standard InfoNCE-based contrastive loss """
    logits_per_image = logits_per_image.float()
    logits_per_text = logits_per_text.float()
    batch_size = logits_per_image.shape[0]
    if batch_size == 0:
        return torch.tensor(0.0, device=logits_per_image.device, requires_grad=True)
    labels = torch.arange(batch_size, device=logits_per_image.device)
    loss_img = F.cross_entropy(logits_per_image, labels)
    loss_txt = F.cross_entropy(logits_per_text, labels)
    total_loss = (loss_img + loss_txt) / 2.0
    return total_loss

print("Standard Contrastive loss function defined.")

Standard Contrastive loss function defined.


In [8]:
# === Cell 8: Setup - Tokenizer and Image Processor (Using SigLIP Processor) ===
# --- Use SiglipImageProcessor ---
from transformers import AutoTokenizer, SiglipImageProcessor

tokenizer = None
image_processor = None

print(f"Loading Tokenizer: {config.text_tokenizer_name}")
try:
    tokenizer = AutoTokenizer.from_pretrained(config.text_tokenizer_name)
    print("PhoBERT Tokenizer loaded successfully.")
except Exception as e:
    print(f"ERROR loading tokenizer '{config.text_tokenizer_name}': {e}")

print(f"Loading Image Processor from: {config.image_processor_name}") # Use SigLIP name from CFG
try:
    # --- Load SiglipImageProcessor ---
    image_processor = SiglipImageProcessor.from_pretrained(config.image_processor_name)
    print("SigLIP Image Processor loaded successfully.")
except Exception as e:
    print(f"ERROR loading image processor '{config.image_processor_name}': {e}")

Loading Tokenizer: vinai/phobert-base


PhoBERT Tokenizer loaded successfully.
Loading Image Processor from: google/siglip-base-patch16-224
SigLIP Image Processor loaded successfully.


In [9]:
# === Cell 9: Setup - Datasets and DataLoaders (FIXED Validation Path) ===
# Uses the dataset class defined above.

train_loader = None
dev_loader = None

# Define paths
validation_json_path = os.path.join(config.data_path, "dev.json") # <<< CHANGED FILENAME
train_json_path = os.path.join(config.data_path, "train.json")

if tokenizer and image_processor:
    print("\\nCreating datasets...")
    # --- Training Dataset ---
    try:
        print(f"Attempting to load training data from: {train_json_path}")
        train_dataset = CustomImageCaptionDataset(
            json_path_or_list=train_json_path,
            image_base_path=config.image_base_path,
            tokenizer=tokenizer,
            image_processor=image_processor, # Pass the loaded processor
            max_length=config.max_length
        )
        if not train_dataset.data: print("\\nERROR: Failed to load training data.")
    except Exception as e:
        print(f"ERROR creating training dataset: {e}")
        train_dataset = None

    # --- Validation Dataset ---
    if os.path.exists(validation_json_path):
         try:
             print(f"Attempting to load validation data from: {validation_json_path}")
             dev_dataset = CustomImageCaptionDataset(
                 json_path_or_list=validation_json_path,
                 image_base_path=config.image_base_path, # Assumes same base path
                 tokenizer=tokenizer,
                 image_processor=image_processor, # Use same processor
                 max_length=config.max_length
             )
             if not dev_dataset.data: print("\\nWARNING: Failed to load validation data.")
         except Exception as e:
             print(f"ERROR creating validation dataset: {e}")
             dev_dataset = None
    else:
        print(f"Validation JSON file not found at {validation_json_path}, skipping validation set creation.")
        dev_dataset = None

    print("\\nCreating dataloaders...")
    num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
    print(f"Using {num_workers} workers for DataLoaders.")

    if train_dataset and train_dataset.data:
        persist_workers = (num_workers > 0)
        try: # Check if persistent_workers is supported
             _ = DataLoader(train_dataset, num_workers=num_workers, persistent_workers=persist_workers)
        except TypeError:
             persist_workers = False
             print("Note: `persistent_workers=True` not supported by this PyTorch version/DataLoader setup.")

        train_loader = DataLoader(
            train_dataset, batch_size=config.batch_size, shuffle=True,
            num_workers=num_workers,
            pin_memory=True if config.device == torch.device("cuda") else False,
            drop_last=True, # Keep drop_last=True for more stable training steps
            persistent_workers=persist_workers
        )
        print(f"Train loader created with {len(train_loader)} batches.")
        # Calculate total training steps (only if using Cosine scheduler)
        if config.scheduler_type == "cosine":
            config.total_training_steps = len(train_loader) * config.epochs // config.accumulation_steps
            print(f"Total estimated training steps for Cosine Scheduler: {config.total_training_steps}")
    else:
        print("Skipping train loader creation (no data).")
        config.total_training_steps = 0 # Set default if no loader

    if dev_dataset and dev_dataset.data:
        persist_workers_dev = (num_workers > 0)
        try: # Check support for dev loader too
             _ = DataLoader(dev_dataset, num_workers=num_workers, persistent_workers=persist_workers_dev)
        except TypeError:
             persist_workers_dev = False

        dev_loader = DataLoader(
            dev_dataset, batch_size=config.batch_size * 2, shuffle=False,
            num_workers=num_workers,
            pin_memory=True if config.device == torch.device("cuda") else False,
            drop_last=False,
            persistent_workers=persist_workers_dev
        )
        print(f"Validation loader created with {len(dev_loader)} batches.")
    else: print("Skipping validation loader creation.")

    if not train_loader: print("\\nERROR: Train loader could not be created.")
else:
     print("ERROR: Tokenizer or Image Processor not loaded. Skipping dataset/loader creation.")

\nCreating datasets...
Attempting to load training data from: ./json_data/train.json
Loading JSON metadata...


Loading JSONs:   0%|          | 0/1 [00:00<?, ?it/s]

Loaded 41236 samples total from 1 file(s).
Dataset size after potential cleaning: 41236
Using image target size: 224x224
Attempting to load validation data from: ./json_data/dev.json
Loading JSON metadata...


Loading JSONs:   0%|          | 0/1 [00:00<?, ?it/s]

Loaded 10002 samples total from 1 file(s).
Dataset size after potential cleaning: 10002
Using image target size: 224x224
\nCreating dataloaders...
Using 20 workers for DataLoaders.
Train loader created with 322 batches.
Validation loader created with 40 batches.


In [10]:
# === Cell 10: Setup - Model, Optimizer, Scheduler (Fine-tuning LRs & Corrected AMP) ===

model = None
optimizer = None
lr_scheduler = None
scaler = None # For AMP

print("\\nInitializing ViPhobertSiglip model components...")
try:
    # Instantiate the encoders and main model
    image_encoder = ImageEncoder(config).to(config.device)
    text_encoder = TextEncoder(config).to(config.device)
    # --- Instantiate the correct model ---
    model = ViPhobertSiglipModel(image_encoder, text_encoder, config).to(config.device)

    print(f"\\nViPhobertSiglip Model initialized successfully on {config.device}.")
    num_params_total = sum(p.numel() for p in model.parameters())
    num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {num_params_total / 1e6:.2f} M")
    print(f"Trainable parameters: {num_params_trainable / 1e6:.2f} M")

except Exception as e:
    print(f"ERROR initializing model components: {e}")
    traceback.print_exc()
    model = None

if model and train_loader: # Check train_loader exists
    print("\\nSetting up optimizer...")
    # --- Optimizer with Fine-tuning LRs ---
    vision_encoder_params = list(model.image_encoder.vision_model.parameters())
    image_head_params = list(model.image_encoder.projection.parameters())
    text_encoder_params = list(model.text_encoder.model.parameters())
    text_head_params = list(model.text_encoder.projection.parameters())
    logit_scale_param = [model.logit_scale] if isinstance(model.logit_scale, nn.Parameter) else []

    optimizer_grouped_parameters = [
        {'params': [p for p in vision_encoder_params if p.requires_grad], 'lr': config.vision_encoder_lr, 'weight_decay': config.weight_decay},
        {'params': [p for p in image_head_params if p.requires_grad], 'lr': config.projection_lr, 'weight_decay': config.weight_decay},
        {'params': [p for p in text_encoder_params if p.requires_grad], 'lr': config.text_encoder_lr, 'weight_decay': config.weight_decay},
        {'params': [p for p in text_head_params if p.requires_grad], 'lr': config.projection_lr, 'weight_decay': config.weight_decay},
        {'params': [p for p in logit_scale_param if p.requires_grad], 'lr': config.projection_lr, 'weight_decay': 0.0 }
    ]

    optimizer_grouped_parameters = [g for g in optimizer_grouped_parameters if g['params']]

    if not optimizer_grouped_parameters:
        print("ERROR: No trainable parameters found for the optimizer.")
    else:
        optimizer = optim.AdamW(optimizer_grouped_parameters, lr=config.learning_rate) # Base LR used if param not in group
        print(f"Optimizer AdamW initialized with grouped LRs (Vision: {config.vision_encoder_lr}, Text: {config.text_encoder_lr}, Proj/Temp: {config.projection_lr}), WD: {config.weight_decay}")

        # --- LR Scheduler ---
        if config.scheduler_type == "cosine":
            if hasattr(config, 'total_training_steps') and config.total_training_steps > 0:
                 lr_scheduler = transformers.get_cosine_schedule_with_warmup(
                     optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=config.total_training_steps
                 )
                 print(f"LR Scheduler: Cosine with Warmup ({config.warmup_steps} steps) initialized.")
            else:
                 print("ERROR: total_training_steps not calculated or zero. Cannot init Cosine scheduler.")
                 lr_scheduler = None
        elif config.scheduler_type == "reduce_on_plateau":
            lr_scheduler = ReduceLROnPlateau(
                optimizer, mode=config.mode, factor=config.rop_factor, patience=config.rop_patience
            )
            print(f"LR Scheduler: ReduceLROnPlateau initialized (mode='{config.mode}', factor={config.rop_factor}, patience={config.rop_patience})")
        else:
            print("No LR Scheduler specified.")
            lr_scheduler = None

        # --- Automatic Mixed Precision (AMP) Scaler ---
        if config.use_amp:
            scaler = torch.amp.GradScaler('cuda') # <<< CORRECTED
            print("AMP GradScaler initialized.")
        else:
            scaler = None

        # Early stopping setup
        early_stopping_counter = 0
        best_val_metric = -float('inf') if config.mode == "max" else float('inf')
        print(f"Early stopping enabled with patience: {config.early_stopping_patience}")

else:
    print("ERROR: Model not initialized or train_loader not available. Skipping optimizer/scheduler setup.")

\nInitializing ViPhobertSiglip model components...
Initializing SigLIP Vision Encoder from: google/siglip-base-patch16-224
  SigLIP Vision model loaded successfully.
  Confirmed/Using vision model hidden size: 768
  Added projection head: 768 -> 768
Initializing Text Encoder: vinai/phobert-base
  Confirmed text model hidden size: 768
  Added projection head: 768 -> 768
Using learnable temperature (logit_scale), initialized to 14.2857
\nViPhobertSiglip Model initialized successfully on cuda.
Total parameters: 229.06 M
Trainable parameters: 229.06 M
\nSetting up optimizer...
Optimizer AdamW initialized with grouped LRs (Vision: 1e-05, Text: 2e-05, Proj/Temp: 0.0001), WD: 0.0001
LR Scheduler: ReduceLROnPlateau initialized (mode='max', factor=0.8, patience=2)
AMP GradScaler initialized.
Early stopping enabled with patience: 5


In [11]:
# === Cell 11: Training and Validation Functions (Using Contrastive Loss) ===
import traceback

def train_step(model, batch, optimizer, scaler, device, use_amp):
    """ Performs a single training step with CONTRASTIVE loss and optional AMP """
    model.train()

    pixel_values = batch['pixel_values']
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']

    with torch.cuda.amp.autocast(enabled=use_amp):
        logits_per_image, logits_per_text, _, _ = model(pixel_values, input_ids, attention_mask)
        loss = contrastive_loss(logits_per_image, logits_per_text) # <<< USE CONTRASTIVE LOSS

    if use_amp:
        scaler.scale(loss).backward()
    else:
        loss.backward()

    return loss.item()

def validate_epoch(model, dataloader, device):
    """ Performs validation, returning metrics """
    model.eval()
    all_image_embeddings = []
    all_text_embeddings = []

    progress_bar = tqdm(dataloader, desc=f"Validation", leave=False, unit="batch")

    with torch.no_grad():
        for batch in progress_bar:
            pixel_values = batch['pixel_values']
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']

            with torch.cuda.amp.autocast(enabled=config.use_amp):
                logits_per_image, logits_per_text, image_embeds_norm, text_embeds_norm = model(pixel_values, input_ids, attention_mask)

            all_image_embeddings.append(image_embeds_norm.cpu())
            all_text_embeddings.append(text_embeds_norm.cpu())

    if not all_image_embeddings or not all_text_embeddings:
         print("Warning: No embeddings collected during validation.")
         return { "loss": float('inf'), "avg_acc": 0.0, "avg_cosine_sim": 0.0,
                  "i2t recall R@1": 0.0, "i2t recall R@5": 0.0, "i2t recall R@10": 0.0,
                  "t2i recall R@1": 0.0, "t2i recall R@5": 0.0, "t2i recall R@10": 0.0 }

    try:
        all_image_embeddings = torch.cat(all_image_embeddings, dim=0)
        all_text_embeddings = torch.cat(all_text_embeddings, dim=0)
    except Exception as e:
        print(f"Error concatenating embeddings: {e}")
        return { "loss": float('inf'), "avg_acc": 0.0, "avg_cosine_sim": 0.0,
                 "i2t recall R@1": 0.0, "i2t recall R@5": 0.0, "i2t recall R@10": 0.0,
                 "t2i recall R@1": 0.0, "t2i recall R@5": 0.0, "t2i recall R@10": 0.0 }

    # --- Log Temp/Bias (if learnable) ---
    current_temp_val = model.logit_scale.exp().item() if isinstance(model.logit_scale, nn.Parameter) else (1 / config.temperature)
    print(f"DEBUG: Validation - Current Temp (exp(logit_scale)): {current_temp_val:.4f}")
    # No bias term in this model version

    print(f"\\nComputing metrics over {all_image_embeddings.shape[0]} validation samples...")
    validation_metrics = compute_metrics(all_image_embeddings.to(device), all_text_embeddings.to(device))

    # Format results
    final_results = {}
    for k, v in validation_metrics.items():
        if isinstance(v, dict):
            for recall_k, recall_v in v.items(): final_results[f"{k.replace('_', ' ')} {recall_k}"] = recall_v
        else: final_results[k.replace('_', ' ')] = v

    del all_image_embeddings, all_text_embeddings
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

    return final_results

print("Training step (contrastive) and validation epoch functions defined.")

Training step (contrastive) and validation epoch functions defined.


In [12]:
# === Cell 12: Fine-tuning Loop ===
import datetime

if model and train_loader and optimizer:  # Basic check
    print(f"\\nStarting ViPhobertSiglip fine-tuning for {config.epochs} epochs...") # Updated print
    print(f"Target metric for saving best model: '{config.metric_to_track}' (mode: {config.mode})")

    best_val_metric = -float('inf') if config.mode == "max" else float('inf')
    global_step = 0
    total_loss_since_log = 0.0
    steps_since_log = 0
    start_train_time = time.time()
    early_stopping_counter = 0 # Initialize here

    history = {'steps': [], 'train_loss': [], 'val_metrics': {}} # Use steps for logging x-axis

    model.train()

    for epoch in range(config.epochs):
        epoch_start_time = time.time()
        print(f"\\n--- Epoch {epoch+1}/{config.epochs} ---")
        progress_bar = tqdm(train_loader, desc=f"Training E{epoch+1}", leave=True, unit="batch")
        epoch_loss_meter = AvgMeter(f"Train Loss E{epoch+1}") # Track epoch average loss

        for i, batch in enumerate(progress_bar):
            # Skip dummy batches if any errors occurred during data loading
            if batch['pixel_values'].shape[0] < config.batch_size and torch.all(batch['pixel_values'] == 0):
                continue

            # --- Training Step ---
            loss = train_step(model, batch, optimizer, scaler, config.device, config.use_amp)
            epoch_loss_meter.update(loss, batch['pixel_values'].shape[0]) # Update epoch meter

            # Accumulate loss for logging interval
            loss_normalized_for_log = loss / config.accumulation_steps
            total_loss_since_log += loss_normalized_for_log
            steps_since_log += 1

            # --- Gradient Accumulation & Optimizer Step ---
            is_update_step = (global_step + 1) % config.accumulation_steps == 0
            if is_update_step:
                # Unscale gradients before clipping (if needed) and optimizer step
                if config.use_amp:
                    scaler.unscale_(optimizer) # Unscales the gradients of optimizer's assigned params in-place
                    # Optional: Gradient Clipping
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    # Optional: Gradient Clipping
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                optimizer.zero_grad()

                # --- LR Scheduler Step (per optimizer step for Cosine, skipped for RoP here) ---
                if lr_scheduler and config.scheduler_type == "cosine":
                    lr_scheduler.step()

            global_step += 1 # Increment global step after processing a batch

            # --- Logging ---
            if global_step % config.log_interval_steps == 0 and steps_since_log > 0:
                avg_loss = total_loss_since_log / steps_since_log
                current_lr = optimizer.param_groups[0]['lr'] # Get first group's LR for logging
                progress_bar.set_postfix(loss=f"{avg_loss:.4f}", lr=f"{current_lr:.2e}", step=f"{global_step}")
                history['steps'].append(global_step)
                history['train_loss'].append(avg_loss)
                total_loss_since_log = 0.0
                steps_since_log = 0

            # --- Validation & Checkpointing (Based on Steps) ---
            if dev_loader and global_step % config.validation_interval_steps == 0 and global_step > 0:
                print(f"\\nRunning validation at step {global_step}...")
                val_start_time = time.time()
                val_results = validate_epoch(model, dev_loader, config.device)
                val_end_time = time.time()
                print(f"Validation finished in {val_end_time - val_start_time:.2f}s")

                metric_log_str = f"  Validation Step {global_step}: "
                history['val_metrics'][global_step] = val_results
                sorted_keys = sorted(val_results.keys())
                for name in sorted_keys:
                    metric_log_str += f"{name}: {val_results[name]:.4f} | "
                print(metric_log_str.strip(" | "))

                # --- Scheduler Step (ReduceLROnPlateau) ---
                current_val_metric_for_scheduler = val_results.get(config.metric_to_track.replace('_', ' '), None)
                if lr_scheduler and config.scheduler_type == "reduce_on_plateau":
                     if current_val_metric_for_scheduler is not None:
                         lr_scheduler.step(current_val_metric_for_scheduler)
                         print(f"  RoP Scheduler step called with {config.metric_to_track}={current_val_metric_for_scheduler:.4f}")
                     else:
                         print(f"  Warning: Metric '{config.metric_to_track}' not found. RoP Scheduler not stepped.")

                # --- Save Checkpoint Logic ---
                current_val_metric = val_results.get(config.metric_to_track.replace('_', ' '), None)
                is_best = False
                save_path = None
                save_path_periodic = None

                if current_val_metric is not None:
                    improvement_threshold = best_val_metric + config.early_stopping_min_delta if config.mode == "max" else best_val_metric - config.early_stopping_min_delta
                    if config.mode == "max": is_best = current_val_metric > improvement_threshold
                    else: is_best = current_val_metric < improvement_threshold

                    if is_best:
                        print(f"  Metric '{config.metric_to_track}' improved from {best_val_metric:.4f} to {current_val_metric:.4f}. Saving best model.")
                        best_val_metric = current_val_metric
                        early_stopping_counter = 0 # Reset counter
                        # --- Use new checkpoint name ---
                        save_path = os.path.join(config.model_path, "phobert_siglip_best.pt")
                    else:
                        early_stopping_counter += 1
                        print(f"  Metric '{config.metric_to_track}' did not improve. Best: {best_val_metric:.4f}. Counter: {early_stopping_counter}/{config.early_stopping_patience}")

                    # Save periodic checkpoint
                    if global_step % config.save_interval_steps == 0:
                        # --- Use new checkpoint name ---
                        periodic_save_path = os.path.join(config.model_path, f"phobert_siglip_step_{global_step}.pt")
                        if save_path != periodic_save_path:
                            print(f"  Saving periodic checkpoint to {periodic_save_path}")
                            save_path_periodic = periodic_save_path

                    # Prepare Save Dictionary & Save
                    if save_path or save_path_periodic:
                        save_dict = {
                            'step': global_step, 'epoch': epoch + 1,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'best_val_metric': best_val_metric,
                            'metric_tracked': config.metric_to_track,
                            'current_val_metrics': val_results,
                            # Save relevant configs
                            'vision_model_name': config.vision_model_name, # Use specific names
                            'text_model_name': config.text_model_name,
                            'projection_dim': config.projection_dim,
                            'learnable_temperature': config.learnable_temperature,
                            'temperature': config.temperature, # Save base temp
                            'max_length': config.max_length,
                        }
                        if lr_scheduler: save_dict['scheduler_state_dict'] = lr_scheduler.state_dict()
                        if scaler: save_dict['scaler_state_dict'] = scaler.state_dict()
                        if save_path: torch.save(save_dict, save_path)
                        if save_path_periodic: torch.save(save_dict, save_path_periodic)
                else:
                    print(f"  Warning: Metric '{config.metric_to_track}' not found. Cannot save best or check early stopping.")
                    early_stopping_counter += 1 # Still count as no improvement

                # --- Early Stopping Check ---
                if early_stopping_counter >= config.early_stopping_patience:
                    print(f"\\nEarly stopping triggered after {early_stopping_counter} validation checks without improvement.")
                    break # Break INNER loop

                model.train() # Reset model to train mode

        # --- End of Epoch ---
        epoch_end_time = time.time()
        print(f"--- Epoch {epoch+1} Time: {datetime.timedelta(seconds=epoch_end_time - epoch_start_time)} ---")
        print(f"--- Average Train Loss for Epoch {epoch+1}: {epoch_loss_meter.avg:.4f} ---") # Log epoch loss

        # Break OUTER loop if early stopping was triggered
        if early_stopping_counter >= config.early_stopping_patience:
           break

    # --- End of Training ---
    end_train_time = time.time()
    total_duration = datetime.timedelta(seconds=end_train_time - start_train_time)
    print(f"=============== Fine-tuning Finished ================") # Updated print
    print(f"Total Training Time: {total_duration}")

    # Save final model state
    # --- Use new checkpoint name ---
    final_model_path = os.path.join(config.model_path, 'phobert_siglip_final.pt')
    final_save_dict = {
        'step': global_step, 'epoch': epoch + 1, # Save last completed epoch
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_metric': best_val_metric,
        'metric_tracked': config.metric_to_track,
        'vision_model_name': config.vision_model_name,
        'text_model_name': config.text_model_name,
        'projection_dim': config.projection_dim,
        'learnable_temperature': config.learnable_temperature,
        'temperature': config.temperature,
        'max_length': config.max_length,
    }
    if lr_scheduler: final_save_dict['scheduler_state_dict'] = lr_scheduler.state_dict()
    if scaler: final_save_dict['scaler_state_dict'] = scaler.state_dict()
    torch.save(final_save_dict, final_model_path)
    print(f"Final model state saved to {final_model_path}")

    # --- Use new checkpoint name ---
    best_model_file = os.path.join(config.model_path, "phobert_siglip_best.pt")
    if dev_loader and os.path.exists(best_model_file):
        print(f"Best model based on '{config.metric_to_track}' ({best_val_metric:.4f}) is saved at: {best_model_file}")
    elif dev_loader: print("Best model checkpoint file not found (or validation was skipped/no improvement).")
    print(f"=================================================")

else:
    print("ERROR: Prerequisites for training (model, dataloader, optimizer) not met. Training loop skipped.")

\nStarting ViPhobertSiglip fine-tuning for 5 epochs...
Target metric for saving best model: 'avg_acc' (mode: max)
\n--- Epoch 1/5 ---


Training E1:   0%|          | 0/322 [00:00<?, ?batch/s]

  with torch.cuda.amp.autocast(enabled=use_amp):


--- Epoch 1 Time: 0:01:31.161567 ---
--- Average Train Loss for Epoch 1: 2.0360 ---
\n--- Epoch 2/5 ---


Training E2:   0%|          | 0/322 [00:00<?, ?batch/s]

In [13]:
# === Cell 13: Final Evaluation on Test Set (Updated for Phobert+Siglip) ===
import traceback
from types import SimpleNamespace
from collections import OrderedDict

print("\\n=============== Starting Test Set Evaluation ===============")

test_loader = None
model_to_test = None

test_json_path = os.path.join(config.data_path, "test.json")
test_image_path = config.image_base_path

# 1. Check prerequisites & Create Test Loader
if os.path.exists(test_json_path) and 'tokenizer' in globals() and tokenizer and 'image_processor' in globals() and image_processor:
    print(f"Loading test data from: {test_json_path}")
    try:
        test_dataset = CustomImageCaptionDataset(
            json_path_or_list=test_json_path, image_base_path=test_image_path,
            tokenizer=tokenizer, image_processor=image_processor,
            max_length=config.max_length
        )
        if test_dataset.data:
            num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
            persist_workers_test = (num_workers > 0)
            try: _ = DataLoader(test_dataset, num_workers=num_workers, persistent_workers=persist_workers_test)
            except TypeError: persist_workers_test = False

            test_loader = DataLoader(
                test_dataset, batch_size=config.batch_size * 2, shuffle=False,
                num_workers=num_workers, pin_memory=True if config.device == torch.device("cuda") else False,
                drop_last=False, persistent_workers=persist_workers_test
            )
            print(f"Test loader created with {len(test_loader)} batches.")
        else: print("Test dataset loaded but is empty.")
    except Exception as e: print(f"Error creating test dataset/loader: {e}")
else: print("Skipping test evaluation: Test JSON, Tokenizer or Image Processor not found/loaded.")


# 2. Load Model for Testing
if test_loader:
    try:
        # --- Use updated checkpoint names ---
        best_model_path = os.path.join(config.model_path, "phobert_siglip_best.pt")
        final_model_path = os.path.join(config.model_path, "phobert_siglip_final.pt")
        load_path = None

        if os.path.exists(best_model_path):
             load_path = best_model_path
             print(f"\\nLoading best model: {load_path}")
        elif os.path.exists(final_model_path):
             load_path = final_model_path
             print(f"\\nLoading final model (best not found): {load_path}")
        else:
            print(f"\\nWARNING: No checkpoints found in {config.model_path} to evaluate.")

        if load_path:
            checkpoint = torch.load(load_path, map_location=config.device)
            print("Re-creating model structure for testing...")

            # --- Create temp config based on saved checkpoint ---
            temp_config_dict = {
                'device': config.device,
                # Use model names saved in checkpoint
                'vision_model_name': checkpoint.get('vision_model_name', config.selected_vision_source),
                'text_model_name': checkpoint.get('text_model_name', config.selected_text_model),
                # Use embedding sizes from current config (should match base models)
                'vision_embedding': config.vision_embedding,
                'text_embedding': config.text_embedding,
                # Get these from checkpoint or current config
                'projection_dim': checkpoint.get('projection_dim', config.projection_dim),
                'learnable_temperature': checkpoint.get('learnable_temperature', config.learnable_temperature),
                'temperature': checkpoint.get('temperature', config.temperature),
                # Bias is not used/saved in this setup
                'learnable_bias': False,
                'bias_init': 0.0,
            }
            temp_config = SimpleNamespace(**temp_config_dict)

            print(f"  Using Vision Source: {temp_config.vision_model_name}")
            print(f"  Using Text Model: {temp_config.text_model_name}")

            # --- Instantiate the CORRECT model class ---
            test_image_encoder = ImageEncoder(temp_config).to(config.device)
            test_text_encoder = TextEncoder(temp_config).to(config.device)
            model_to_test = ViPhobertSiglipModel(test_image_encoder, test_text_encoder, temp_config).to(config.device)
            # -----------------------------------------

            state_dict = checkpoint['model_state_dict']
            if all(k.startswith('module.') for k in state_dict.keys()):
                print("Detected 'module.' prefix, removing.")
                state_dict = OrderedDict((k[7:], v) for k, v in state_dict.items())

            load_result = model_to_test.load_state_dict(state_dict, strict=False)
            print(f"  State dict loading result: {load_result}")
            if load_result.missing_keys: print(f"  Warning: Missing keys: {load_result.missing_keys}")
            if load_result.unexpected_keys: print(f"  Warning: Unexpected keys: {load_result.unexpected_keys}")
            print(f"Model weights loaded successfully.")

            print("\\nRunning evaluation on test set...")
            test_results = validate_epoch(model_to_test, test_loader, config.device)

            print("\\n--- Test Set Results ---")
            metric_log_str = ""
            sorted_keys = sorted(test_results.keys())
            for name in sorted_keys: metric_log_str += f"  {name}: {test_results[name]:.4f}\\n"
            print(metric_log_str.strip())
            print("------------------------")
        else:
            print("Evaluation skipped (no weights found).")
    except Exception as e:
        print(f"\\nERROR during test setup/evaluation: {e}")
        traceback.print_exc()

print("\\n================= Evaluation Finished ==================")

Loading test data from: ./json_data/test.json
Loading JSON metadata...


Loading JSONs:   0%|          | 0/1 [00:00<?, ?it/s]

Loaded 10001 samples total from 1 file(s).
Dataset size after potential cleaning: 10001
Using image target size: 224x224
Test loader created with 79 batches.
\nLoading best model: ./trained_models/ViPhobertSiglip_uitopenviic/phobert_siglip_best.pt
Re-creating model structure for testing...
  Using Vision Source: google/siglip-base-patch16-224
  Using Text Model: vinai/phobert-base
Initializing SigLIP Vision Encoder from: google/siglip-base-patch16-224
  SigLIP Vision model loaded successfully.
  Confirmed/Using vision model hidden size: 768
  Added projection head: 768 -> 512
Initializing Text Encoder: vinai/phobert-base
  Confirmed text model hidden size: 768
  Added projection head: 768 -> 512
Using learnable temperature (logit_scale), initialized to 14.2857
  State dict loading result: <All keys matched successfully>
Model weights loaded successfully.
\nRunning evaluation on test set...


Validation:   0%|          | 0/79 [00:00<?, ?batch/s]

  with torch.cuda.amp.autocast(enabled=config.use_amp):


DEBUG: Validation - Current Temp (exp(logit_scale)): 16.5427
\nComputing metrics over 10001 validation samples...
\n--- Test Set Results ---
avg acc: 0.0819\n  avg cosine sim: 0.5634\n  i2t acc: 0.0967\n  i2t recall R@1: 0.0967\n  i2t recall R@10: 0.4508\n  i2t recall R@5: 0.3272\n  t2i acc: 0.0672\n  t2i recall R@1: 0.0672\n  t2i recall R@10: 0.4499\n  t2i recall R@5: 0.3293\n
------------------------


In [14]:
# === Cell 14: Training Visualization (Adapted for Steps/Epochs) ===
# This function plots based on epochs if validation runs per epoch,
# or steps if validation runs based on steps.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import math # Import math

def plot_training_metrics(history, plot_dir, plot_by='epoch'):
    """Plots training and validation metrics."""

    os.makedirs(plot_dir, exist_ok=True)
    print(f"Plot directory ensured at: {os.path.abspath(plot_dir)}")

    if not history:
        print("No history data provided.")
        return

    # Determine x-axis based on available data and preference
    if plot_by == 'step' and history.get('steps') and history.get('train_loss'):
        x_axis_train = history['steps']
        x_label = 'Global Steps'
        x_axis_val = sorted(history.get('val_metrics', {}).keys()) if history.get('val_metrics') else []
    elif history.get('train_loss') and history.get('validation_results'):
         # Use epochs if validation results are stored per epoch
         num_epochs = len(history['train_loss'])
         x_axis_train = range(1, num_epochs + 1)
         x_axis_val = range(1, len(history['validation_results']) + 1)
         x_label = 'Epoch'
         plot_by = 'epoch' # Force epoch plotting if step data is missing for val
    else:
        print("Insufficient history data (need train_loss and either steps or validation_results per epoch).")
        return

    val_metrics_data = history.get('val_metrics', {}) if plot_by == 'step' else history.get('validation_results', [])

    # --- Training Loss ---
    plt.figure(figsize=(10, 6))
    plt.plot(x_axis_train, history['train_loss'], 'b-', label=f'Training Loss (Avg per Log Interval if steps)')
    plt.xlabel(x_label)
    plt.ylabel('Loss')
    plt.title(f'Training Loss over {x_label.capitalize()}')

    # Plot validation loss if available (assuming it's stored in val_results/val_metrics)
    if val_metrics_data:
        try:
            if plot_by == 'step':
                val_loss = [val_metrics_data[step].get('loss', float('nan')) for step in x_axis_val]
            else: # Plot by epoch
                val_loss = [res.get('loss', float('nan')) for res in val_metrics_data if res]
            if any(not math.isnan(vl) for vl in val_loss): # Only plot if loss was calculated and stored
                 plt.plot(x_axis_val, val_loss, 'r-', label='Validation Loss')
        except (KeyError, TypeError):
             print("Validation loss not found or incorrectly formatted in history.")

    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    save_path_loss = os.path.join(plot_dir, f'training_loss_{plot_by}.png')
    plt.savefig(save_path_loss, dpi=300)
    print(f"Saved training loss plot to: {save_path_loss}")
    plt.close()

    # --- Validation Metrics ---
    if val_metrics_data and x_axis_val:
        # Get metric names from the first valid entry
        first_valid_val_result = next((res for res in (val_metrics_data.values() if plot_by == 'step' else val_metrics_data) if res and isinstance(res, dict)), None)
        if first_valid_val_result:
            metrics_to_plot = [k for k in first_valid_val_result.keys() if k != 'loss'] # Exclude loss

            num_plots = len(metrics_to_plot)
            if num_plots > 0:
                ncols = 2
                nrows = math.ceil(num_plots / ncols)
                fig, axes = plt.subplots(nrows, ncols, figsize=(8 * ncols, 6 * nrows), squeeze=False)
                axes = axes.flatten()

                for i, metric_name in enumerate(metrics_to_plot):
                    if plot_by == 'step':
                        metric_values = [val_metrics_data[step].get(metric_name, float('nan')) for step in x_axis_val]
                    else: # Plot by epoch
                         metric_values = [res.get(metric_name, float('nan')) for res in val_metrics_data if res]

                    if any(not math.isnan(v) for v in metric_values): # Check if metric has valid data
                        axes[i].plot(x_axis_val, metric_values, 'r-o', label=f'Validation {metric_name}')
                        axes[i].set_xlabel(x_label)
                        axes[i].set_ylabel(metric_name.replace('_', ' ').capitalize())
                        axes[i].set_title(f'Validation {metric_name} over {x_label.capitalize()}')
                        axes[i].legend()
                        axes[i].grid(True)
                    else:
                         axes[i].set_title(f'Validation {metric_name} (No Data)')
                         axes[i].text(0.5, 0.5, 'No Data', ha='center', va='center')


                for j in range(i + 1, len(axes)): fig.delaxes(axes[j]) # Hide unused subplots

                fig.suptitle(f'Validation Metrics over {x_label.capitalize()}', fontsize=16, y=1.02)
                plt.tight_layout(rect=[0, 0, 1, 0.98])
                save_path_val = os.path.join(plot_dir, f'validation_metrics_{plot_by}.png')
                plt.savefig(save_path_val, dpi=300)
                print(f"Saved validation metrics plot to: {save_path_val}")
                plt.close()
            else: print("No validation metrics (excluding loss) found to plot.")
        else: print("No valid validation results found.")
    else: print("No validation metrics found in history to plot.")


# --- Plotting ---
# Decide whether to plot by 'step' or 'epoch' based on how validation was run
plot_directory = "./train_plot/ViSigLIP_uitopenviic"
plotting_mode = 'step' if config.validation_interval_steps > 0 else 'epoch'

if 'history' in locals() and isinstance(history, dict):
    plot_training_metrics(history, plot_directory, plot_by=plotting_mode)
else:
    print("No training history found. Run training first.")

# --- END OF SCRIPT ---

Plot directory ensured at: /home/researcher/huypq69/TuningModels/train_plot/ViSigLIP_uitopenviic
Saved training loss plot to: ./train_plot/ViSigLIP_uitopenviic/training_loss_step.png
Saved validation metrics plot to: ./train_plot/ViSigLIP_uitopenviic/validation_metrics_step.png
