In [None]:
import os
import json
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast

from transformers import (
    Blip2Processor,
    Blip2Model,
    PhobertTokenizer, # Using PhoBERT tokenizer
    AutoImageProcessor,
    logging
)
from datasets import load_dataset # Can be useful for handling data if needed later

# Suppress excessive warnings
logging.set_verbosity_error()

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False # Set to False for reproducibility

print(f"PyTorch Version: {torch.__version__}")
cuda_available = torch.cuda.is_available()
print(f"CUDA Available: {cuda_available}")
if cuda_available:
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Capability: {torch.cuda.get_device_capability(0)}")

In [None]:
# %% [markdown]
# ## 2. Configuration

# %%
class CFG:
    # --- Paths ---
    # IMPORTANT: Adjust these paths to match your directory structure
    data_path = "./" # Directory containing train.json, dev.json, test.json
    image_base_path = "./" # Base directory where image_path in JSON starts from (e.g., contains 'images/' folder)
    model_path = "./ViBLIP_QFormer_Trained_Retrieval" # Where to save trained models

    # --- Model Selection ---
    blip2_model_name = "Salesforce/blip2-opt-2.7b"
    text_tokenizer_name = "vinai/phobert-base" # Vietnamese PhoBERT

    # --- Training Parameters ---
    seed = 42
    batch_size = 8  # Reduced further for potential 4090 memory limits with large models
    num_workers = 4
    qformer_lr = 1e-5 # Often requires a smaller LR for fine-tuning
    weight_decay = 0.05
    patience = 2 # ReduceLROnPlateau patience
    factor = 0.8 # ReduceLROnPlateau factor
    epochs = 10 # Increased epochs, adjust as needed
    early_stop_patience = 3 # Stop training if no improvement after N epochs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = True # Use Automatic Mixed Precision

    # --- Image/Text Parameters ---
    image_size = 224 # Standard BLIP-2 image size
    max_length = 77 # Standard CLIP/BLIP text length

    # --- Loss/Saving Parameters ---
    temperature = 0.07 # Temperature for contrastive loss
    save_best_only = True
    metric_to_track = "val_avg_acc" # Track average retrieval accuracy
    mode = "max" # Maximize accuracy

config = CFG()
seed_everything(config.seed)
os.makedirs(config.model_path, exist_ok=True)

print(f"--- ViBLIP Q-Former Retrieval Training Configuration ---")
print(f"Device: {config.device}")
print(f"Base BLIP-2 Model: {config.blip2_model_name}")
print(f"Text Tokenizer: {config.text_tokenizer_name}")
print(f"Batch Size: {config.batch_size}")
print(f"Use AMP: {config.use_amp}")
print(f"Epochs: {config.epochs}")
print(f"Q-Former LR: {config.qformer_lr}")
print(f"Early Stop Patience: {config.early_stop_patience}")
print(f"Output Path: {config.model_path}")
print(f"Data Path (JSONs): {os.path.abspath(config.data_path)}")
print(f"Image Base Path: {os.path.abspath(config.image_base_path)}")
print(f"Metric to Track for Best Model: {config.metric_to_track} ({config.mode})")
print(f"------------------------------------------------------\n")

if not os.path.exists(config.data_path) or not os.path.exists(config.image_base_path):
     print(f"WARNING: Data path ({config.data_path}) or Image base path ({config.image_base_path}) does not exist.")
     print("Please ensure the paths are correct and the dataset is present.")

In [None]:


# %% [markdown]
# ## 3. Metric Calculation Utilities

# %%
class AvgMeter:
    """Computes and stores the average and current value"""
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0

    def update(self, val, count=1):
        if torch.is_tensor(val): val = val.item() # Ensure val is a Python number
        if isinstance(val, (int, float)):
            self.sum += val * count
            self.count += count
            self.avg = self.sum / self.count if self.count != 0 else 0
        elif val is not None: # Handle potential None or unexpected types gracefully
             print(f"Warning: AvgMeter received unexpected type {type(val)}. Skipping update.")


    def __repr__(self):
        return f"{self.name}: {self.avg:.4f}"

def compute_recall_at_k(similarity_matrix, k, dim):
    """ Computes Recall@k for image-to-text (dim=0) or text-to-image (dim=1) """
    n = similarity_matrix.shape[1-dim] # Number of items to retrieve from
    if n == 0 or k <= 0: return 0.0
    effective_k = min(k, n) # Cannot retrieve more items than available
    correct_count = 0

    # Get the indices of the top-k most similar items
    # topk returns (values, indices)
    top_k_indices = torch.topk(similarity_matrix, effective_k, dim=dim).indices
    ground_truth = torch.arange(n, device=similarity_matrix.device)

    if dim == 0: # I2T Recall: For each text, check if its corresponding image is in the top-k images
        # Transpose for easier iteration if needed, or adjust indexing
        # Check if ground_truth image index is present in the top k retrieved text indices for that image
         for txt_idx in range(n): # Iterate through columns (texts)
             if ground_truth[txt_idx] in top_k_indices[:, txt_idx]: # Check if correct image index is in the column
                correct_count += 1
    elif dim == 1: # T2I Recall: For each image, check if its corresponding text is in the top-k texts
        # Check if ground_truth text index is present in the top k retrieved image indices for that text
        for img_idx in range(n): # Iterate through rows (images)
             if ground_truth[img_idx] in top_k_indices[img_idx, :]: # Check if correct text index is in the row
                correct_count += 1
    else:
        raise ValueError("dim must be 0 (I2T) or 1 (T2I)")

    return correct_count / n if n > 0 else 0.0


def compute_metrics(image_embeddings, text_embeddings, temp=1.0):
    """ Computes ITC loss and retrieval metrics """
    image_embeddings = image_embeddings.float() # Ensure float32 for stability
    text_embeddings = text_embeddings.float()

    # Calculate cosine similarity matrix (logits)
    # Assumes embeddings are already normalized
    logits = text_embeddings @ image_embeddings.T * temp # Use temperature from config if needed scaling here, or apply in loss

    n = logits.shape[0] # Should be batch size or validation set size
    default_metrics = {
        "loss": None, # Loss needs to be computed where gradients are required
        "i2t_acc": 0.0, "t2i_acc": 0.0, "avg_acc": 0.0,
        "avg_cosine_sim": 0.0,
        "i2t_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0},
        "t2i_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0}
    }
    if n == 0:
        print("Warning: compute_metrics received empty embeddings.")
        return default_metrics

    try:
        # --- Accuracy Calculation ---
        ground_truth = torch.arange(n, device=logits.device)
        # Image-to-Text Accuracy: For each image, is the correct text caption ranked highest?
        i2t_preds = torch.argmax(logits, dim=0) # Find max similarity text for each image
        i2t_acc = (i2t_preds == ground_truth).float().mean().item()

        # Text-to-Image Accuracy: For each text caption, is the correct image ranked highest?
        t2i_preds = torch.argmax(logits, dim=1) # Find max similarity image for each text
        t2i_acc = (t2i_preds == ground_truth).float().mean().item()

        avg_acc = (i2t_acc + t2i_acc) / 2

        # Average cosine similarity of the correct pairs (diagonal)
        diag_len = min(logits.shape[0], logits.shape[1])
        avg_cosine_sim = torch.diag(logits[:diag_len, :diag_len]).mean().item() / temp # Divide by temp if logits included it

        # --- Recall Calculation ---
        i2t_recall = {}
        t2i_recall = {}
        for k in [1, 5, 10]:
            k_str = f"R@{k}"
            # Recall@k for I2T: For each image, is the correct text among the top k texts? (dim=0)
            i2t_recall[k_str] = compute_recall_at_k(logits, k, dim=0)
             # Recall@k for T2I: For each text, is the correct image among the top k images? (dim=1)
            t2i_recall[k_str] = compute_recall_at_k(logits, k, dim=1)

        # --- Loss Calculation (Symmetric Cross Entropy for ITC) ---
        # Calculate loss here if needed (e.g., during validation)
        # Note: For training, loss is usually computed within the training step for gradients
        labels = torch.arange(n, device=logits.device)
        loss_i2t = F.cross_entropy(logits.T, labels) # Predict text given image (logits columns)
        loss_t2i = F.cross_entropy(logits, labels)   # Predict image given text (logits rows)
        loss = (loss_i2t + loss_t2i) / 2.0

        return {
            "loss": loss.item(),
            "i2t_acc": i2t_acc, "t2i_acc": t2i_acc, "avg_acc": avg_acc,
            "avg_cosine_sim": avg_cosine_sim,
            "i2t_recall": i2t_recall, "t2i_recall": t2i_recall
        }
    except Exception as e:
        print(f"Error during metric calculation: {e}")
        print(f"Shapes: ImgEmb={image_embeddings.shape}, TxtEmb={text_embeddings.shape}, Logits={logits.shape}")
        # Return default metrics but maybe set loss to NaN or a high value
        default_metrics["loss"] = float('nan')
        return default_metrics

print("Metric utilities defined.")

In [None]:
# %% [markdown]
# ## 4. Dataset and DataLoader

# %%
class VietnameseImageRetrievalDataset(Dataset):
    def __init__(self, json_path, image_base_path, image_processor, tokenizer, max_length, is_train=True):
        self.json_path = json_path
        self.image_base_path = image_base_path
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_train = is_train # Flag might be useful later

        print(f"Loading data from: {self.json_path}")
        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                self.data = json.load(f)
            print(f"Loaded {len(self.data)} samples.")
            # Basic validation of the first item
            if self.data:
                first_item = self.data[0]
                if "image_path" not in first_item or "caption" not in first_item:
                     raise ValueError("JSON items must contain 'image_path' and 'caption' keys.")
                if not isinstance(first_item["caption"], list) or not first_item["caption"]:
                     raise ValueError("'caption' must be a non-empty list of strings.")
            else:
                 print("Warning: JSON file is empty.")

        except FileNotFoundError:
            print(f"ERROR: JSON file not found at {json_path}")
            self.data = []
        except json.JSONDecodeError:
            print(f"ERROR: Could not decode JSON from {json_path}")
            self.data = []
        except Exception as e:
            print(f"An unexpected error occurred loading JSON: {e}")
            self.data = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_rel_path = item['image_path']
        # Ensure correct path joining (handle cases like './images/' or 'images/')
        if image_rel_path.startswith('./'):
            image_rel_path = image_rel_path[2:]
        image_full_path = os.path.join(self.image_base_path, image_rel_path)

        try:
            image = Image.open(image_full_path).convert("RGB")
            # Apply image transformations (resizing, normalization) using the processor
            # The processor expects a PIL image or similar format
            processed_image = self.image_processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0) # Remove batch dim
        except FileNotFoundError:
            print(f"Warning: Image not found at {image_full_path}. Returning None for image.")
            # Handle missing images: return None or a placeholder tensor
            processed_image = torch.zeros((3, config.image_size, config.image_size)) # Placeholder
            # Or skip this item if many are missing? Needs careful consideration.

        except Exception as e:
             print(f"Warning: Error loading image {image_full_path}: {e}. Returning placeholder.")
             processed_image = torch.zeros((3, config.image_size, config.image_size))


        # Get the caption (assuming the first caption if multiple are provided)
        caption = item['caption'][0]

        # Tokenize the caption using PhoBERT tokenizer
        # Important: Ensure tokenizer handles padding and truncation
        tokenized_caption = self.tokenizer(
            caption,
            padding='max_length', # Pad to max_length
            truncation=True,      # Truncate if longer
            max_length=self.max_length,
            return_tensors="pt"   # Return PyTorch tensors
        )

        # Squeeze unnecessary dimensions added by tokenizer
        input_ids = tokenized_caption['input_ids'].squeeze()
        attention_mask = tokenized_caption['attention_mask'].squeeze()

        return {
            "pixel_values": processed_image,
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

In [None]:
# %% [markdown]
# ## 5. Initialize Processors, Tokenizer, Datasets, and DataLoaders

# %%
# Load the specific image processor associated with the BLIP-2 model
# Although we fine-tune Q-Former, image processing should match the pre-training
try:
    image_processor = AutoImageProcessor.from_pretrained(config.blip2_model_name)
    print(f"Image processor loaded from {config.blip2_model_name}")
except Exception as e:
    print(f"Error loading image processor: {e}. Using default Blip2Processor.")
    # Fallback if AutoImageProcessor fails for some reason
    fallback_processor = Blip2Processor.from_pretrained(config.blip2_model_name)
    image_processor = fallback_processor.image_processor


# Load the specified Vietnamese PhoBERT tokenizer
try:
    text_tokenizer = PhobertTokenizer.from_pretrained(config.text_tokenizer_name)
    # Check if pad token exists, add if necessary (though PhoBERT usually has it)
    if text_tokenizer.pad_token is None:
        print("PhoBERT tokenizer does not have a pad token. Setting it to eos_token.")
        text_tokenizer.pad_token = text_tokenizer.eos_token # Common practice if pad is missing
    print(f"Text tokenizer loaded: {config.text_tokenizer_name}")
except Exception as e:
    print(f"FATAL ERROR: Could not load PhoBERT tokenizer '{config.text_tokenizer_name}'. {e}")
    # Exit or raise error, as tokenizer is critical
    raise e


# --- Create Datasets ---
train_json_path = os.path.join(config.data_path, "train.json")
dev_json_path = os.path.join(config.data_path, "dev.json")
# test_json_path = os.path.join(config.data_path, "test.json") # Optional for final testing

train_dataset = VietnameseImageRetrievalDataset(
    json_path=train_json_path,
    image_base_path=config.image_base_path,
    image_processor=image_processor,
    tokenizer=text_tokenizer,
    max_length=config.max_length,
    is_train=True
)

val_dataset = VietnameseImageRetrievalDataset(
    json_path=dev_json_path,
    image_base_path=config.image_base_path,
    image_processor=image_processor,
    tokenizer=text_tokenizer,
    max_length=config.max_length,
    is_train=False
)

# --- Create DataLoaders ---
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True, # Helps speed up data transfer to GPU
    drop_last=False # Keep last batch even if smaller
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False, # No shuffling for validation
    num_workers=config.num_workers,
    pin_memory=True,
    drop_last=False
)

print(f"\nDataLoaders created:")
print(f"Train dataset size: {len(train_dataset)}, Train loader batches: {len(train_loader)}")
print(f"Validation dataset size: {len(val_dataset)}, Validation loader batches: {len(val_loader)}")


In [None]:


# %% [markdown]
# ## 6. Model Loading and Preparation

# %%
print(f"Loading BLIP-2 model: {config.blip2_model_name}...")
# Load the base BLIP-2 model. Blip2Model provides access to vision, qformer, and language model parts.
# Use load_in_8bit=True if memory is very tight, requires `bitsandbytes`
# device_map='auto' might be needed for very large models across multiple GPUs
try:
     model = Blip2Model.from_pretrained(
         config.blip2_model_name,
         # load_in_8bit=True, # Uncomment if needed, requires bitsandbytes
         # device_map="auto" # Uncomment if using multiple GPUs or offloading
         # torch_dtype=torch.float16 # Uncomment if using AMP and want model loaded in fp16
     )
     print("BLIP-2 model loaded successfully.")
except Exception as e:
     print(f"Error loading model {config.blip2_model_name}: {e}")
     raise e


# --- Freeze parameters ---
print("Freezing model parameters initially...")
for param in model.parameters():
    param.requires_grad = False

# --- Unfreeze Q-Former parameters ---
print("Unfreezing Q-Former parameters for fine-tuning...")
qformer_param_count = 0
total_param_count = 0
trainable_param_count = 0

for name, param in model.named_parameters():
    total_param_count += param.numel()
    # Target parameters within the qformer module
    if "qformer." in name:
        param.requires_grad = True
        qformer_param_count += param.numel()
        trainable_param_count += param.numel()
    else:
        param.requires_grad = False

print(f"Total parameters: {total_param_count:,}")
print(f"Q-Former parameters: {qformer_param_count:,}")
print(f"Trainable parameters (should match Q-Former): {trainable_param_count:,}")

if trainable_param_count == 0:
    print("WARNING: No parameters were unfrozen for training. Check model structure and parameter names.")

# --- Move model to device ---
model.to(config.device)
print(f"Model moved to {config.device}")

# If using device_map="auto", the model might already be on GPU(s)
# If using load_in_8bit=True, model parts might stay on CPU/GPU depending on device_map

# Optional: Compile model (PyTorch 2.0+) - can speed up training but might have issues
# try:
#     if hasattr(torch, 'compile'):
#         print("Compiling model with torch.compile...")
#         model = torch.compile(model)
# except Exception as e:
#      print(f"torch.compile failed: {e}")

In [None]:
# %% [markdown]
# ## 7. Training and Validation Functions

# %%
def train_one_epoch(model, dataloader, optimizer, scaler, device, temperature):
    model.train()
    loss_meter = AvgMeter(name="Train Loss")
    start_time = time.time()
    pbar = tqdm(dataloader, desc="Training", leave=False)

    for batch in pbar:
        # Move batch to device
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        optimizer.zero_grad()

        with autocast(enabled=config.use_amp):
            # Extract features using the model
            # We need image features and text features projected to the same multimodal space
            # Blip2Model output contains 'image_embeds' (from vision encoder) and 'last_hidden_state' (from LM)
            # We need the outputs *after* the Q-Former processing for multimodal alignment.
            # Let's use the dedicated methods: get_image_features and get_text_features
            # These typically apply the projection layers.

            image_outputs = model.get_image_features(pixel_values=pixel_values)
            # For text features with Q-Former, we usually pass input_ids to the model directly
            # or use get_text_features which might involve the language model.
            # Let's confirm Blip2Model's get_text_features usage...
            # Yes, get_text_features takes input_ids and attention_mask.
            text_outputs = model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)

            # image_features = image_outputs.image_embeds # Check the attribute name, might be different
            # text_features = text_outputs.text_embeds   # Check the attribute name

            # Blip2Model output doesn't directly give 'image_embeds' and 'text_embeds' in the multimodal space easily.
            # A common way for retrieval with BLIP-2 is to use the query outputs from Q-Former.
            # Let's try getting QFormer output directly if possible, or use `get_image_features` / `get_text_features`
            # which *should* return embeddings suitable for contrastive loss.
            # `image_features` (from get_image_features) and `text_features` (from get_text_features)
            # are typically pooled representations. Let's assume they are suitable.

            image_features = image_outputs # Output might be the embedding tensor directly
            text_features = text_outputs  # Output might be the embedding tensor directly

            # --- Crucial: Normalize embeddings for cosine similarity ---
            image_features = F.normalize(image_features, p=2, dim=-1)
            text_features = F.normalize(text_features, p=2, dim=-1)

            # --- Calculate Contrastive Loss (InfoNCE) ---
            logits = (text_features @ image_features.T) * temperature
            targets = torch.arange(logits.shape[0], device=device) # Ground truth: image i matches text i

            loss_i2t = F.cross_entropy(logits.T, targets) # Predict text given image
            loss_t2i = F.cross_entropy(logits, targets)   # Predict image given text
            loss = (loss_i2t + loss_t2i) / 2.0

        # Backward pass
        if config.use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        loss_meter.update(loss.item(), count=pixel_values.size(0))
        pbar.set_postfix({"Loss": loss_meter.avg, "LR": optimizer.param_groups[0]['lr']})

    elapsed_time = time.time() - start_time
    print(f"Train Epoch Completed - Time: {elapsed_time:.2f}s, Avg Loss: {loss_meter.avg:.4f}")
    return loss_meter.avg


def validate_one_epoch(model, dataloader, device, temperature):
    model.eval()
    all_image_features = []
    all_text_features = []
    loss_meter = AvgMeter(name="Val Loss")
    start_time = time.time()
    pbar = tqdm(dataloader, desc="Validation", leave=False)

    with torch.no_grad():
        for batch in pbar:
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            with autocast(enabled=config.use_amp):
                # Extract features (same as in training)
                image_outputs = model.get_image_features(pixel_values=pixel_values)
                text_outputs = model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)

                image_features = image_outputs
                text_features = text_outputs

                # Normalize
                image_features_norm = F.normalize(image_features, p=2, dim=-1)
                text_features_norm = F.normalize(text_features, p=2, dim=-1)

            all_image_features.append(image_features_norm.cpu())
            all_text_features.append(text_features_norm.cpu())

            # Optional: Calculate batch-level loss for monitoring (doesn't affect gradients)
            logits = (text_features_norm @ image_features_norm.T) * temperature
            targets = torch.arange(logits.shape[0], device=device)
            loss_i2t = F.cross_entropy(logits.T, targets)
            loss_t2i = F.cross_entropy(logits, targets)
            loss = (loss_i2t + loss_t2i) / 2.0
            loss_meter.update(loss.item(), count=pixel_values.size(0))
            pbar.set_postfix({"Loss": loss_meter.avg})


    # Concatenate all features from the validation set
    if not all_image_features or not all_text_features:
         print("Warning: No features collected during validation.")
         # Return default bad metrics
         metrics = compute_metrics(torch.tensor([]), torch.tensor([]), temp=temperature) # Get default structure
         metrics["loss"] = loss_meter.avg if loss_meter.count > 0 else float('nan')
         return metrics

    all_image_features = torch.cat(all_image_features, dim=0).to(device)
    all_text_features = torch.cat(all_text_features, dim=0).to(device)

    # --- Calculate Metrics on the ENTIRE validation set ---
    print(f"\nCalculating metrics on {all_image_features.shape[0]} validation samples...")
    # Pass normalized features directly
    metrics = compute_metrics(all_image_features, all_text_features, temp=temperature)
    # Override the loss calculated within compute_metrics with the averaged batch loss
    metrics["loss"] = loss_meter.avg # Use the averaged batch loss for consistency report

    elapsed_time = time.time() - start_time
    print(f"Validation Completed - Time: {elapsed_time:.2f}s")

    return metrics



In [None]:

# %% [markdown]
# ## 8. Main Training Loop

# %%
print("--- Starting Training ---")

# Optimizer (AdamW is recommended for transformers)
# Filter parameters that require gradients
trainable_params = [p for p in model.parameters() if p.requires_grad]
if not trainable_params:
    raise ValueError("No trainable parameters found. Check model freezing/unfreezing steps.")

optimizer = AdamW(
    trainable_params,
    lr=config.qformer_lr,
    weight_decay=config.weight_decay
)

# Learning Rate Scheduler
lr_scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min', # Reduce LR when validation loss stops decreasing
    factor=config.factor,
    patience=config.patience,
    verbose=True
)

# Gradient Scaler for Automatic Mixed Precision (AMP)
scaler = GradScaler(enabled=config.use_amp)

# Tracking best model and early stopping
best_metric_value = -np.inf if config.mode == "max" else np.inf
epochs_no_improve = 0
history = {'train_loss': [], 'val_loss': [], f'{config.metric_to_track}': []}

for epoch in range(1, config.epochs + 1):
    print(f"\n===== Epoch {epoch}/{config.epochs} =====")

    # --- Training ---
    train_loss = train_one_epoch(model, train_loader, optimizer, scaler, config.device, config.temperature)
    history['train_loss'].append(train_loss)

    # --- Validation ---
    val_metrics = validate_one_epoch(model, val_loader, config.device, config.temperature)
    val_loss = val_metrics["loss"]
    val_tracked_metric = val_metrics.get(config.metric_to_track, None) # Get the specific metric we track

    history['val_loss'].append(val_loss)
    if val_tracked_metric is not None:
         history[config.metric_to_track].append(val_tracked_metric)


    print(f"Epoch {epoch} Summary:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")
    print(f"  Val Metrics:")
    print(f"    Avg Accuracy (I2T+T2I)/2: {val_metrics.get('avg_acc', 'N/A'):.4f}")
    print(f"    Accuracy (I2T): {val_metrics.get('i2t_acc', 'N/A'):.4f}, Accuracy (T2I): {val_metrics.get('t2i_acc', 'N/A'):.4f}")
    i2t_rec = val_metrics.get('i2t_recall', {})
    t2i_rec = val_metrics.get('t2i_recall', {})
    print(f"    I2T Recall: R@1={i2t_rec.get('R@1', 'N/A'):.4f}, R@5={i2t_rec.get('R@5', 'N/A'):.4f}, R@10={i2t_rec.get('R@10', 'N/A'):.4f}")
    print(f"    T2I Recall: R@1={t2i_rec.get('R@1', 'N/A'):.4f}, R@5={t2i_rec.get('R@5', 'N/A'):.4f}, R@10={t2i_rec.get('R@10', 'N/A'):.4f}")
    print(f"    Avg Cosine Sim (Pos Pairs): {val_metrics.get('avg_cosine_sim', 'N/A'):.4f}")


    # --- Learning Rate Scheduling ---
    # ReduceLROnPlateau steps based on validation loss
    lr_scheduler.step(val_loss)

    # --- Model Saving & Early Stopping ---
    if val_tracked_metric is None:
         print(f"Warning: Metric '{config.metric_to_track}' not found in validation results. Cannot save best model or check early stopping based on it.")
         current_is_best = False # Cannot determine best if metric is missing
    else:
        if config.mode == "max":
            current_is_best = val_tracked_metric > best_metric_value
        else: # mode == "min"
            current_is_best = val_tracked_metric < best_metric_value

    if current_is_best:
        best_metric_value = val_tracked_metric
        epochs_no_improve = 0
        print(f"*** New best model found! ({config.metric_to_track}: {best_metric_value:.4f}) Saving model... ***")
        # Save only the Q-Former weights, as that's what we trained
        qformer_state_dict = {k: v for k, v in model.state_dict().items() if "qformer." in k}
        save_path = os.path.join(config.model_path, "best_qformer_weights.pth")
        torch.save(qformer_state_dict, save_path)
        print(f"Q-Former weights saved to {save_path}")

        # Optionally save the full model if needed, but it will be large
        # if not config.save_best_only: # Example: Save full model checkpoint too
        #    full_save_path = os.path.join(config.model_path, f"epoch_{epoch}_full_model.pth")
        #    torch.save(model.state_dict(), full_save_path)
        #    print(f"Full model saved to {full_save_path}")

    else:
        epochs_no_improve += 1
        print(f"No improvement in {config.metric_to_track} for {epochs_no_improve} epoch(s). Best: {best_metric_value:.4f}")
        if config.early_stop_patience > 0 and epochs_no_improve >= config.early_stop_patience:
            print(f"\n--- Early stopping triggered after {epoch} epochs. ---")
            break

# %% [markdown]
# ## 9. Training Finished
#
# The training loop has completed. The best Q-Former weights (based on validation `{config.metric_to_track}`) should be saved in `{config.model_path}/best_qformer_weights.pth`.
#
# You can now load these weights back into the Q-Former part of a `Blip2Model` for inference or further use. Example loading:
#
# ```python
# # Load the base model architecture again
# inference_model = Blip2Model.from_pretrained(config.blip2_model_name)
#
# # Load the fine-tuned Q-Former weights
# qformer_weights_path = os.path.join(config.model_path, "best_qformer_weights.pth")
# if os.path.exists(qformer_weights_path):
#     qformer_state_dict = torch.load(qformer_weights_path, map_location='cpu') # Load to CPU first
#
#     # Load the state dict into the model's Q-Former
#     # Need to access the qformer submodule correctly
#     # The exact name might vary slightly depending on the transformers version, but 'qformer' is standard
#     missing_keys, unexpected_keys = inference_model.qformer.load_state_dict(qformer_state_dict, strict=False)
#     print("Q-Former weights loaded.")
#     if missing_keys: print(f"Missing keys: {missing_keys}")
#     if unexpected_keys: print(f"Unexpected keys: {unexpected_keys}") # Should ideally be empty if saved correctly
#
#     inference_model.to(config.device) # Move to GPU/CPU for inference
#     inference_model.eval()
# else:
#     print(f"ERROR: Saved Q-Former weights not found at {qformer_weights_path}")
#
# # Now 'inference_model' has the fine-tuned Q-Former and can be used for retrieval tasks.
# ```

# %%
print("\n--- Training Script Finished ---")

In [None]:
# Optional: Plot training history
try:
    import matplotlib.pyplot as plt
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    metric_key = config.metric_to_track
    if metric_key in history and history[metric_key]:
         plt.plot(history[metric_key], label=f'Validation {metric_key}')
         plt.title(f'{metric_key} History')
         plt.xlabel('Epoch')
         plt.ylabel(metric_key)
         plt.legend()
         plt.grid(True)
    else:
         plt.text(0.5, 0.5, f'Metric "{metric_key}"\n not recorded', horizontalalignment='center', verticalalignment='center')


    plt.tight_layout()
    plt.show()
except ImportError:
    print("Install matplotlib to plot training history: pip install matplotlib")
except Exception as e:
     print(f"Couldn't plot history: {e}")