In [None]:
# === Cell 1: Imports ===
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# --- Import SigLIP Vision model and its Processor ---
from transformers import SiglipVisionModel, SiglipConfig, SiglipImageProcessor
# --- Keep PhoBERT parts ---
from transformers import AutoModel, AutoTokenizer, AutoConfig

from PIL import Image
import json
import os
import random
import numpy as np
import math
import time
import transformers
import gc
import traceback

try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm

print(f"PyTorch Version: {torch.__version__}")
print(f"Transformers Version: {transformers.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    device_idx = 1 # Or 0
    print(f"CUDA Device Name: {torch.cuda.get_device_name(device_idx)}")
    torch.cuda.set_device(device_idx)

In [None]:
# === Cell 2: Configuration Class (CFG) - Configured for SigLIP Loss ===
class CFG:
    # --- Paths ---
    data_path = "./json_data/"
    image_base_path = "./data/UIT-OpenViIC-dataset/"
    model_path = "./trained_models/ViSiglip_uitopenviic_SigmoidLoss"

    # --- Model Selection ---
    selected_vision_source = "google/siglip-base-patch16-224"
    selected_text_model = "vinai/phobert-base"
    text_tokenizer_name = selected_text_model

    # --- Model parameters ---
    vision_model_name = selected_vision_source
    text_model_name = selected_text_model
    image_processor_name = selected_vision_source

    @property
    def text_embedding(self): return 768
    @property
    def vision_embedding(self): return 768

    projection_dim = 768 # Match SigLIP paper's projection dim

    # --- SigLIP Loss specific Parameters ---
    learnable_temperature = True
    temperature_init = 10.0 # SigLIP often initializes temperature higher
    learnable_bias = True
    bias_init = -10.0 # SigLIP initializes bias negative

    # --- Training parameters ---
    seed = 42
    # SigLIP benefits from large batches, use accumulation
    batch_size = 32   # Keep reduced per-device batch size for 24GB VRAM
    accumulation_steps = 64 # Increase accumulation (Effective 2048)
    num_workers = 20

    # --- Learning Rates (May need adjustment for SigLIP loss) ---
    # Might use a single LR for simplicity when using SigLIP loss
    learning_rate = 1e-4 # Base LR (AdamW default)
    projection_lr = 1e-4 # LR for projection, temp, bias
    vision_encoder_lr = 1e-5 # Lower LR for backbones
    text_encoder_lr = 2e-5
    weight_decay = 0.1 # Higher WD often used with SigLIP pretraining

    # --- Scheduler ---
    scheduler_type = "cosine" # Cosine often used for longer training
    warmup_steps = 1000 # Number of warmup steps for Cosine scheduler
    rop_patience = 3  # ReduceLROnPlateau parameters (if used)
    rop_factor = 0.8

    epochs = 50 # Increase epochs for more training steps

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = True

    # --- Image/Text parameters ---
    max_length = 77 # Standard CLIP length

    # --- Saving parameters ---
    save_best_only = True
    metric_to_track = "avg_acc" # Still track retrieval accuracy for validation
    mode = "max"
    save_interval_steps = 2000 # Save checkpoints periodically
    validation_interval_steps = 1000 # Validate periodically
    log_interval_steps = 50

    early_stopping_patience = 10 # Allow more patience if convergence is slower
    early_stopping_min_delta = 0.001

# --- Instantiate Config and Create Output Dir ---
config = CFG()
os.makedirs(config.model_path, exist_ok=True)
print(f"Using device: {config.device}")
print(f"Per-Device Batch Size: {config.batch_size}")
print(f"Accumulation Steps: {config.accumulation_steps}")
print(f"Effective Batch Size (per optimizer step): {config.batch_size * config.accumulation_steps}")
print(f"Model output path: {config.model_path}")
print(f"Selected Vision Source: {config.selected_vision_source}")
print(f"Selected Text Model: {config.selected_text_model}")
print(f"Image base path (for resolving paths in JSON): {os.path.abspath(config.image_base_path)}")
print(f"AMP Enabled: {config.use_amp}")

In [None]:
# === Cell 3: Seeding ===
def set_seed(seed=config.seed):
    print(f"Setting seed: {seed}")
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed()

In [None]:
# === Cell 4: Metric & AvgMeter Utilities ===
# (Keep AvgMeter, compute_recall_at_k, compute_metrics as before)
# These metrics are used for validation, even if training loss is different
class AvgMeter:
    def __init__(self, name="Metric"): self.name = name; self.reset()
    def reset(self): self.sum = 0; self.count = 0; self.avg = 0.0
    def update(self, val, count=1):
        if torch.is_tensor(val): val = val.item()
        if isinstance(val, (int, float)):
            self.sum += val * count; self.count += count
            self.avg = float(self.sum) / self.count if self.count != 0 else 0.0
    def __repr__(self): return f"{self.name}: {self.avg:.4f}"

def compute_recall_at_k(similarity_matrix, k, dim):
    n = similarity_matrix.shape[1-dim]; correct_count = 0
    if n == 0: return 0.0
    actual_k = min(k, similarity_matrix.shape[dim])
    if actual_k == 0: return 0.0
    top_k_indices = torch.topk(similarity_matrix, actual_k, dim=dim).indices
    ground_truth = torch.arange(n, device=similarity_matrix.device)
    if dim == 0: # I2T
        for i in range(n): correct_count += ground_truth[i] in top_k_indices[:, i]
    elif dim == 1: # T2I
        for i in range(n): correct_count += ground_truth[i] in top_k_indices[i, :]
    else: raise ValueError("dim must be 0 or 1")
    return float(correct_count) / n

def compute_metrics(image_embeddings, text_embeddings):
    sim_matrix = text_embeddings.float() @ image_embeddings.float().T; n = sim_matrix.shape[0]
    if n == 0: return {"i2t_acc": 0.0, "t2i_acc": 0.0, "avg_acc": 0.0, "avg_cosine_sim": 0.0, "i2t_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0}, "t2i_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0}}
    ground_truth = torch.arange(n, device=sim_matrix.device); i2t_preds = torch.argmax(sim_matrix, dim=0); t2i_preds = torch.argmax(sim_matrix, dim=1)
    i2t_acc = (i2t_preds == ground_truth).float().mean().item(); t2i_acc = (t2i_preds == ground_truth).float().mean().item(); avg_acc = (i2t_acc + t2i_acc) / 2.0
    avg_cosine_sim = torch.diag(sim_matrix).mean().item(); i2t_recall = {}; t2i_recall = {}
    recall_k_values = [k for k in [1, 5, 10] if k <= n]
    for k in recall_k_values: i2t_recall[f"R@{k}"] = compute_recall_at_k(sim_matrix, k, dim=0); t2i_recall[f"R@{k}"] = compute_recall_at_k(sim_matrix, k, dim=1)
    for k in [1, 5, 10]: k_str = f"R@{k}"; i2t_recall.setdefault(k_str, 0.0); t2i_recall.setdefault(k_str, 0.0)
    return {"i2t_acc": i2t_acc, "t2i_acc": t2i_acc, "avg_acc": avg_acc, "avg_cosine_sim": avg_cosine_sim, "i2t_recall": i2t_recall, "t2i_recall": t2i_recall}

print("Metric utilities defined.")

In [None]:
# === Cell 5: Dataset Class Definition ===
# (No changes needed from the previous corrected version)

import traceback
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

class CustomImageCaptionDataset(Dataset):
    def __init__(self, json_path_or_list, image_base_path, tokenizer, image_processor, max_length):
        super().__init__()
        self.data = []
        if isinstance(json_path_or_list, str) and os.path.isdir(json_path_or_list):
             json_files = [os.path.join(json_path_or_list, f) for f in os.listdir(json_path_or_list) if f.endswith('.json')]
             print(f"Found {len(json_files)} JSON files in {json_path_or_list}")
        elif isinstance(json_path_or_list, str) and os.path.isfile(json_path_or_list):
            json_files = [json_path_or_list]
        elif isinstance(json_path_or_list, list):
            json_files = json_path_or_list
        else:
            raise ValueError("json_path_or_list must be a directory, a single JSON file, or a list of JSON files.")

        print("Loading JSON metadata...")
        total_loaded_count = 0
        for json_path in tqdm(json_files, desc="Loading JSONs"):
            try:
                with open(json_path, 'r', encoding='utf-8') as f:
                    try:
                        file_data = json.load(f)
                        if isinstance(file_data, list):
                             self.data.extend(file_data)
                             total_loaded_count += len(file_data)
                        else:
                             self.data.append(file_data)
                             total_loaded_count += 1
                    except json.JSONDecodeError:
                        print(f"  Info: Failed to load {json_path} as single JSON. Attempting JSON-per-line format...")
                        f.seek(0)
                        count_line_by_line = 0
                        for line in f:
                            line = line.strip()
                            if line:
                                try:
                                    line_data = json.loads(line)
                                    self.data.append(line_data)
                                    count_line_by_line += 1
                                except json.JSONDecodeError as line_err:
                                     print(f"  ERROR parsing line in {json_path}: {line_err}. Line content (partial): {line[:100]}...")
                        total_loaded_count += count_line_by_line
                        if count_line_by_line == 0: print(f"  Failed to load any data using JSON-per-line format from {json_path} either.")
            except Exception as e:
                print(f"ERROR opening or processing file {json_path}: {e}")

        print(f"Loaded {total_loaded_count} samples total from {len(json_files)} file(s).")
        self.data = [item for item in self.data if item]
        print(f"Dataset size after potential cleaning: {len(self.data)}")

        if not self.data: print("WARNING: No data loaded!")

        self.image_base_path = image_base_path
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        try:
            self.img_size = image_processor.config.image_size
        except AttributeError:
            try: self.img_size = image_processor.size['height']
            except (AttributeError, TypeError, KeyError): self.img_size = 224; print("Warning: Defaulting image size to 224.")
        print(f"Using image target size: {self.img_size}x{self.img_size}")
        if not os.path.isdir(self.image_base_path): print(f"WARNING: Image base path does not exist: {os.path.abspath(self.image_base_path)}")

    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        if idx >= len(self.data): raise IndexError("Index out of bounds")
        item = self.data[idx]
        relative_image_path = item.get('image_path', item.get('url', item.get('filename')))
        caption_data = item.get('caption', item.get('text', item.get('title', '')))
        caption = caption_data[0] if isinstance(caption_data, list) and caption_data else caption_data if isinstance(caption_data, str) else ""
        if not relative_image_path or not caption: return self._get_dummy_item()
        try:
            image_path = os.path.join(self.image_base_path, relative_image_path)
            image = Image.open(image_path).convert('RGB')
            image_inputs = self.image_processor(images=image, return_tensors="pt")
            pixel_values = image_inputs['pixel_values'].squeeze(0)
        except Exception: return self._get_dummy_item()
        try:
            text_inputs = self.tokenizer(caption, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
            input_ids = text_inputs['input_ids'].squeeze(0)
            attention_mask = text_inputs['attention_mask'].squeeze(0)
        except Exception: return self._get_dummy_item()
        return {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}
    def _get_dummy_item(self): return {"pixel_values": torch.zeros((3, self.img_size, self.img_size), dtype=torch.float), "input_ids": torch.zeros(self.max_length, dtype=torch.long), "attention_mask": torch.zeros(self.max_length, dtype=torch.long)}

print("CustomImageCaptionDataset class defined.")

In [None]:
# === Cell 6: Model Definition (SigLIP Vision + PhoBERT Text + SigLIP Params) ===

class ImageEncoder(nn.Module):
    """Encodes images using SigLIP's Vision Model."""
    def __init__(self, config_train, pretrained=True):
        super().__init__()
        self.config_train = config_train
        print(f"Initializing SigLIP Vision Encoder from: {config_train.vision_model_name}")
        if pretrained:
            self.vision_model = SiglipVisionModel.from_pretrained(config_train.vision_model_name)
        else:
            siglip_vision_config = SiglipConfig.from_pretrained(config_train.vision_model_name).vision_config
            self.vision_model = SiglipVisionModel(siglip_vision_config)
        self.input_features = self.vision_model.config.hidden_size
        print(f"  Confirmed/Using vision model hidden size: {self.input_features}")
        self.projection = nn.Linear(self.input_features, config_train.projection_dim, bias=False)
        print(f"  Added projection head: {self.input_features} -> {config_train.projection_dim}")

    def forward(self, pixel_values):
        vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=True)
        image_embed = vision_outputs.pooler_output
        projected_features = self.projection(image_embed)
        # Normalization in main model
        return projected_features

class TextEncoder(nn.Module):
    """Encodes text using PhoBERT-Base."""
    def __init__(self, config_train, pretrained=True):
        super().__init__()
        self.config_train = config_train
        print(f"Initializing Text Encoder: {config_train.text_model_name}")
        if pretrained:
            self.model = AutoModel.from_pretrained(config_train.text_model_name)
        else:
            model_config = AutoConfig.from_pretrained(config_train.text_model_name)
            self.model = AutoModel.from_config(model_config)
        self.input_features = self.model.config.hidden_size
        print(f"  Confirmed text model hidden size: {self.input_features}")
        self.projection = nn.Linear(self.input_features, config_train.projection_dim, bias=False)
        print(f"  Added projection head: {self.input_features} -> {config_train.projection_dim}")

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        text_features = outputs.last_hidden_state[:, 0, :]
        projected_features = self.projection(text_features)
        # Normalization in main model
        return projected_features

# --- Combined Model (Using SigLIP temperature & bias) ---
class ViPhobertSiglipLossModel(nn.Module): # New name
    """Combines SigLIP Vision, PhoBERT Text, and SigLIP temp/bias."""
    def __init__(self, image_encoder, text_encoder, config_train):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.config_train = config_train

        if config_train.learnable_temperature:
            init_val_t = torch.tensor(config_train.temperature_init, dtype=torch.float)
            self.logit_scale = nn.Parameter(init_val_t)
            print(f"Using learnable temperature, initialized to {self.logit_scale.item():.4f}")
        else:
            temp_tensor = torch.tensor(config_train.temperature_init, dtype=torch.float)
            self.register_buffer('logit_scale', temp_tensor)
            print(f"Using fixed temperature: {self.logit_scale.item():.4f}")

        if config_train.learnable_bias:
            init_val_b = torch.tensor(config_train.bias_init, dtype=torch.float)
            self.logit_bias = nn.Parameter(init_val_b)
            print(f"Using learnable bias, initialized to {self.logit_bias.item():.4f}")
        else:
            bias_tensor = torch.tensor(config_train.bias_init, dtype=torch.float)
            self.register_buffer('logit_bias', bias_tensor)
            print(f"Using fixed bias: {self.logit_bias.item():.4f}")

    def forward(self, pixel_values, input_ids, attention_mask):
        pixel_values = pixel_values.to(self.config_train.device)
        input_ids = input_ids.to(self.config_train.device)
        attention_mask = attention_mask.to(self.config_train.device)

        image_embed = self.image_encoder(pixel_values)
        text_embed = self.text_encoder(input_ids, attention_mask)

        image_features = F.normalize(image_embed, p=2, dim=-1)
        text_features = F.normalize(text_embed, p=2, dim=-1)

        # Temperature and bias are returned for the SigLIP loss function
        current_temp = self.logit_scale.to(image_features.device)
        current_bias = self.logit_bias.to(image_features.device)

        return image_features, text_features, current_temp, current_bias

print("ViPhobertSiglipLoss Model components defined.")

In [None]:
# === Cell 7: SigLIP Loss Function ===
def siglip_loss(image_features, text_features, logit_scale, logit_bias):
    """ Computes the SigLIP loss. """
    image_features = image_features.float()
    text_features = text_features.float()
    logit_scale = logit_scale.float()
    logit_bias = logit_bias.float()

    n = text_features.shape[0]
    if n == 0:
        return torch.tensor(0.0, device=image_features.device, requires_grad=True)

    # Calculate similarity with scaling and bias
    logits = image_features @ text_features.t() * logit_scale + logit_bias

    # Labels: 1 for positive pairs (diagonal), 0 for negative pairs
    labels = torch.eye(n, device=logits.device, dtype=logits.dtype)

    # Pairwise sigmoid loss using BCEWithLogitsLoss
    loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='mean')

    return loss

print("SigLIP loss function defined.")

In [None]:
# === Cell 8: Setup - Tokenizer and Image Processor (Using SigLIP Processor) ===
from transformers import AutoTokenizer, SiglipImageProcessor

tokenizer = None
image_processor = None

print(f"Loading Tokenizer: {config.text_tokenizer_name}")
try:
    tokenizer = AutoTokenizer.from_pretrained(config.text_tokenizer_name)
    print("PhoBERT Tokenizer loaded successfully.")
except Exception as e:
    print(f"ERROR loading tokenizer '{config.text_tokenizer_name}': {e}")

print(f"Loading Image Processor from: {config.image_processor_name}")
try:
    image_processor = SiglipImageProcessor.from_pretrained(config.image_processor_name)
    print("SigLIP Image Processor loaded successfully.")
except Exception as e:
    print(f"ERROR loading image processor '{config.image_processor_name}': {e}")

# Image transforms are implicitly handled by the processor during dataset __getitem__
# No separate transforms object needed if using the processor directly
print("Image transforms handled by SiglipImageProcessor.")

In [None]:
# === Cell 9: Setup - Datasets and DataLoaders (FIXED Validation Path) ===
train_loader = None
dev_loader = None

validation_json_path = os.path.join(config.data_path, "dev.json")
train_json_path = os.path.join(config.data_path, "train.json")

if tokenizer and image_processor:
    print("\\nCreating datasets...")
    # --- Training Dataset ---
    try:
        print(f"Attempting to load training data from: {train_json_path}")
        train_dataset = CustomImageCaptionDataset(
            json_path_or_list=train_json_path,
            image_base_path=config.image_base_path,
            tokenizer=tokenizer,
            image_processor=image_processor, # Pass the loaded processor
            max_length=config.max_length
        )
        if not train_dataset.data: print("\\nERROR: Failed to load training data.")
    except Exception as e:
        print(f"ERROR creating training dataset: {e}")
        train_dataset = None

    # --- Validation Dataset ---
    if os.path.exists(validation_json_path):
         try:
             print(f"Attempting to load validation data from: {validation_json_path}")
             dev_dataset = CustomImageCaptionDataset(
                 json_path_or_list=validation_json_path,
                 image_base_path=config.image_base_path,
                 tokenizer=tokenizer,
                 image_processor=image_processor, # Use same processor
                 max_length=config.max_length
             )
             if not dev_dataset.data: print("\\nWARNING: Failed to load validation data.")
         except Exception as e:
             print(f"ERROR creating validation dataset: {e}")
             dev_dataset = None
    else:
        print(f"Validation JSON file not found at {validation_json_path}, skipping validation set creation.")
        dev_dataset = None

    print("\\nCreating dataloaders...")
    num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
    print(f"Using {num_workers} workers for DataLoaders.")

    if train_dataset and train_dataset.data:
        persist_workers = (num_workers > 0)
        try: _ = DataLoader(train_dataset, num_workers=num_workers, persistent_workers=persist_workers)
        except TypeError: persist_workers = False;# print("Note: `persistent_workers=True` not supported.")

        train_loader = DataLoader(
            train_dataset, batch_size=config.batch_size, shuffle=True,
            num_workers=num_workers,
            pin_memory=True if config.device == torch.device("cuda") else False,
            drop_last=True,
            persistent_workers=persist_workers
        )
        print(f"Train loader created with {len(train_loader)} batches.")
        if config.scheduler_type == "cosine":
            config.total_training_steps = len(train_loader) * config.epochs // config.accumulation_steps
            print(f"Total estimated training steps for Cosine Scheduler: {config.total_training_steps}")
    else:
        print("Skipping train loader creation (no data).")
        config.total_training_steps = 0

    if dev_dataset and dev_dataset.data:
        persist_workers_dev = (num_workers > 0)
        try: _ = DataLoader(dev_dataset, num_workers=num_workers, persistent_workers=persist_workers_dev)
        except TypeError: persist_workers_dev = False

        dev_loader = DataLoader(
            dev_dataset, batch_size=config.batch_size * 2, shuffle=False,
            num_workers=num_workers,
            pin_memory=True if config.device == torch.device("cuda") else False,
            drop_last=False,
            persistent_workers=persist_workers_dev
        )
        print(f"Validation loader created with {len(dev_loader)} batches.")
    else: print("Skipping validation loader creation.")

    if not train_loader: print("\\nERROR: Train loader could not be created.")
else:
     print("ERROR: Tokenizer or Image Processor not loaded. Skipping dataset/loader creation.")

In [None]:
# === Cell 10: Setup - Model, Optimizer, Scheduler (SigLIP Params) ===

import traceback

model = None
optimizer = None
lr_scheduler = None
scaler = None # For AMP

print("\\nInitializing ViPhobertSiglipLoss model components...")
try:
    image_encoder = ImageEncoder(config).to(config.device)
    text_encoder = TextEncoder(config).to(config.device)
    # --- Instantiate the CORRECT model class ---
    model = ViPhobertSiglipLossModel(image_encoder, text_encoder, config).to(config.device)

    print(f"\\nViPhobertSiglipLoss Model initialized successfully on {config.device}.")
    num_params_total = sum(p.numel() for p in model.parameters())
    num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {num_params_total / 1e6:.2f} M")
    print(f"Trainable parameters: {num_params_trainable / 1e6:.2f} M")

except Exception as e:
    print(f"ERROR initializing model components: {e}")
    traceback.print_exc()
    model = None

if model and train_loader:
    print("\\nSetting up optimizer...")
    # --- Optimizer Grouping (Include logit_scale and logit_bias if learnable) ---
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight", "logit_scale", "logit_bias"] # Exclude temp/bias from WD

    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and p.requires_grad],
         'lr': config.learning_rate, 'weight_decay': config.weight_decay}, # Base LR for most params
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and p.requires_grad],
         'lr': config.learning_rate, 'weight_decay': 0.0}, # No WD for bias/norm/temp/bias
        # Optionally assign different LRs if needed (e.g., higher for projection/temp/bias)
        # {'params': model.image_encoder.projection.parameters(), 'lr': config.projection_lr, 'weight_decay': 0.0},
        # {'params': model.text_encoder.projection.parameters(), 'lr': config.projection_lr, 'weight_decay': 0.0},
        # {'params': [p for n, p in param_optimizer if "logit_scale" in n and p.requires_grad], 'lr': config.projection_lr, 'weight_decay': 0.0},
        # {'params': [p for n, p in param_optimizer if "logit_bias" in n and p.requires_grad], 'lr': config.projection_lr, 'weight_decay': 0.0},
    ]
    # Filter empty groups
    optimizer_grouped_parameters = [g for g in optimizer_grouped_parameters if g['params']]

    if not optimizer_grouped_parameters:
        print("ERROR: No trainable parameters found.")
        optimizer = None
    else:
        optimizer = optim.AdamW(optimizer_grouped_parameters, lr=config.learning_rate) # Base LR is default
        print(f"Optimizer AdamW initialized with base LR: {config.learning_rate}, weight decay: {config.weight_decay}")

        # --- LR Scheduler ---
        if config.scheduler_type == "cosine":
            if hasattr(config, 'total_training_steps') and config.total_training_steps > 0:
                 lr_scheduler = transformers.get_cosine_schedule_with_warmup(
                     optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=config.total_training_steps
                 )
                 print(f"LR Scheduler: Cosine with Warmup ({config.warmup_steps} steps) initialized.")
            else: lr_scheduler = None; print("ERROR: Cannot init Cosine scheduler.")
        elif config.scheduler_type == "reduce_on_plateau":
            lr_scheduler = ReduceLROnPlateau(optimizer, mode=config.mode, factor=config.rop_factor, patience=config.rop_patience)
            print(f"LR Scheduler: ReduceLROnPlateau initialized (mode='{config.mode}', factor={config.rop_factor}, patience={config.rop_patience})")
        else: lr_scheduler = None; print("No LR Scheduler specified.")

        # --- AMP Scaler ---
        if config.use_amp:
            scaler = torch.amp.GradScaler('cuda')
            print("AMP GradScaler initialized.")
        else: scaler = None

        # Early stopping setup
        early_stopping_counter = 0
        best_val_metric = -float('inf') if config.mode == "max" else float('inf')
        print(f"Early stopping enabled with patience: {config.early_stopping_patience}")

else:
    print("ERROR: Model or train_loader not available. Skipping optimizer/scheduler setup.")
    optimizer = None; lr_scheduler = None; scaler = None

In [None]:
# === Cell 11: Training and Validation Functions (Using SigLIP Loss) ===
import traceback

def train_step(model, batch, optimizer, scaler, device, use_amp):
    """ Performs a single training step with SigLIP loss and optional AMP """
    model.train()
    pixel_values = batch['pixel_values']
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']

    with torch.amp.autocast(device_type=device.type, enabled=use_amp):
        # Model returns normalized features, temp, bias
        image_features, text_features, temp, bias = model(pixel_values, input_ids, attention_mask)
        loss = siglip_loss(image_features, text_features, temp, bias) # <<< USE SigLIP LOSS

    if use_amp:
        scaler.scale(loss).backward()
    else:
        loss.backward()

    return loss.item()

def validate_epoch(model, dataloader, device):
    """ Performs validation, returning standard retrieval metrics """
    model.eval()
    all_image_embeddings = []
    all_text_embeddings = []
    progress_bar = tqdm(dataloader, desc="Validation", leave=False, unit="batch")

    with torch.no_grad():
        for batch in progress_bar:
            pixel_values = batch['pixel_values']
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']

            with torch.amp.autocast(device_type=device.type, enabled=config.use_amp):
                # Get normalized features, ignore temp/bias for standard metrics
                image_embeds_norm, text_embeds_norm, _, _ = model(pixel_values, input_ids, attention_mask)

            all_image_embeddings.append(image_embeds_norm.cpu())
            all_text_embeddings.append(text_embeds_norm.cpu())

    if not all_image_embeddings or not all_text_embeddings:
         print("Warning: No embeddings collected during validation.")
         return {"loss": float('inf'), "avg_acc": 0.0, "avg_cosine_sim": 0.0, "i2t recall R@1": 0.0, "i2t recall R@5": 0.0, "i2t recall R@10": 0.0, "t2i recall R@1": 0.0, "t2i recall R@5": 0.0, "t2i recall R@10": 0.0 }

    try:
        all_image_embeddings = torch.cat(all_image_embeddings, dim=0)
        all_text_embeddings = torch.cat(all_text_embeddings, dim=0)
    except Exception as e:
        print(f"Error concatenating embeddings: {e}")
        return {"loss": float('inf'), "avg_acc": 0.0, "avg_cosine_sim": 0.0, "i2t recall R@1": 0.0, "i2t recall R@5": 0.0, "i2t recall R@10": 0.0, "t2i recall R@1": 0.0, "t2i recall R@5": 0.0, "t2i recall R@10": 0.0 }

    # Log Temperature and Bias values
    if hasattr(model, 'logit_scale'):
        current_temp_val = model.logit_scale.item() if isinstance(model.logit_scale, nn.Parameter) else model.logit_scale.item()
        print(f"DEBUG: Validation - Current Temp: {current_temp_val:.4f}")
    if hasattr(model, 'logit_bias'):
        current_bias_val = model.logit_bias.item() if isinstance(model.logit_bias, nn.Parameter) else model.logit_bias.item()
        print(f"DEBUG: Validation - Current Bias: {current_bias_val:.4f}")

    print(f"\\nComputing metrics over {all_image_embeddings.shape[0]} validation samples...")
    validation_metrics = compute_metrics(all_image_embeddings.to(device), all_text_embeddings.to(device))

    # Format results (no validation loss calculated/averaged here)
    final_results = {}
    for k, v in validation_metrics.items():
        if isinstance(v, dict):
            for recall_k, recall_v in v.items(): final_results[f"{k.replace('_', ' ')} {recall_k}"] = recall_v
        else: final_results[k.replace('_', ' ')] = v

    del all_image_embeddings, all_text_embeddings
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

    return final_results

print("Training step (SigLIP loss) and validation epoch functions defined.")

In [None]:
# === Cell 12: Training Loop (with Early Stopping, SigLIP Loss) ===
import datetime

if model and train_loader and optimizer:
    print(f"\\nStarting ViPhobertSiglip fine-tuning (SigLIP Loss) for {config.epochs} epochs...")
    print(f"Target metric for saving best model: '{config.metric_to_track}' (mode: {config.mode})")

    # Use variables initialized in Cell 10
    # best_val_metric, early_stopping_counter

    global_step = 0
    total_loss_since_log = 0.0
    steps_since_log = 0
    start_train_time = time.time()

    history = {'steps': [], 'train_loss': [], 'val_metrics': {}}

    model.train()

    for epoch in range(config.epochs):
        epoch_start_time = time.time()
        print(f"\\n--- Epoch {epoch+1}/{config.epochs} ---")
        progress_bar = tqdm(train_loader, desc=f"Training E{epoch+1}", leave=True, unit="batch")
        epoch_loss_meter = AvgMeter(f"Train Loss E{epoch+1}")

        for i, batch in enumerate(progress_bar):
            if 'pixel_values' not in batch or batch['pixel_values'].shape[0] < config.batch_size and torch.all(batch['pixel_values'] == 0):
                continue

            # --- Training Step using siglip_loss ---
            loss = train_step(model, batch, optimizer, scaler, config.device, config.use_amp)
            epoch_loss_meter.update(loss * config.accumulation_steps, batch['pixel_values'].shape[0])

            loss_normalized = loss / config.accumulation_steps
            total_loss_since_log += loss_normalized
            steps_since_log += 1

            # --- Gradient Accumulation & Optimizer Step ---
            is_update_step = (i + 1) % config.accumulation_steps == 0 or (i + 1) == len(train_loader)
            if is_update_step:
                if config.use_amp:
                    scaler.unscale_(optimizer)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                optimizer.zero_grad()
                if lr_scheduler and config.scheduler_type == "cosine":
                    lr_scheduler.step()

            global_step += 1

            # --- Logging ---
            if global_step % config.log_interval_steps == 0 and steps_since_log > 0:
                avg_loss = total_loss_since_log / steps_since_log
                current_lr = optimizer.param_groups[0]['lr']
                progress_bar.set_postfix(loss=f"{avg_loss:.4f}", lr=f"{current_lr:.2e}", step=f"{global_step}")
                history['steps'].append(global_step)
                history['train_loss'].append(avg_loss)
                total_loss_since_log = 0.0
                steps_since_log = 0

            # --- Validation & Checkpointing ---
            if dev_loader and config.validation_interval_steps > 0 and global_step % config.validation_interval_steps == 0 and global_step > 0:
                print(f"\\nRunning validation at step {global_step}...")
                val_start_time = time.time()
                val_results = validate_epoch(model, dev_loader, config.device)
                val_end_time = time.time()
                print(f"Validation finished in {val_end_time - val_start_time:.2f}s")

                metric_log_str = f"  Validation Step {global_step}: "
                history['val_metrics'][global_step] = val_results
                sorted_keys = sorted(val_results.keys())
                for name in sorted_keys: metric_log_str += f"{name}: {val_results[name]:.4f} | "
                print(metric_log_str.strip(" | "))

                current_val_metric_for_scheduler = val_results.get(config.metric_to_track.replace('_', ' '), None)
                if lr_scheduler and config.scheduler_type == "reduce_on_plateau":
                     if current_val_metric_for_scheduler is not None:
                         lr_scheduler.step(current_val_metric_for_scheduler)
                         print(f"  RoP Scheduler step called with {config.metric_to_track}={current_val_metric_for_scheduler:.4f}")
                     else:
                         print(f"  Warning: Metric '{config.metric_to_track}' not found. RoP Scheduler not stepped.")

                current_val_metric = val_results.get(config.metric_to_track.replace('_', ' '), None)
                is_best = False
                save_path = None
                save_path_periodic = None

                if current_val_metric is not None:
                    improvement_threshold = best_val_metric + config.early_stopping_min_delta if config.mode == "max" else best_val_metric - config.early_stopping_min_delta
                    if config.mode == "max": is_best = current_val_metric > improvement_threshold
                    else: is_best = current_val_metric < improvement_threshold

                    if is_best:
                        print(f"  Metric '{config.metric_to_track}' improved from {best_val_metric:.4f} to {current_val_metric:.4f}. Saving best model.")
                        best_val_metric = current_val_metric
                        early_stopping_counter = 0
                        # --- Use SigLIP Loss specific name ---
                        save_path = os.path.join(config.model_path, "phobert_sigliploss_best.pt")
                    else:
                        early_stopping_counter += 1
                        print(f"  Metric '{config.metric_to_track}' did not improve. Best: {best_val_metric:.4f}. Counter: {early_stopping_counter}/{config.early_stopping_patience}")

                    if config.save_interval_steps > 0 and global_step % config.save_interval_steps == 0:
                        # --- Use SigLIP Loss specific name ---
                        periodic_save_path = os.path.join(config.model_path, f"phobert_sigliploss_step_{global_step}.pt")
                        if save_path != periodic_save_path:
                            print(f"  Saving periodic checkpoint to {periodic_save_path}")
                            save_path_periodic = periodic_save_path

                    if save_path or save_path_periodic:
                        save_dict = { # Populate with relevant info
                            'step': global_step, 'epoch': epoch + 1,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'best_val_metric': best_val_metric,
                            'metric_tracked': config.metric_to_track,
                            'current_val_metrics': val_results,
                            'vision_model_name': config.vision_model_name,
                            'text_model_name': config.text_model_name,
                            'projection_dim': config.projection_dim,
                            'learnable_temperature': config.learnable_temperature,
                            'temperature_init': config.temperature_init,
                            'learnable_bias': config.learnable_bias,
                            'bias_init': config.bias_init,
                            'max_length': config.max_length,
                        }
                        if lr_scheduler: save_dict['scheduler_state_dict'] = lr_scheduler.state_dict()
                        if scaler: save_dict['scaler_state_dict'] = scaler.state_dict()
                        if save_path: torch.save(save_dict, save_path)
                        if save_path_periodic: torch.save(save_dict, save_path_periodic)
                else:
                    print(f"  Warning: Metric '{config.metric_to_track}' not found. Cannot save best or check early stopping.")
                    early_stopping_counter += 1

                if early_stopping_counter >= config.early_stopping_patience:
                    print(f"\\nEarly stopping triggered after {early_stopping_counter} validation checks without improvement.")
                    break # Break INNER loop

                model.train()

        # --- End of Epoch ---
        epoch_end_time = time.time()
        print(f"--- Epoch {epoch+1} Time: {datetime.timedelta(seconds=epoch_end_time - epoch_start_time)} ---")
        print(f"--- Average Train Loss for Epoch {epoch+1}: {epoch_loss_meter.avg:.4f} ---")

        # --- Epoch-based Validation (if validation_interval_steps <= 0) ---
        # (Keep the epoch-based validation block from previous clean version if you want that option)
        # ...

        if early_stopping_counter >= config.early_stopping_patience:
           break

    # --- End of Training ---
    end_train_time = time.time()
    total_duration = datetime.timedelta(seconds=end_train_time - start_train_time)
    print(f"=============== Fine-tuning (SigLIP Loss) Finished ================")
    print(f"Total Training Time: {total_duration}")

    # --- Use SigLIP Loss specific name ---
    final_model_path = os.path.join(config.model_path, 'phobert_sigliploss_final.pt')
    final_save_dict = { # Populate final save dict
        'step': global_step, 'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_metric': best_val_metric,
        'metric_tracked': config.metric_to_track,
        'vision_model_name': config.vision_model_name,
        'text_model_name': config.text_model_name,
        'projection_dim': config.projection_dim,
        'learnable_temperature': config.learnable_temperature,
        'temperature_init': config.temperature_init,
        'learnable_bias': config.learnable_bias,
        'bias_init': config.bias_init,
        'max_length': config.max_length,
    }
    if lr_scheduler: final_save_dict['scheduler_state_dict'] = lr_scheduler.state_dict()
    if scaler: final_save_dict['scaler_state_dict'] = scaler.state_dict()
    torch.save(final_save_dict, final_model_path)
    print(f"Final model state saved to {final_model_path}")

    # --- Use SigLIP Loss specific name ---
    best_model_file = os.path.join(config.model_path, "phobert_sigliploss_best.pt")
    if dev_loader and os.path.exists(best_model_file):
        print(f"Best model based on '{config.metric_to_track}' ({best_val_metric:.4f}) is saved at: {best_model_file}")
    elif dev_loader: print("Best model checkpoint file not found.")
    print(f"=================================================")

else:
    print("ERROR: Prerequisites for training not met. Training loop skipped.")

In [None]:
# === Cell 13: Final Evaluation on Test Set (Updated for SigLIP Loss Model) ===
import traceback
from types import SimpleNamespace
from collections import OrderedDict

print("\\n=============== Starting Test Set Evaluation ===============")

test_loader = None
model_to_test = None

test_json_path = os.path.join(config.data_path, "test.json")
test_image_path = config.image_base_path

# 1. Check prerequisites & Create Test Loader
# ... (Keep data loading logic as before) ...
if os.path.exists(test_json_path) and 'tokenizer' in globals() and tokenizer and 'image_processor' in globals() and image_processor:
    print(f"Loading test data from: {test_json_path}")
    try:
        test_dataset = CustomImageCaptionDataset(
            json_path=test_json_path, image_base_path=test_image_path,
            tokenizer=tokenizer, image_processor=image_processor,
            max_length=config.max_length
        )
        if test_dataset.data:
            num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
            persist_workers_test = (num_workers > 0)
            try: _ = DataLoader(test_dataset, num_workers=num_workers, persistent_workers=persist_workers_test)
            except TypeError: persist_workers_test = False
            test_loader = DataLoader(
                test_dataset, batch_size=config.batch_size * 2, shuffle=False,
                num_workers=num_workers, pin_memory=True if config.device == torch.device("cuda") else False,
                drop_last=False, persistent_workers=persist_workers_test
            )
            print(f"Test loader created with {len(test_loader)} batches.")
        else: print("Test dataset loaded but is empty.")
    except Exception as e: print(f"Error creating test dataset/loader: {e}")
else: print("Skipping test evaluation: Test JSON, Tokenizer or Image Processor not found/loaded.")

# 2. Load Model for Testing
if test_loader:
    try:
        # --- Use SigLIP Loss specific names ---
        best_model_path = os.path.join(config.model_path, "phobert_sigliploss_best.pt")
        final_model_path = os.path.join(config.model_path, "phobert_sigliploss_final.pt")
        load_path = None

        if os.path.exists(best_model_path): load_path = best_model_path; print(f"\\nLoading best model: {load_path}")
        elif os.path.exists(final_model_path): load_path = final_model_path; print(f"\\nLoading final model: {load_path}")
        else: print(f"\\nWARNING: No checkpoints found in {config.model_path} to evaluate.")

        if load_path:
            checkpoint = torch.load(load_path, map_location=config.device)
            print("Re-creating model structure for testing...")

            # --- Create temp config based on saved checkpoint ---
            temp_config_dict = {
                'device': config.device,
                'vision_model_name': checkpoint.get('vision_model_name', config.selected_vision_source),
                'text_model_name': checkpoint.get('text_model_name', config.selected_text_model),
                'vision_embedding': config.vision_embedding,
                'text_embedding': config.text_embedding,
                'projection_dim': checkpoint.get('projection_dim', config.projection_dim),
                # Load SigLIP loss params correctly
                'learnable_temperature': checkpoint.get('learnable_temperature', config.learnable_temperature),
                'temperature_init': checkpoint.get('temperature_init', config.temperature_init),
                'learnable_bias': checkpoint.get('learnable_bias', config.learnable_bias),
                'bias_init': checkpoint.get('bias_init', config.bias_init),
            }
            temp_config = SimpleNamespace(**temp_config_dict)

            print(f"  Using Vision Source: {temp_config.vision_model_name}")
            print(f"  Using Text Model: {temp_config.text_model_name}")

            # --- Instantiate the CORRECT model class ---
            test_image_encoder = ImageEncoder(temp_config, pretrained=False).to(config.device)
            test_text_encoder = TextEncoder(temp_config, pretrained=False).to(config.device)
            # Use the model class that includes bias
            model_to_test = ViPhobertSiglipLossModel(test_image_encoder, test_text_encoder, temp_config).to(config.device)

            state_dict = checkpoint['model_state_dict']
            if all(k.startswith('module.') for k in state_dict.keys()):
                print("Detected 'module.' prefix, removing.")
                state_dict = OrderedDict((k[7:], v) for k, v in state_dict.items())

            load_result = model_to_test.load_state_dict(state_dict, strict=True)
            print(f"  State dict loading result: {load_result}")
            print(f"Model weights loaded successfully.")

            print("\\nRunning evaluation on test set...")
            test_results = validate_epoch(model_to_test, test_loader, config.device) # Use same val function

            print("\\n--- Test Set Results ---")
            metric_log_str = ""
            sorted_keys = sorted(test_results.keys())
            for name in sorted_keys: metric_log_str += f"  {name}: {test_results[name]:.4f}\\n"
            print(metric_log_str.strip())
            print("------------------------")
        else:
            print("Evaluation skipped (no weights found).")
    except Exception as e:
        print(f"\\nERROR during test setup/evaluation: {e}")
        traceback.print_exc()

print("\\n================= Evaluation Finished ==================")

In [None]:
# === Cell 14: Training Visualization (Adapted for Steps/Epochs) ===
# (No changes needed here, plotting logic is adaptable)

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import math

def plot_training_metrics(history, plot_dir, plot_by='epoch'):
    """Plots training and validation metrics."""
    os.makedirs(plot_dir, exist_ok=True)
    print(f"Plot directory ensured at: {os.path.abspath(plot_dir)}")
    if not history: print("No history data provided."); return

    if plot_by == 'step' and history.get('steps') and history.get('train_loss'):
        x_axis_train = history['steps']; x_label = 'Global Steps'
        x_axis_val = sorted(history.get('val_metrics', {}).keys()) if history.get('val_metrics') else []
    elif history.get('train_loss') and history.get('validation_results'):
         num_epochs_trained = len(history['train_loss']); num_epochs_validated = len([res for res in history['validation_results'] if res is not None])
         x_axis_train = range(1, num_epochs_trained + 1); x_axis_val = range(1, num_epochs_validated + 1); x_label = 'Epoch'; plot_by = 'epoch'
    else: print("Insufficient history data."); return

    val_metrics_data = history.get('val_metrics', {}) if plot_by == 'step' else history.get('validation_results', [])
    valid_val_results = [res for res in val_metrics_data if res is not None] if plot_by=='epoch' else \
                        [val_metrics_data.get(step) for step in x_axis_val if val_metrics_data.get(step)]

    # Training Loss
    if history.get('train_loss'):
        plt.figure(figsize=(10, 6)); plt.plot(x_axis_train, history['train_loss'], 'b-', label=f'Training Loss (Avg per {"Log Interval" if plot_by=="step" else "Epoch"})'); plt.xlabel(x_label); plt.ylabel('Loss'); plt.title(f'Training Loss over {x_label.capitalize()}')
        if valid_val_results:
            try:
                val_loss = [val_metrics_data[step].get('loss', float('nan')) for step in x_axis_val] if plot_by == 'step' else [res.get('loss', float('nan')) for res in valid_val_results]
                x_axis_val_loss = x_axis_val if len(val_loss) == len(x_axis_val) else range(1, len(val_loss) + 1) if plot_by=='epoch' else x_axis_val[:len(val_loss)]
                if any(not math.isnan(vl) for vl in val_loss): plt.plot(x_axis_val_loss, val_loss, 'r-', label='Validation Loss')
            except (KeyError, TypeError, IndexError): print("Validation loss not found or incorrectly formatted.")
        plt.legend(); plt.grid(True); plt.tight_layout(); save_path_loss = os.path.join(plot_dir, f'training_loss_{plot_by}.png'); plt.savefig(save_path_loss, dpi=300); print(f"Saved training loss plot to: {save_path_loss}"); plt.close()
    else: print("No training loss data to plot.")

    # Validation Metrics
    if valid_val_results and x_axis_val:
        first_valid_val_result = valid_val_results[0]
        if first_valid_val_result and isinstance(first_valid_val_result, dict):
            metrics_to_plot = [k for k in first_valid_val_result.keys() if k != 'loss']
            num_plots = len(metrics_to_plot)
            if num_plots > 0:
                ncols = 2; nrows = math.ceil(num_plots / ncols); fig, axes = plt.subplots(nrows, ncols, figsize=(8 * ncols, 6 * nrows), squeeze=False); axes = axes.flatten()
                for i, metric_name in enumerate(metrics_to_plot):
                    metric_values = [val_metrics_data[step].get(metric_name, float('nan')) for step in x_axis_val] if plot_by == 'step' else [res.get(metric_name, float('nan')) for res in valid_val_results]
                    x_axis_val_metric = x_axis_val if len(metric_values) == len(x_axis_val) else range(1, len(metric_values) + 1) if plot_by=='epoch' else x_axis_val[:len(metric_values)]
                    if any(not math.isnan(v) for v in metric_values):
                        axes[i].plot(x_axis_val_metric, metric_values, 'r-o', label=f'Validation {metric_name}'); axes[i].set_xlabel(x_label); axes[i].set_ylabel(metric_name.replace('_', ' ').capitalize()); axes[i].set_title(f'Validation {metric_name} over {x_label.capitalize()}'); axes[i].legend(); axes[i].grid(True)
                    else: axes[i].set_title(f'Validation {metric_name} (No Data)'); axes[i].text(0.5, 0.5, 'No Data', ha='center', va='center')
                for j in range(i + 1, len(axes)): fig.delaxes(axes[j])
                fig.suptitle(f'Validation Metrics over {x_label.capitalize()}', fontsize=16, y=1.02 if nrows>1 else 1.05); plt.tight_layout(rect=[0, 0, 1, 0.97])
                save_path_val = os.path.join(plot_dir, f'validation_metrics_{plot_by}.png'); plt.savefig(save_path_val, dpi=300); print(f"Saved validation metrics plot to: {save_path_val}"); plt.close()
            else: print("No validation metrics (excluding loss) found to plot.")
        else: print("No valid validation results found.")
    else: print("No validation metrics found in history to plot.")


# --- Plotting ---
plot_directory = f"{config.model_path}/plots"
plotting_mode = 'step' if config.validation_interval_steps > 0 else 'epoch'

if 'history' in locals() and isinstance(history, dict):
    plot_training_metrics(history, plot_directory, plot_by=plotting_mode)
else:
    print("No training history found. Run training first.")

# --- END OF SCRIPT ---