# -*- coding: utf-8 -*-
# ViBLIP Fine-tuning for Vietnamese Image Retrieval by Text
# Generated from Jupyter Notebook


# === Cell 1: Installs and Imports ===
# !pip install -q transformers torch torchvision torchaudio Pillow tqdm accelerate bitsandbytes sentencepiece

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import Blip2Processor, Blip2Model, Blip2Config, AutoTokenizer # AutoTokenizer might be needed for processor loading edge cases
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
import json
import os
import random
import numpy as np
from tqdm.notebook import tqdm # Use standard tqdm if not in notebook: from tqdm import tqdm
import torch.nn.functional as F
import math
import time # For timing epochs
import transformers
from torch.cuda.amp import GradScaler, autocast # For mixed precision

print(f"PyTorch Version: {torch.__version__}")
print(f"Transformers Version: {transformers.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")


# === Cell 2: Configuration Class (CFG) ===
class CFG:
    # --- Paths ---
    # Base directory where your train.json, dev.json, test.json are located
    data_path = "./json_data/" # ADJUST THIS PATH
    # Base directory for images referenced in JSON files (relative to data_path or absolute)
    image_path = "./data/" # ADJUST THIS PATH

    # Output directory for saved models
    model_path = "./ViBLIP_retrieval"

    # --- BLIP Model Selection ---
    # Common options: Salesforce/blip2-opt-2.7b, Salesforce/blip2-flan-t5-xl
    # Choose based on available resources. OPT models might be slightly smaller.
    selected_blip_model = "Salesforce/blip2-opt-2.7b" # ADJUST IF NEEDED

    # --- Model parameters ---
    blip_model_name = selected_blip_model
    blip_processor_name = selected_blip_model
    projection_dim = 256 # Shared latent space dimension (e.g., 256, 512, 768)
    # Freeze parts of the model? BLIP2's LLM is very large.
    freeze_vision_model = False
    freeze_language_model = True # HIGHLY RECOMMENDED to freeze LLM unless you have >> 40GB VRAM
    freeze_qformer = False
    # Quantization (Requires bitsandbytes). Reduces memory significantly.
    load_in_8bit = False # Set to True if memory is tight

    # --- Training parameters ---
    seed = 42
    # Reduce batch size significantly compared to CLIP due to BLIP2's size
    batch_size = 16  # START LOW (e.g., 4, 8, 16) and increase based on GPU memory
    num_workers = 4  # Adjust based on system capability
    # Learning rates for different components (tune these)
    vision_encoder_lr = 1e-5
    qformer_lr = 2e-5
    language_model_lr = 1e-6 # Only relevant if freeze_language_model=False
    projection_lr = 1e-4 # Projection head can often learn faster
    weight_decay = 1e-3
    patience = 3 # Scheduler patience
    factor = 0.8 # Scheduler reduction factor
    epochs = 20 # Adjust as needed
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Use mixed precision to save memory and potentially speed up training
    use_amp = True

    # --- Image/Text parameters (mostly handled by processor) ---
    # Processor determines image size (usually 224 for BLIP2)
    max_length = 64 # Max text sequence length for tokenizer

    # --- Loss/Saving parameters ---
    temperature = 0.07 # Initial temperature for scaling logits
    learnable_temperature = True # Whether the logit_scale is learnable
    save_best_only = True
    metric_to_track = "avg_R@1" # Common retrieval metric ('avg_acc', 'i2t R@1', 't2i R@1', 'avg_R@5', etc.)
    mode = "max" # Mode for scheduler/saving based on metric_to_track ('max' for recall/acc, 'min' for loss)
    early_stopping_patience = 5 # Epochs with no improvement before stopping
    early_stopping_min_delta = 0.001 # Minimum change to qualify as improvement

# --- Instantiate Config and Create Output Dir ---
config = CFG()
os.makedirs(config.model_path, exist_ok=True)
print(f"Using device: {config.device}")
print(f"Model output path: {config.model_path}")
print(f"Selected BLIP2 Model: {config.blip_model_name}")
print(f"Image base path (for resolving paths in JSON): {os.path.abspath(config.image_path)}")


# === Cell 3: Seeding for Reproducibility ===
def set_seed(seed=config.seed):
    print(f"Setting seed: {seed}")
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) # For multi-GPU
        # torch.backends.cudnn.deterministic = True # Can impact performance
        # torch.backends.cudnn.benchmark = False

set_seed()


# === Cell 4: Metric Calculation Utilities ===
class AvgMeter:
    """Computes and stores the average and current value"""
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0

    def update(self, val, count=1):
        if torch.is_tensor(val):
             val = val.item()
        if isinstance(val, (int, float)):
            self.sum += val * count
            self.count += count
            self.avg = self.sum / self.count if self.count != 0 else 0

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def compute_recall_at_k(similarity_matrix, k, dim):
    """Calculates Recall@k for image-text retrieval."""
    n = similarity_matrix.shape[1-dim] # Number of samples (e.g., images if dim=1 for T2I)
    k_eff = min(k, similarity_matrix.shape[dim]) # Effective k cannot be larger than candidate pool size
    if k_eff == 0 or n == 0: return 0.0

    top_k_indices = torch.topk(similarity_matrix, k_eff, dim=dim).indices
    ground_truth = torch.arange(n, device=similarity_matrix.device)

    correct_count = 0
    if dim == 0: # I2T: Find correct text (row index) for each image (column)
        for img_idx in range(n):
            if ground_truth[img_idx] in top_k_indices[:, img_idx]:
                correct_count += 1
    elif dim == 1: # T2I: Find correct image (column index) for each text (row)
        for txt_idx in range(n):
             if ground_truth[txt_idx] in top_k_indices[txt_idx, :]:
                correct_count += 1
    else:
        raise ValueError("dim must be 0 or 1")

    return correct_count / n if n > 0 else 0.0

def compute_metrics(image_embeddings, text_embeddings):
    """Computes retrieval metrics for a batch or validation set."""
    if image_embeddings.device != text_embeddings.device:
        text_embeddings = text_embeddings.to(image_embeddings.device)

    sim_matrix = text_embeddings @ image_embeddings.T
    sim_matrix = sim_matrix.float()
    n = sim_matrix.shape[0]
    if n == 0:
        return {
            "i2t_acc": 0.0, "t2i_acc": 0.0, "avg_acc": 0.0,
            "avg_cosine_sim": 0.0,
            "i2t_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0},
            "t2i_recall": {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0},
            "avg_R@1": 0.0, "avg_R@5": 0.0, "avg_R@10": 0.0
        }

    ground_truth = torch.arange(n, device=sim_matrix.device)
    t2i_preds = torch.argmax(sim_matrix, dim=1)
    i2t_preds = torch.argmax(sim_matrix, dim=0)
    t2i_acc = (t2i_preds == ground_truth).float().mean().item()
    i2t_acc = (i2t_preds == ground_truth).float().mean().item()
    avg_acc = (i2t_acc + t2i_acc) / 2
    avg_cosine_sim = torch.diagonal(sim_matrix).mean().item()

    i2t_recall = {}
    t2i_recall = {}
    recall_k_values = [1, 5, 10]

    for k in recall_k_values:
        k_str = f"R@{k}"
        i2t_recall[k_str] = compute_recall_at_k(sim_matrix, k, dim=0)
        t2i_recall[k_str] = compute_recall_at_k(sim_matrix, k, dim=1)

    avg_recall = {}
    for k in recall_k_values:
        k_str = f"R@{k}"
        avg_recall[f"avg_{k_str}"] = (i2t_recall[k_str] + t2i_recall[k_str]) / 2

    metrics = {
        "i2t_acc": i2t_acc, "t2i_acc": t2i_acc, "avg_acc": avg_acc,
        "avg_cosine_sim": avg_cosine_sim,
        "i2t_recall": i2t_recall,
        "t2i_recall": t2i_recall,
        **avg_recall
    }
    return metrics

print("Metric utilities defined.")


# === Cell 5: Dataset Class Definition ===
class Blip2ImageCaptionDataset(Dataset):
    def __init__(self, json_path, image_base_path, processor, max_length):
        super().__init__()
        print(f"Attempting to load data from: {os.path.abspath(json_path)}")
        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                self.data = json.load(f)
        except FileNotFoundError:
            print(f"ERROR: JSON file not found at {json_path}")
            self.data = []
        except json.JSONDecodeError as e:
            print(f"Error: Could not decode JSON from {json_path}: {e}")
            self.data = []
        except Exception as e:
            print(f"An unexpected error occurred loading {json_path}: {e}")
            self.data = []

        print(f"Found {len(self.data)} samples in {os.path.basename(json_path)}.")
        self.image_base_path = image_base_path
        self.processor = processor
        self.max_length = max_length

        # Determine image size from processor
        try:
            # Try accessing size directly (newer transformers)
            if isinstance(self.processor.image_processor.size, dict):
                self.img_size = self.processor.image_processor.size['height'] # Or 'shortest_edge'
            else: # Older style might be int or tuple
                 self.img_size = self.processor.image_processor.size
                 if isinstance(self.img_size, (tuple, list)): self.img_size = self.img_size[0]
        except AttributeError:
            print("Warning: Could not determine image size from processor, defaulting to 224.")
            self.img_size = 224
        print(f"Using image size: {self.img_size}x{self.img_size}")

        if not os.path.isdir(self.image_base_path):
             print(f"WARNING: Image base path does not exist: {os.path.abspath(self.image_base_path)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if idx >= len(self.data):
             raise IndexError("Index out of bounds")
        item = self.data[idx]
        relative_image_path = item.get('image_path')
        captions = item.get('caption', [])
        caption = captions[0] if captions else "" # Take the first caption

        # Initialize with dummy data
        dummy_image = Image.new('RGB', (self.img_size, self.img_size))
        try:
            pixel_values = self.processor(images=dummy_image, return_tensors="pt")['pixel_values'].squeeze()
        except Exception as e:
             print(f"Error processing dummy image: {e}")
             pixel_values = torch.zeros((3, self.img_size, self.img_size)) # Fallback tensor

        image_loaded_successfully = False
        if relative_image_path:
            image_path = os.path.normpath(os.path.join(self.image_base_path, relative_image_path))
            try:
                image = Image.open(image_path).convert('RGB')
                processed_output = self.processor(images=image, text=None, return_tensors="pt") # Only process image here
                pixel_values = processed_output['pixel_values'].squeeze()
                image_loaded_successfully = True
            except FileNotFoundError:
                print(f"Warning: Img not found at {image_path}. Using dummy image for idx {idx}.")
            except Exception as e:
                print(f"Warning: Error loading image {image_path}: {e}. Using dummy image for idx {idx}.")
        else:
             print(f"Warning: Missing 'image_path' for item at index {idx}. Using dummy image.")

        # Process text
        try:
            text_inputs = self.processor(
                images=None, # Important: don't re-process image
                text=caption,
                padding='max_length',
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            )
            input_ids = text_inputs['input_ids'].squeeze()
            attention_mask = text_inputs['attention_mask'].squeeze()
        except Exception as e:
            print(f"Error processing text '{caption}' for idx {idx}: {e}")
            # Create dummy text inputs if error occurs
            input_ids = torch.zeros(self.max_length, dtype=torch.long)
            attention_mask = torch.zeros(self.max_length, dtype=torch.long)

        # Ensure tensors are 1D after squeeze
        if input_ids.dim() > 1: input_ids = input_ids.view(-1)
        if attention_mask.dim() > 1: attention_mask = attention_mask.view(-1)

        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

print("Blip2ImageCaptionDataset class defined.")


# === Cell 6: Model Definition (BLIP2 Retrieval Model, Loss) ===
class Blip2RetrievalModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config_train = config # Store training config
        print(f"Initializing BLIP2 Model: {config.blip_model_name}")

        # Load the base BLIP2 model.
        # Use Blip2Model for feature extraction, not Blip2ForConditionalGeneration.
        try:
            load_kwargs = {}
            if config.load_in_8bit:
                 if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7:
                    print("Warning: 8-bit loading requested but not supported on this GPU or CUDA version. Loading in default precision.")
                 else:
                    print("Attempting to load model in 8-bit.")
                    load_kwargs['load_in_8bit'] = True
                    load_kwargs['device_map'] = 'auto' # device_map needed for 8-bit

            self.blip_model = Blip2Model.from_pretrained(
                config.blip_model_name,
                **load_kwargs
            )

            # If not using device_map, explicitly move model parts if needed
            if 'device_map' not in load_kwargs and config.device != torch.device('cpu'):
                print(f"Manually moving model components to {config.device}")
                self.blip_model.to(config.device)
            elif 'device_map' in load_kwargs:
                print(f"Model loaded with device_map: {self.blip_model.hf_device_map}")

        except ImportError as e:
             if 'bitsandbytes' in str(e):
                 print("ERROR: bitsandbytes library not found. Please install it (`pip install bitsandbytes`) to use 8-bit loading.")
             else:
                 print(f"ERROR loading base BLIP2 model: {e}")
             raise e
        except Exception as e:
            print(f"ERROR loading base BLIP2 model: {e}")
            print("Check model name, internet connection, and available memory.")
            raise e

        # --- Freeze Components ---
        if config.freeze_vision_model:
            print("  Freezing Vision Model parameters.")
            for param in self.blip_model.vision_model.parameters():
                param.requires_grad = False
        if config.freeze_qformer:
            print("  Freezing Q-Former parameters.")
            for param in self.blip_model.qformer.parameters():
                param.requires_grad = False
        if config.freeze_language_model:
            if hasattr(self.blip_model, 'language_model') and self.blip_model.language_model is not None:
                 print("  Freezing Language Model parameters.")
                 for param in self.blip_model.language_model.parameters():
                    param.requires_grad = False
                 # Also freeze the language projection if it exists (maps Q-Former to LM input)
                 if hasattr(self.blip_model, 'language_projection') and self.blip_model.language_projection is not None:
                     for param in self.blip_model.language_projection.parameters():
                         param.requires_grad = False
            else:
                 print("  Language model component not found or is None, skipping freeze.")

        # Determine input dimension for projection heads (usually Q-Former output dim)
        qformer_hidden_size = self.blip_model.config.qformer_config.hidden_size
        print(f"  Q-Former hidden size (input to projection): {qformer_hidden_size}")

        # --- Projection Heads ---
        # Use separate projections for image and text features coming from Q-Former
        self.image_projection = nn.Linear(qformer_hidden_size, config.projection_dim, bias=False)
        self.text_projection = nn.Linear(qformer_hidden_size, config.projection_dim, bias=False)
        print(f"  Added projection heads: {qformer_hidden_size} -> {config.projection_dim}")

        # --- Learnable Temperature ---
        if config.learnable_temperature:
            self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / config.temperature))
            print(f"  Using learnable temperature (logit scale), initialized to {self.logit_scale.exp().item():.4f}")
        else:
            self.logit_scale = torch.tensor(np.log(1 / config.temperature)) # Keep on CPU initially, move in forward
            print(f"  Using fixed temperature: {config.temperature}")

        # Move projections to the correct device if model wasn't loaded with device_map
        if 'device_map' not in getattr(self.blip_model, 'hf_device_map', {}):
            model_device = config.device # Get device from config
            self.image_projection.to(model_device)
            self.text_projection.to(model_device)
            if isinstance(self.logit_scale, nn.Parameter):
                 self.logit_scale.to(model_device)
            else:
                 self.logit_scale = self.logit_scale.to(model_device) # Move tensor


    def forward(self, pixel_values, input_ids, attention_mask):
        # Extract features - QFormer output is used for multimodal understanding
        image_outputs = self.blip_model.get_image_features(pixel_values=pixel_values, return_dict=True)
        text_outputs = self.blip_model.get_text_features(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)

        # Use the pooled output (representation of [CLS] query token)
        image_features = image_outputs.last_hidden_state[:, 0, :] # Shape: [batch_size, qformer_hidden_size]
        text_features = text_outputs.last_hidden_state[:, 0, :]   # Shape: [batch_size, qformer_hidden_size]

        # Project features into the shared embedding space
        image_embeds = self.image_projection(image_features)
        text_embeds = self.text_projection(text_features)

        # Normalize embeddings
        image_embeds_norm = F.normalize(image_embeds, p=2, dim=-1)
        text_embeds_norm = F.normalize(text_embeds, p=2, dim=-1)

        # Cosine similarity scaled by temperature
        # Ensure logit_scale is on the same device as embeddings
        current_logit_scale = self.logit_scale.exp().to(image_embeds_norm.device)
        logits_per_image = current_logit_scale * image_embeds_norm @ text_embeds_norm.t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text, image_embeds_norm, text_embeds_norm

# --- Loss Function (Contrastive Loss) ---
def contrastive_loss(logits_per_image, logits_per_text):
    batch_size = logits_per_image.shape[0]
    if batch_size == 0: return torch.tensor(0.0, device=logits_per_image.device, requires_grad=True)
    labels = torch.arange(batch_size, device=logits_per_image.device)
    loss_img = F.cross_entropy(logits_per_image, labels)
    loss_txt = F.cross_entropy(logits_per_text, labels)
    total_loss = (loss_img + loss_txt) / 2.0
    return total_loss

print("BLIP2 Retrieval Model and Loss Function defined.")


# === Cell 7: Training and Validation Epoch Functions ===
def train_epoch(model, dataloader, optimizer, device, epoch_num, scaler):
    model.train()
    loss_meter = AvgMeter(f"Train Loss E{epoch_num}")
    # Use standard tqdm if not in notebook
    try:
        from tqdm.notebook import tqdm as tqdm_notebook
        progress_bar = tqdm_notebook(dataloader, desc=f"Training E{epoch_num}", leave=False, unit="batch")
    except ImportError:
        from tqdm import tqdm
        progress_bar = tqdm(dataloader, desc=f"Training E{epoch_num}", leave=False, unit="batch")


    for batch in progress_bar:
        optimizer.zero_grad()

        # Move batch to device
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        batch_size = pixel_values.size(0)
        if batch_size == 0: continue

        # Automatic Mixed Precision
        with autocast(enabled=config.use_amp):
            logits_per_image, logits_per_text, _, _ = model(pixel_values, input_ids, attention_mask)
            loss = contrastive_loss(logits_per_image, logits_per_text)

        # Gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loss_meter.update(loss.item(), batch_size)
        progress_bar.set_postfix(loss=f"{loss_meter.avg:.4f}")

    return loss_meter.avg

def validate_epoch(model, dataloader, device, epoch_num):
    model.eval()
    loss_meter = AvgMeter(f"Val Loss E{epoch_num}")
    # Initialize metric accumulators
    all_image_embeddings = []
    all_text_embeddings = []

    # Use standard tqdm if not in notebook
    try:
        from tqdm.notebook import tqdm as tqdm_notebook
        progress_bar = tqdm_notebook(dataloader, desc=f"Validation E{epoch_num}", leave=False, unit="batch")
    except ImportError:
        from tqdm import tqdm
        progress_bar = tqdm(dataloader, desc=f"Validation E{epoch_num}", leave=False, unit="batch")

    with torch.no_grad():
        for batch in progress_bar:
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            batch_size = pixel_values.size(0)
            if batch_size == 0: continue

            with autocast(enabled=config.use_amp):
                logits_per_image, logits_per_text, image_embeds, text_embeds = model(pixel_values, input_ids, attention_mask)
                loss = contrastive_loss(logits_per_image, logits_per_text)

            loss_meter.update(loss.item(), batch_size)

            # Store embeddings (move to CPU to conserve GPU memory during validation)
            all_image_embeddings.append(image_embeds.cpu())
            all_text_embeddings.append(text_embeds.cpu())

            # Display running loss in progress bar
            progress_bar.set_postfix(loss=f"{loss_meter.avg:.4f}")

    # Concatenate all embeddings
    if not all_image_embeddings or not all_text_embeddings:
         print("Warning: No embeddings collected during validation.")
         # Return zero/default metrics if no data processed
         zero_metrics = { "loss": loss_meter.avg, "avg_acc": 0.0, "avg_cosine_sim": 0.0,
                           "i2t R@1": 0.0, "i2t R@5": 0.0, "i2t R@10": 0.0,
                           "t2i R@1": 0.0, "t2i R@5": 0.0, "t2i R@10": 0.0,
                           "avg R@1": 0.0, "avg R@5": 0.0, "avg R@10": 0.0 }
         # Reformat keys for consistency with normal return
         return {k.replace('_', ' '): v for k,v in zero_metrics.items()}

    all_image_embeddings = torch.cat(all_image_embeddings, dim=0)
    all_text_embeddings = torch.cat(all_text_embeddings, dim=0)

    # Compute metrics over the entire validation set on the specified device
    print(f"\nComputing metrics over {all_image_embeddings.shape[0]} validation samples...")
    validation_metrics = compute_metrics(all_image_embeddings.to(device), all_text_embeddings.to(device))

    # Combine loss with computed metrics
    final_results = {"loss": loss_meter.avg}
    # Flatten the recall dictionaries for easier logging/history tracking
    for k, v in validation_metrics.items():
        if isinstance(v, dict):
            for recall_k, recall_v in v.items():
                final_results[f"{k.replace('_', ' ')} {recall_k}"] = recall_v
        else:
            final_results[k.replace('_', ' ')] = v # Replace underscores for keys like avg_acc

    return final_results

print("Training and Validation epoch functions defined.")


# === Cell 8: Setup - BLIP2 Processor ===
print(f"Loading BLIP2 Processor: {config.blip_processor_name}")
try:
    processor = Blip2Processor.from_pretrained(config.blip_processor_name)
    print("Processor loaded successfully.")
except Exception as e:
    print(f"ERROR loading processor '{config.blip_processor_name}': {e}")
    processor = None # Ensure processor is None if loading fails
    # raise e # Optionally stop execution


# === Cell 9: Setup - Datasets and DataLoaders ===
train_loader = None
dev_loader = None

if processor:
    print("\nCreating datasets...")
    train_json = os.path.join(config.data_path, "train.json")
    dev_json = os.path.join(config.data_path, "dev.json")

    train_dataset = Blip2ImageCaptionDataset(
        json_path=train_json,
        image_base_path=config.image_path,
        processor=processor,
        max_length=config.max_length
    )
    dev_dataset = Blip2ImageCaptionDataset(
        json_path=dev_json,
        image_base_path=config.image_path,
        processor=processor,
        max_length=config.max_length
    )

    if not train_dataset.data:
        print("\nERROR: Failed to load training data. Check 'train_json' path and format.")
    if not dev_dataset.data:
         print("\nWARNING: Failed to load validation data. Validation steps will be skipped or may error.")

    print("\nCreating dataloaders...")
    num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
    print(f"Using {num_workers} workers for DataLoaders.")

    if train_dataset.data:
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True if config.device == torch.device("cuda") else False,
            drop_last=True # Drop last incomplete batch for more stable training steps
        )
        print(f"Train loader created with {len(train_loader)} batches.")
    else:
        print("Skipping train loader creation due to missing training data.")

    if dev_dataset.data:
        dev_loader = DataLoader(
            dev_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True if config.device == torch.device("cuda") else False,
            drop_last=False # Keep last batch for full validation
        )
        print(f"Validation loader created with {len(dev_loader)} batches.")
    else:
        print("Skipping validation loader creation due to missing validation data.")

    if not train_loader:
         print("\nERROR: Train loader could not be created. Cannot proceed with training.")

else:
     print("ERROR: BLIP2 Processor not loaded. Skipping dataset and dataloader creation.")


# === Cell 10: Setup - Model, Optimizer, Scheduler, AMP Scaler ===
model = None
optimizer = None
lr_scheduler = None
scaler = None

print("\nInitializing model components...")
try:
    # Pass the config object to the model
    model = Blip2RetrievalModel(config) # Model is moved to device inside its __init__ if device_map not used
    print(f"\nBlip2RetrievalModel initialized.")
    # Calculate trainable params AFTER potential freezing
    num_params_total = sum(p.numel() for p in model.parameters())
    num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {num_params_total / 1e6:.2f} M")
    print(f"Trainable parameters: {num_params_trainable / 1e6:.2f} M")

except Exception as e:
    print(f"ERROR initializing BLIP2 model: {e}")
    model = None # Ensure model is None if init fails

# --- Optimizer Setup ---
if model: # Check if model was created
    print("\nSetting up optimizer...")
    # Get parameters based on their names/modules
    vision_params = [p for n, p in model.blip_model.vision_model.named_parameters() if p.requires_grad]
    qformer_params = [p for n, p in model.blip_model.qformer.named_parameters() if p.requires_grad]

    language_params = []
    if hasattr(model.blip_model, 'language_model') and model.blip_model.language_model is not None:
        language_params.extend([p for n, p in model.blip_model.language_model.named_parameters() if p.requires_grad])
    if hasattr(model.blip_model, 'language_projection') and model.blip_model.language_projection is not None:
         language_params.extend([p for n, p in model.blip_model.language_projection.named_parameters() if p.requires_grad])

    projection_params = [p for n, p in model.named_parameters() if ('image_projection' in n or 'text_projection' in n or 'logit_scale' in n) and p.requires_grad]

    print(f"  Param counts (Trainable): Vision={len(vision_params)}, QFormer={len(qformer_params)}, LM={len(language_params)}{' (FROZEN)' if config.freeze_language_model and not language_params else ''}, Projection={len(projection_params)}")

    optimizer_grouped_parameters = [
        {"params": vision_params, "lr": config.vision_encoder_lr, "weight_decay": config.weight_decay},
        {"params": qformer_params, "lr": config.qformer_lr, "weight_decay": config.weight_decay},
        {"params": language_params, "lr": config.language_model_lr, "weight_decay": config.weight_decay},
        {"params": projection_params, "lr": config.projection_lr, "weight_decay": config.weight_decay},
    ]

    # Filter out groups with zero parameters
    optimizer_grouped_parameters = [g for g in optimizer_grouped_parameters if g['params']]

    if not optimizer_grouped_parameters:
         print("ERROR: No trainable parameters found for the optimizer. Check freezing flags and model structure.")
    else:
        optimizer = optim.AdamW(optimizer_grouped_parameters)
        print(f"Optimizer AdamW initialized.")

        # --- LR Scheduler Setup ---
        lr_scheduler = ReduceLROnPlateau(
            optimizer,
            mode=config.mode,
            factor=config.factor,
            patience=config.patience
        )
        print(f"LR Scheduler ReduceLROnPlateau initialized (mode='{config.mode}', factor={config.factor}, patience={config.patience})")

        # --- Early Stopping Setup ---
        early_stopping_counter = 0
        print(f"Early stopping initialized (patience={config.early_stopping_patience}, min_delta={config.early_stopping_min_delta})")

        # --- AMP GradScaler Setup ---
        scaler = GradScaler(enabled=config.use_amp)
        print(f"AMP GradScaler initialized ({'enabled' if config.use_amp else 'disabled'}).")

else:
    print("ERROR: Model not initialized. Skipping optimizer/scheduler/scaler setup.")


# === Cell 11: Training Loop ===
# Check prerequisites exist before starting loop
if model and train_loader and optimizer and lr_scheduler and scaler:
    print(f"\nStarting training for {config.epochs} epochs...")
    print(f"Tracking metric: '{config.metric_to_track}' (mode: {config.mode})")
    print(f"Using AMP: {config.use_amp}")

    best_val_metric = -float('inf') if config.mode == "max" else float('inf')
    history = {'train_loss': [], 'validation_results': []}
    start_train_time = time.time()
    no_improve_epochs = 0

    for epoch in range(config.epochs):
        epoch_start_time = time.time()
        print(f"\n--- Epoch {epoch+1}/{config.epochs} ---")

        # --- Training ---
        train_loss = train_epoch(model, train_loader, optimizer, config.device, epoch+1, scaler)
        history['train_loss'].append(train_loss)
        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}")

        # --- Validation ---
        val_results = {"loss": float('inf'), config.metric_to_track.replace('_', ' '): (-float('inf') if config.mode == 'max' else float('inf'))} # Default with formatted key
        if dev_loader:
            val_results = validate_epoch(model, dev_loader, config.device, epoch+1)
            history['validation_results'].append(val_results)
            print("  Validation Metrics:")
            metric_log_str = "  "
            # Sort keys for consistent printing
            sorted_keys = sorted(val_results.keys(), key=lambda x: (x.split()[0], int(x.split('@')[-1]) if '@' in x else -1))
            for name in sorted_keys:
                 metric_log_str += f"{name}: {val_results[name]:.4f} | "
            print(metric_log_str.strip(" | "))


            # --- Scheduler Step ---
            current_val_metric_for_scheduler = val_results.get(config.metric_to_track.replace('_', ' '), None) # Use formatted key
            if current_val_metric_for_scheduler is not None:
                lr_scheduler.step(current_val_metric_for_scheduler)
                current_lrs = [group['lr'] for group in optimizer.param_groups]
                # Map LRs back to components for printing (adjust indices based on actual groups)
                lr_map = {config.vision_encoder_lr: 'Vision', config.qformer_lr: 'QF', config.language_model_lr: 'LM', config.projection_lr: 'Proj'}
                lr_str = "  Current LRs: " + ", ".join([f"{lr_map.get(group['lr'], f'Group{i}')}={group['lr']:.2e}" for i, group in enumerate(optimizer.param_groups)])
                print(lr_str)

            else:
                print(f"  Warning: Metric '{config.metric_to_track}' not found in validation results. Scheduler not stepped.")
        else:
            print("  Validation skipped (no dev_loader).")
            history['validation_results'].append(None)

        # --- Save Checkpoint & Early Stopping Logic ---
        current_val_metric = val_results.get(config.metric_to_track.replace('_', ' '), -float('inf') if config.mode == "max" else float('inf')) # Use formatted key
        is_best = False
        improved = False

        if dev_loader:
            if config.mode == "max":
                if current_val_metric > best_val_metric + config.early_stopping_min_delta:
                    is_best = True
                    improved = True
            else: # config.mode == "min"
                if current_val_metric < best_val_metric - config.early_stopping_min_delta:
                    is_best = True
                    improved = True

            if is_best:
                print(f"  Metric '{config.metric_to_track}' improved from {best_val_metric:.4f} to {current_val_metric:.4f}")
                best_val_metric = current_val_metric
                no_improve_epochs = 0
            else:
                no_improve_epochs += 1
                print(f"  Metric '{config.metric_to_track}' did not improve. Best: {best_val_metric:.4f}. Counter: {no_improve_epochs}/{config.early_stopping_patience}")

        # Prepare save dictionary - saving state_dict is generally preferred for large models
        save_dict = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': lr_scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'train_loss': train_loss,
            'validation_results': val_results,
            'best_val_metric': best_val_metric,
            'metric_tracked': config.metric_to_track,
            'blip_config_dict': model.blip_model.config.to_dict() # Save base model config
        }

        # Save logic
        best_checkpoint_path = os.path.join(config.model_path, "blip2_retrieval_best.pt")
        final_epoch_path = os.path.join(config.model_path, f"blip2_retrieval_epoch_{epoch+1}.pt")

        if config.save_best_only and dev_loader:
            if is_best:
                torch.save(save_dict, best_checkpoint_path)
                print(f"  Saved Best Model (Epoch {epoch+1}) to {best_checkpoint_path}")
        else:
            torch.save(save_dict, final_epoch_path)
            print(f"  Saved Epoch {epoch+1} Checkpoint to {final_epoch_path}")
            if is_best and dev_loader:
                 torch.save(save_dict, best_checkpoint_path)
                 print(f"  (Also saved as best model)")

        epoch_end_time = time.time()
        print(f"--- Epoch {epoch+1} Time: {epoch_end_time - epoch_start_time:.2f} seconds ---")

        # Early Stopping Check
        if dev_loader and no_improve_epochs >= config.early_stopping_patience:
            print(f"\nEarly stopping triggered after {config.early_stopping_patience} epochs without improvement.")
            break

    # --- End of Training ---
    end_train_time = time.time()
    total_train_time = end_train_time - start_train_time
    print(f"\n=============== Training Finished ================")
    print(f"Total Training Time: {total_train_time:.2f} seconds ({total_train_time/60:.2f} minutes)")

    # Save the final epoch's state dictionary separately
    final_model_path = os.path.join(config.model_path, 'blip2_retrieval_final_epoch.pt')
    # Make sure save_dict has the state from the *last* completed epoch
    torch.save(save_dict, final_model_path)
    print(f"Final epoch model state saved to {final_model_path}")

    best_model_file = os.path.join(config.model_path, "blip2_retrieval_best.pt")
    if dev_loader and os.path.exists(best_model_file):
        print(f"Best model based on '{config.metric_to_track}' ({best_val_metric:.4f}) is saved at: {best_model_file}")
    elif dev_loader:
        print("Best model checkpoint file not found. The final epoch model is saved.")
    print(f"=================================================")

else:
    print("ERROR: Prerequisites for training (model, dataloader, optimizer, scheduler, scaler) not met. Training loop skipped.")


# === Cell 12: Final Evaluation on Test Set ===
print("\n=============== Starting Test Set Evaluation ===============")

test_loader = None
model_to_test = None
test_json_path = os.path.join(config.data_path, "test.json")

# 1. Check if test data and processor exist
if os.path.exists(test_json_path) and 'processor' in globals() and processor:
    print(f"Loading test data from: {test_json_path}")
    try:
        test_dataset = Blip2ImageCaptionDataset(
            json_path=test_json_path,
            image_base_path=config.image_path,
            processor=processor,
            max_length=config.max_length
        )

        if test_dataset.data:
            num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
            test_loader = DataLoader(
                test_dataset,
                batch_size=config.batch_size,
                shuffle=False,
                num_workers=num_workers,
                pin_memory=True if config.device == torch.device("cuda") else False,
                drop_last=False
            )
            print(f"Test loader created with {len(test_loader)} batches.")
        else:
             print("Test dataset loaded but is empty. Skipping evaluation.")

    except Exception as e:
        print(f"Error creating test dataset/loader: {e}")
else:
    print("Skipping test evaluation: Test JSON or Processor not found/loaded.")

# 2. Load Model for Testing if test_loader was created
if test_loader:
    try:
        # Determine which model weights to load (best or final)
        best_model_path = os.path.join(config.model_path, "blip2_retrieval_best.pt")
        final_model_path = os.path.join(config.model_path, "blip2_retrieval_final_epoch.pt")

        load_path = None
        if os.path.exists(best_model_path):
            load_path = best_model_path
            print(f"\nAttempting to load best model weights from: {load_path}")
        elif os.path.exists(final_model_path):
            load_path = final_model_path
            print(f"\nBest model not found. Attempting to load final epoch weights from: {load_path}")
        else:
            print(f"\nWARNING: No saved model checkpoints ('best' or 'final') found in {config.model_path}.")

        if load_path:
            checkpoint = torch.load(load_path, map_location=config.device)

            # Re-create model using saved config, then load state_dict
            print("Re-creating model structure for testing...")
            if 'blip_config_dict' in checkpoint:
                 saved_blip_config = Blip2Config.from_dict(checkpoint['blip_config_dict'])
                 temp_config = config # Start with current config
                 temp_config.blip_model_name = saved_blip_config._name_or_path
                 print(f"  Using base model config from checkpoint: {temp_config.blip_model_name}")
                 model_to_test = Blip2RetrievalModel(temp_config)
            else:
                 print("Warning: Blip config not found in checkpoint, using current CFG.")
                 model_to_test = Blip2RetrievalModel(config)

            state_dict = checkpoint['model_state_dict']
            # Handle 'module.' prefix
            if all(k.startswith('module.') for k in state_dict.keys()):
                print("Detected 'module.' prefix, removing for loading.")
                from collections import OrderedDict
                new_state_dict = OrderedDict((k[7:], v) for k, v in state_dict.items())
                state_dict = new_state_dict

            load_result = model_to_test.load_state_dict(state_dict, strict=False)
            print(f"  State dict loading result: {load_result}")
            if load_result.missing_keys:
                 print(f"  Warning: Missing keys: {load_result.missing_keys}")
            if load_result.unexpected_keys:
                 print(f"  Warning: Unexpected keys: {load_result.unexpected_keys}")

            # Ensure model is on device if not using device_map
            if 'device_map' not in getattr(model_to_test.blip_model, 'hf_device_map', {}):
                 model_to_test.to(config.device)

            print(f"Model weights loaded successfully from {load_path}")

            # --- Run Evaluation ---
            print("\nRunning evaluation on test set...")
            test_results = validate_epoch(model_to_test, test_loader, config.device, epoch_num="Test")

            print("\n--- Test Set Results ---")
            metric_log_str = ""
            sorted_keys = sorted(test_results.keys(), key=lambda x: (x.split()[0], int(x.split('@')[-1]) if '@' in x else -1))
            for name in sorted_keys:
                 value = test_results[name]
                 metric_log_str += f"  {name}: {value:.4f}\n"
            print(metric_log_str.strip())
            print("------------------------")

        else:
             print("Evaluation skipped as no model weights were found to load.")

    except Exception as e:
        print(f"\nERROR during test setup or evaluation: {e}")
        import traceback
        traceback.print_exc()

print("\n================= Evaluation Finished =================")


# === Cell 13: Training Visualization (Adapted) ===
import matplotlib.pyplot as plt
import seaborn as sns
import os
import numpy as np

sns.set_style("whitegrid")
plot_dir = "train_plot"
os.makedirs(plot_dir, exist_ok=True)
print(f"Plot directory ensured at: {os.path.abspath(plot_dir)}")

def save_subplot_as_figure(subplot, save_path):
    fig_new = plt.figure(figsize=(8, 6))
    ax_new = fig_new.add_subplot(111)
    lines = subplot.get_lines()
    if not lines:
        print(f"Warning: No lines found in subplot for {save_path}")
        plt.close(fig_new)
        return
    labels = [line.get_label() for line in lines]
    for line in lines:
        ax_new.plot(line.get_xdata(), line.get_ydata(),
                    color=line.get_color(),
                    linestyle=line.get_linestyle(),
                    marker=line.get_marker(),
                    label=line.get_label())
    ax_new.set_title(subplot.get_title())
    ax_new.set_xlabel(subplot.get_xlabel())
    ax_new.set_ylabel(subplot.get_ylabel())
    ax_new.grid(True)
    if any(label and not label.startswith('_') for label in labels):
         ax_new.legend()
    plt.tight_layout()
    fig_new.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close(fig_new)

def plot_training_metrics(history):
    if not history or not history.get('train_loss') or not history.get('validation_results'):
        print("No training history available or history is incomplete.")
        return

    valid_results = [res for res in history['validation_results'] if res is not None]
    if not valid_results:
        print("No valid validation results found. Plotting only training loss.")
        epochs = range(1, len(history['train_loss']) + 1)
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        ax.plot(epochs, history['train_loss'], 'b-o', label='Training Loss')
        ax.set_title('Training Loss over Epochs')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.legend()
        ax.grid(True)
        save_path = os.path.join(plot_dir, f'training_loss.png')
        fig.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"Saved loss plot to: {save_path}")
        # plt.show() # Don't show in script mode
        plt.close(fig)
        return

    num_epochs_trained = len(history['train_loss'])
    num_epochs_validated = len(valid_results)
    epochs_train = range(1, num_epochs_trained + 1)
    epochs_val = range(1, num_epochs_validated + 1)

    fig, axes = plt.subplots(2, 2, figsize=(16, 13))
    fig.suptitle('Training and Validation Metrics', fontsize=16, y=1.02)

    # --- Plot Loss ---
    val_loss = [res.get('loss', float('nan')) for res in valid_results] # Use .get for safety
    axes[0, 0].plot(epochs_train, history['train_loss'], 'b-o', label='Training Loss')
    axes[0, 0].plot(epochs_val, val_loss, 'r-s', label='Validation Loss')
    axes[0, 0].set_title('Loss over Epochs')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)

    # --- Plot Accuracy (Average Accuracy) ---
    metric_key_acc = 'avg acc' # Key from validate_epoch output
    if metric_key_acc in valid_results[0]:
        val_acc = [res[metric_key_acc] for res in valid_results]
        axes[0, 1].plot(epochs_val, val_acc, 'g-^', label='Average Accuracy (Val)')
        axes[0, 1].set_title('Validation Average Accuracy over Epochs')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
    else:
        axes[0, 1].set_title(f'Validation Acc ({metric_key_acc}) (Not Found)')

    # --- Plot Recall Metrics ---
    has_recall = 'i2t recall R@1' in valid_results[0] # Check a representative key

    if has_recall:
        # I2T Recall
        for k in [1, 5, 10]:
            key = f'i2t recall R@{k}'
            values = [res.get(key, float('nan')) for res in valid_results]
            axes[1, 0].plot(epochs_val, values, marker='o', label=f'I2T R@{k}')
        axes[1, 0].set_title('Image-to-Text Recall (Val)')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Recall')
        axes[1, 0].legend()
        axes[1, 0].grid(True)

        # T2I Recall
        for k in [1, 5, 10]:
            key = f't2i recall R@{k}'
            values = [res.get(key, float('nan')) for res in valid_results]
            axes[1, 1].plot(epochs_val, values, marker='s', label=f'T2I R@{k}')
        axes[1, 1].set_title('Text-to-Image Recall (Val)')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Recall')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
    else:
        axes[1, 0].set_title('I2T Recall (Not Found)')
        axes[1, 1].set_title('T2I Recall (Not Found)')

    plt.tight_layout(rect=[0, 0, 1, 0.96])

    # Save individual plots
    plot_names = ['loss', 'accuracy', 'i2t_recall', 't2i_recall']
    for idx, name in enumerate(plot_names):
        i, j = divmod(idx, 2)
        save_path = os.path.join(plot_dir, f'training_{name}.png')
        if axes[i, j].has_data():
            save_subplot_as_figure(axes[i, j], save_path)
            print(f"Saved {name} plot to: {save_path}")
        else:
            print(f"Skipping save for {name} plot (no data).")

    # Save combined plot
    combined_save_path = os.path.join(plot_dir, 'training_metrics_combined.png')
    fig.savefig(combined_save_path, bbox_inches='tight', dpi=300)
    print(f"Saved combined plot to: {combined_save_path}")

    # plt.show() # Avoid showing plots in a script run
    plt.close(fig)

# --- Plotting ---
# Check if the 'history' variable exists from the training loop before plotting
if 'history' in locals() and isinstance(history, dict) and history.get('train_loss'):
    plot_training_metrics(history)
else:
    print("No training history found or history is empty. Run training first to generate history.")


# --- END OF SCRIPT ---