In [None]:
# -*- coding: utf-8 -*-
"""ViBLIP_finetune.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/notebooks/empty.ipynb

Adapts the ViCLIP training script for fine-tuning BLIP-2 on a Vietnamese Image Captioning dataset.
"""

# Cell 1: Installs

# Cell 2: Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms # Still potentially useful for basic loading if not using datasets library
from transformers import AutoProcessor, Blip2ForConditionalGeneration, Blip2Config
from transformers import get_scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau # Optional: Can still use ReduceLROnPlateau or use transformers schedulers
from PIL import Image
import json
import os
import random
import numpy as np
from tqdm.notebook import tqdm
import torch.nn.functional as F
import math
import time # For timing epochs
import nltk # For BLEU/ROUGE calculation if not using evaluate
import evaluate # Using Hugging Face evaluate library for metrics

# Check GPU availability
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
else:
    print("WARNING: CUDA not available, training will be on CPU.")

# Cell 3: Configuration Class (CFG) - Adapted for ViBLIP

class CFG:
    # --- Paths ---
    # Base directory where your train.json, dev.json, test.json are located
    image_path = "./data/OpenViVQA-dataset/"

    # Output directory for saved models

    data_path = "./json_data/" # Assumes json files are in the same directory as the notebook
    image_base_path = "./data/OpenViVQA-dataset/" # Base directory for images referenced in JSON

    # Output directory for saved models and logs
    output_dir = "./ViBLIP_vivqa"
    logging_dir = f"{output_dir}/logs"
    checkpoint_dir = f"{output_dir}/checkpoints"

    # --- BLIP-2 Model Selection ---
    # Using Salesforce/blip2-opt-2.7b as a base. Replace if a better Vietnamese BLIP2 exists.
    # The processor and model should match.
    blip2_processor_name = "Salesforce/blip2-opt-2.7b"
    blip2_model_name = "Salesforce/blip2-opt-2.7b"

    # --- Training parameters ---
    seed = 42
    batch_size = 16 # Adjust based on GPU memory (BLIP2 is larger than CLIP components)
    num_workers = 4  # Adjust based on your system
    learning_rate = 1e-5 # Lower LR often better for fine-tuning large models
    weight_decay = 1e-4
    # Scheduler options (using transformers scheduler or ReduceLROnPlateau)
    lr_scheduler_type = "linear" # e.g., linear, cosine, constant, reduce_lr_on_plateau
    num_warmup_steps = 100
    # ReduceLROnPlateau specific (if lr_scheduler_type='reduce_lr_on_plateau')
    rlrop_factor = 0.8
    rlrop_patience = 3

    epochs = 10 # Adjust as needed
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gradient_accumulation_steps = 2 # Effective batch size = batch_size * gradient_accumulation_steps

    # --- Text Generation parameters (for evaluation) ---
    generation_max_length = 50 # Max length of generated captions
    num_beams = 4 # Beam search size

    # --- Checkpointing/Logging parameters ---
    save_best_only = True
    metric_to_track = "bleu" # Metric for best model saving and LR scheduling (e.g., 'bleu', 'rougeL', 'loss')
    mode = "max" if metric_to_track != "loss" else "min" # For comparing tracked metric
    log_interval = 10 # Log training loss every N steps
    eval_strategy = "epoch" # Evaluate every epoch ('steps' also possible)
    save_strategy = "epoch" # Save checkpoint every epoch ('steps' also possible)
    save_total_limit = 2 # Keep only the best and the last checkpoint


# --- Instantiate Config and Create Output Dirs ---
config = CFG()
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.logging_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)

print(f"Using device: {config.device}")
print(f"Output path: {os.path.abspath(config.output_dir)}")
print(f"Image base path: {os.path.abspath(config.image_base_path)}")
print(f"Using BLIP-2 Processor: {config.blip2_processor_name}")
print(f"Using BLIP-2 Model: {config.blip2_model_name}")

# Cell 4: Seeding for Reproducibility

def set_seed(seed=config.seed):
    print(f"Setting seed: {seed}")
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) # For multi-GPU
        # Deterministic operations can impact performance, use if needed
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False

set_seed()

# Cell 5: Metric Calculation Utilities & Setup

class AvgMeter:
    """Computes and stores the average and current value"""
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0

    def update(self, val, count=1):
        if torch.is_tensor(val):
             val = val.item()
        if isinstance(val, (int, float)):
            self.sum += val * count
            self.count += count
            self.avg = self.sum / self.count if self.count != 0 else 0

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

# Load evaluation metrics from 'evaluate' library
try:
    bleu_metric = evaluate.load("bleu")
    rouge_metric = evaluate.load("rouge")
    # You can add more metrics like meteor, cider, spice if needed
    # meteor_metric = evaluate.load("meteor")
    # cider_metric = evaluate.load("cider") # May require installing dependencies
    # spice_metric = evaluate.load("spice") # May require installing dependencies
    print("Evaluation metrics (BLEU, ROUGE) loaded successfully.")
except Exception as e:
    print(f"Error loading evaluation metrics: {e}")
    print("Please ensure 'evaluate' and 'nltk' are installed correctly.")
    bleu_metric = None
    rouge_metric = None


def compute_captioning_metrics(predictions, references):
    """Computes BLEU and ROUGE scores using the evaluate library."""
    metrics = {}
    if bleu_metric:
        try:
            # BLEU expects lists of strings
            bleu_results = bleu_metric.compute(predictions=predictions, references=[[ref] for ref in references]) # References need to be list of lists
            # Extract main BLEU score (often BLEU-4) and potentially others
            metrics['bleu'] = bleu_results.get('bleu', 0.0) # Overall BLEU
            # metrics['bleu-1'] = bleu_results.get('precisions', [0]*4)[0]
            # metrics['bleu-4'] = bleu_results.get('precisions', [0]*4)[3]
        except Exception as e:
            print(f"Warning: Could not compute BLEU score: {e}")
            metrics['bleu'] = 0.0
    else:
        metrics['bleu'] = 0.0


    if rouge_metric:
        try:
            # ROUGE also expects lists of strings
            rouge_results = rouge_metric.compute(predictions=predictions, references=references) # References can be list of strings here
            metrics.update({k: v for k, v in rouge_results.items()}) # Add rougeL, rouge1, etc.
        except Exception as e:
            print(f"Warning: Could not compute ROUGE score: {e}")
            # Initialize ROUGE keys to 0 if computation fails
            metrics.update({'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0, 'rougeLsum': 0.0})

    else:
         metrics.update({'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0, 'rougeLsum': 0.0})

    # Add more metrics here if loaded (e.g., meteor, cider, spice)

    return metrics


print("Metric utilities defined.")

# Cell 6: Dataset Class Definition - Adapted for ViBLIP Processor

class ImageCaptionDataset(Dataset):
    def __init__(self, json_path, image_base_path, processor):
        super().__init__()
        print(f"Attempting to load data from: {os.path.abspath(json_path)}")
        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                self.data = json.load(f)
        except FileNotFoundError:
            print(f"ERROR: JSON file not found at {json_path}")
            self.data = []
        except json.JSONDecodeError:
            print(f"Error: Could not decode JSON from {json_path}")
            self.data = []
        except Exception as e:
            print(f"An unexpected error occurred loading {json_path}: {e}")
            self.data = []

        print(f"Found {len(self.data)} samples in {os.path.basename(json_path)}.")
        self.image_base_path = image_base_path
        self.processor = processor

        if not os.path.isdir(self.image_base_path):
             print(f"WARNING: Image base path does not exist: {os.path.abspath(self.image_base_path)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if idx >= len(self.data):
             raise IndexError("Index out of bounds")
        item = self.data[idx]
        relative_image_path = item.get('image_path', None)
        captions = item.get('caption', []) # Expecting a list

        # Use the first caption if available, otherwise provide an empty string
        caption = captions[0] if captions else ""

        image = None
        if relative_image_path:
            image_path = os.path.normpath(os.path.join(self.image_base_path, relative_image_path))
            try:
                image = Image.open(image_path).convert('RGB')
            except FileNotFoundError:
                print(f"Warning: Img not found: {image_path}. Using blank image.")
                image = Image.new('RGB', (224, 224), color = 'white') # Create blank image
            except Exception as e:
                print(f"Warning: Error loading image {image_path}: {e}. Using blank image.")
                image = Image.new('RGB', (224, 224), color = 'white') # Create blank image
        else:
            print(f"Warning: Missing 'image_path' for item at index {idx}. Using blank image.")
            image = Image.new('RGB', (224, 224), color = 'white') # Create blank image

        # Process image and text using the BLIP-2 processor
        # For training, we need pixel_values, input_ids, attention_mask, and labels
        # The processor handles image preprocessing and tokenization
        # We provide the caption as 'text' which becomes 'input_ids' and 'attention_mask'
        # For the decoder labels during training, we typically use the same caption tokens
        encoding = self.processor(
            images=image,
            text=caption,
            padding="max_length", # Pad text to max length defined in processor/config
            truncation=True,
            return_tensors="pt"
        )

        # The processor might return tensors with a batch dimension of 1, remove it.
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}

        # Set labels for language modeling loss (usually same as input_ids for captioning)
        encoding["labels"] = encoding["input_ids"].clone()

        # Optional: Mask padding tokens in labels if processor doesn't handle it automatically for loss
        # padding_token_id = self.processor.tokenizer.pad_token_id
        # if padding_token_id is not None:
        #     encoding["labels"][encoding["labels"] == padding_token_id] = -100 # Common practice to ignore padding in loss

        # Add raw caption for evaluation purposes
        encoding["raw_caption"] = caption

        return encoding

print("ImageCaptionDataset class defined for BLIP-2.")

# Cell 7: Training and Validation Epoch Functions - Adapted for ViBLIP

def train_epoch(model, dataloader, optimizer, scheduler, device, epoch_num, scaler=None, grad_accumulation_steps=1):
    model.train()
    loss_meter = AvgMeter(f"Train Loss E{epoch_num}")
    optimizer.zero_grad() # Zero gradients at the start of the epoch

    progress_bar = tqdm(dataloader, desc=f"Training E{epoch_num}", leave=False, unit="batch")
    for step, batch in enumerate(progress_bar):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}

        # Use autocast if scaler is provided (for mixed precision)
        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model(
                    pixel_values=batch['pixel_values'],
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels']
                )
                loss = outputs.loss
        else:
            outputs = model(
                pixel_values=batch['pixel_values'],
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            loss = outputs.loss

        # Normalize loss for gradient accumulation
        if grad_accumulation_steps > 1:
            loss = loss / grad_accumulation_steps

        # Backpropagation
        if scaler:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        # Optimizer step (conditionally based on accumulation)
        if (step + 1) % grad_accumulation_steps == 0:
            if scaler:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            scheduler.step() # Step scheduler after optimizer step
            optimizer.zero_grad() # Zero gradients after stepping

        loss_meter.update(loss.item() * grad_accumulation_steps, batch['pixel_values'].size(0)) # Use original loss for meter
        progress_bar.set_postfix(loss=f"{loss_meter.avg:.4f}", lr=f"{scheduler.get_last_lr()[0]:.2e}")

        # Log loss periodically
        # if (step + 1) % (config.log_interval * grad_accumulation_steps) == 0:
        #     print(f"  Step {step+1}/{len(dataloader)*config.epochs}: Train Loss = {loss_meter.avg:.4f}")


    return loss_meter.avg

def validate_epoch(model, processor, dataloader, device, epoch_num):
    model.eval()
    loss_meter = AvgMeter(f"Val Loss E{epoch_num}")
    all_predictions = []
    all_references = []

    progress_bar = tqdm(dataloader, desc=f"Validation E{epoch_num}", leave=False, unit="batch")
    with torch.no_grad():
        for batch in progress_bar:
            # Store raw reference captions
            raw_captions = batch.pop("raw_caption", []) # Pop non-tensor data before moving batch
            all_references.extend(raw_captions)

            # Move tensors to device
            batch_tensors = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}

            # --- Calculate Loss ---
            # Use autocast for consistency, though grads aren't needed
            if device == torch.device("cuda"):
                 with torch.cuda.amp.autocast():
                    outputs = model(
                        pixel_values=batch_tensors['pixel_values'],
                        input_ids=batch_tensors['input_ids'],
                        attention_mask=batch_tensors['attention_mask'],
                        labels=batch_tensors['labels']
                    )
                    loss = outputs.loss
            else:
                 outputs = model(
                     pixel_values=batch_tensors['pixel_values'],
                     input_ids=batch_tensors['input_ids'],
                     attention_mask=batch_tensors['attention_mask'],
                     labels=batch_tensors['labels']
                 )
                 loss = outputs.loss

            loss_meter.update(loss.item(), batch_tensors['pixel_values'].size(0))

            # --- Generate Captions ---
            generated_ids = model.generate(
                pixel_values=batch_tensors['pixel_values'],
                max_length=config.generation_max_length,
                num_beams=config.num_beams,
                early_stopping=True
            )
            # Decode generated captions
            generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
            all_predictions.extend([caption.strip() for caption in generated_captions])

            progress_bar.set_postfix(loss=f"{loss_meter.avg:.4f}")

    # Ensure predictions and references are lists of strings
    all_predictions = [str(p) for p in all_predictions]
    all_references = [str(r) for r in all_references]


    # --- Compute Metrics ---
    validation_metrics = compute_captioning_metrics(all_predictions, all_references)
    validation_metrics['loss'] = loss_meter.avg

    # Optionally print a few generated vs reference captions
    print("\n--- Sample Generation vs Reference ---")
    for i in range(min(3, len(all_predictions))):
        print(f"  Pred {i+1}: {all_predictions[i]}")
        print(f"  Ref  {i+1}: {all_references[i]}")
    print("------------------------------------")

    return validation_metrics


print("Training and Validation epoch functions defined for ViBLIP.")

# Cell 8: Setup - Processor

print(f"Loading BLIP-2 Processor: {config.blip2_processor_name}")
try:
    processor = AutoProcessor.from_pretrained(config.blip2_processor_name)
    print("Processor loaded successfully.")
    # Set decoder_start_token_id if needed (some models require it explicitly)
    # if not hasattr(model.config, "decoder_start_token_id") or model.config.decoder_start_token_id is None:
    #    print("Setting decoder_start_token_id to pad_token_id")
    #    model.config.decoder_start_token_id = processor.tokenizer.pad_token_id

except Exception as e:
    print(f"ERROR loading processor '{config.blip2_processor_name}': {e}")
    print("Please ensure the model name is correct and you have internet access or the model is cached.")
    processor = None # Set to None to prevent errors in subsequent cells

# Cell 9: Setup - Datasets and DataLoaders

print("\nCreating datasets...")
train_json = os.path.join(config.data_path, "train.json")
dev_json = os.path.join(config.data_path, "dev.json")
test_json = os.path.join(config.data_path, "test.json")

train_loader = None
dev_loader = None
test_loader = None

if processor: # Only proceed if processor loaded correctly
    train_dataset = ImageCaptionDataset(
        json_path=train_json,
        image_base_path=config.image_base_path,
        processor=processor
    )
    dev_dataset = ImageCaptionDataset(
        json_path=dev_json,
        image_base_path=config.image_base_path,
        processor=processor
    )

    # Basic checks
    if not train_dataset.data:
        print("\nERROR: Failed to load training data.")
    if not dev_dataset.data:
        print("\nWARNING: Failed to load validation data. Validation steps will be skipped.")

    print("\nCreating dataloaders...")
    num_workers = min(config.num_workers, os.cpu_count() if os.cpu_count() else 1)
    print(f"Using {num_workers} workers for DataLoaders.")

    if train_dataset.data:
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True if config.device == torch.device("cuda") else False,
            drop_last=False, # Keep last batch for training
            collate_fn=lambda batch: processor.pad(batch, return_tensors="pt") # Use processor for dynamic padding
        )
        print(f"Train loader created with {len(train_loader)} batches (batch size: {config.batch_size}).")

    if dev_dataset.data:
        dev_loader = DataLoader(
            dev_dataset,
            batch_size=config.batch_size * 2, # Often can use larger batch size for eval
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True if config.device == torch.device("cuda") else False,
            drop_last=False,
            collate_fn=lambda batch: processor.pad(batch, return_tensors="pt")
        )
        print(f"Validation loader created with {len(dev_loader)} batches (batch size: {config.batch_size * 2}).")

    # Optional: Create test loader later if needed
    if os.path.exists(test_json):
         test_dataset = ImageCaptionDataset(
            json_path=test_json,
            image_base_path=config.image_base_path,
            processor=processor
         )
         if test_dataset.data:
             test_loader = DataLoader(
                 test_dataset,
                 batch_size=config.batch_size * 2,
                 shuffle=False,
                 num_workers=num_workers,
                 pin_memory=True if config.device == torch.device("cuda") else False,
                 drop_last=False,
                 collate_fn=lambda batch: processor.pad(batch, return_tensors="pt")
             )
             print(f"Test loader created with {len(test_loader)} batches.")
         else:
            print("Test JSON loaded but was empty.")
    else:
        print("Test JSON not found. Skipping test loader creation.")

    if not train_loader:
         print("\nERROR: Train loader could not be created.")

else:
    print("ERROR: Processor not loaded. Cannot create datasets and dataloaders.")

# Cell 10: Setup - Model, Optimizer, Scheduler

model = None
optimizer = None
lr_scheduler = None
scaler = None # For mixed precision

if processor and train_loader: # Only proceed if processor and train_loader exist
    print("\nInitializing BLIP-2 model...")
    try:
        model = Blip2ForConditionalGeneration.from_pretrained(config.blip2_model_name)
        model.to(config.device)
        print(f"BLIP-2 model '{config.blip2_model_name}' loaded successfully on {config.device}.")
        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Trainable parameters: {num_params / 1e6:.2f} M")

        # Optional: Freeze parts of the model if desired (e.g., vision encoder)
        # for param in model.vision_model.parameters():
        #     param.requires_grad = False
        # num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        # print(f"Trainable parameters after freezing vision encoder: {num_params / 1e6:.2f} M")


    except Exception as e:
        print(f"ERROR initializing model '{config.blip2_model_name}': {e}")
        print("Check model name, internet connection, and available memory.")
        model = None # Ensure model is None if loading fails

    if model:
        print("\nSetting up optimizer...")
        # Use AdamW optimizer, common for transformers
        optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        print(f"Optimizer AdamW initialized (LR={config.learning_rate}, WD={config.weight_decay}).")

        # --- LR Scheduler Setup ---
        num_training_steps = config.epochs * len(train_loader) // config.gradient_accumulation_steps
        if config.lr_scheduler_type == 'reduce_lr_on_plateau':
             lr_scheduler = ReduceLROnPlateau(
                 optimizer, mode=config.mode, factor=config.rlrop_factor, patience=config.rlrop_patience
             )
             print(f"LR Scheduler ReduceLROnPlateau initialized (mode='{config.mode}', factor={config.rlrop_factor}, patience={config.rlrop_patience})")
        else:
            # Use Hugging Face scheduler
            lr_scheduler = get_scheduler(
                name=config.lr_scheduler_type,
                optimizer=optimizer,
                num_warmup_steps=config.num_warmup_steps * config.gradient_accumulation_steps,
                num_training_steps=num_training_steps
            )
            print(f"LR Scheduler '{config.lr_scheduler_type}' initialized (Warmup: {config.num_warmup_steps}, Total Steps: {num_training_steps})")

        # --- Mixed Precision Setup (Optional but recommended for large models) ---
        if config.device == torch.device("cuda"):
            scaler = torch.cuda.amp.GradScaler()
            print("CUDA GradScaler enabled for mixed precision.")

else:
     print("ERROR: Model, Processor or Train Loader not available. Skipping optimizer/scheduler setup.")

# Cell 11: Training Loop

if model and train_loader and optimizer and lr_scheduler and processor: # Check prerequisites
    print(f"\nStarting training for {config.epochs} epochs...")
    print(f"Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
    print(f"Tracking metric: '{config.metric_to_track}' (mode: {config.mode}) for best model saving.")

    best_val_metric = -float('inf') if config.mode == "max" else float('inf')
    history = {'train_loss': [], 'validation_results': []}
    start_train_time = time.time()
    global_step = 0

    for epoch in range(config.epochs):
        epoch_start_time = time.time()
        print(f"\n--- Epoch {epoch+1}/{config.epochs} ---")

        # --- Training ---
        avg_train_loss = train_epoch(
            model, train_loader, optimizer, lr_scheduler, config.device,
            epoch+1, scaler, config.gradient_accumulation_steps
        )
        history['train_loss'].append(avg_train_loss)
        print(f"Epoch {epoch+1}: Average Train Loss = {avg_train_loss:.4f}")

        # --- Validation ---
        val_results = {"loss": float('inf'), config.metric_to_track: best_val_metric} # Default if no validation
        if dev_loader and config.eval_strategy == "epoch":
            print("Running validation...")
            val_results = validate_epoch(model, processor, dev_loader, config.device, epoch+1)
            history['validation_results'].append(val_results)
            # Print validation metrics
            print("  Validation Metrics:")
            metric_log_str = "  "
            for name, value in val_results.items():
                metric_log_str += f"{name}: {value:.4f} | "
            print(metric_log_str.strip(" | "))

            current_val_metric = val_results.get(config.metric_to_track, None)

            # Step ReduceLROnPlateau scheduler if using it
            if isinstance(lr_scheduler, ReduceLROnPlateau):
                 if current_val_metric is not None:
                    lr_scheduler.step(current_val_metric)
                    print(f"  Current LR (ReduceLROnPlateau): {optimizer.param_groups[0]['lr']:.2e}")
                 else:
                    print(f"  Warning: Metric '{config.metric_to_track}' not found. ReduceLROnPlateau scheduler not stepped.")

        else:
             print("  Validation skipped for this epoch based on strategy.")
             history['validation_results'].append(None) # Append None if no validation


        # --- Save Checkpoint ---
        if config.save_strategy == "epoch":
            checkpoint_path = os.path.join(config.checkpoint_dir, f"epoch_{epoch+1}")
            model.save_pretrained(checkpoint_path)
            processor.save_pretrained(checkpoint_path)
            print(f"  Checkpoint saved to {checkpoint_path}")

            is_best = False
            if current_val_metric is not None:
                if config.mode == "max" and current_val_metric > best_val_metric:
                    is_best = True
                    best_val_metric = current_val_metric
                elif config.mode == "min" and current_val_metric < best_val_metric:
                    is_best = True
                    best_val_metric = current_val_metric

            if is_best:
                best_checkpoint_path = os.path.join(config.checkpoint_dir, "best_model")
                model.save_pretrained(best_checkpoint_path)
                processor.save_pretrained(best_checkpoint_path)
                print(f"  Saved Best Model (Epoch {epoch+1}, {config.metric_to_track}={current_val_metric:.4f}) to {best_checkpoint_path}")

                # Save training history with best model (optional)
                history_path = os.path.join(best_checkpoint_path, "training_history.json")
                with open(history_path, 'w') as f:
                    json.dump(history, f, indent=4)


        # --- Manage Checkpoints (Remove older ones if limit is set) ---
        if config.save_total_limit is not None and config.save_strategy == "epoch":
            checkpoints = sorted(
                [d for d in os.listdir(config.checkpoint_dir) if d.startswith("epoch_")],
                key=lambda x: int(x.split('_')[1])
            )
            if len(checkpoints) > config.save_total_limit:
                import shutil
                checkpoint_to_remove = os.path.join(config.checkpoint_dir, checkpoints[0])
                print(f"  Removing old checkpoint: {checkpoint_to_remove}")
                shutil.rmtree(checkpoint_to_remove)


        epoch_end_time = time.time()
        print(f"--- Epoch {epoch+1} Time: {epoch_end_time - epoch_start_time:.2f} seconds ---")


    # --- End of Training ---
    end_train_time = time.time()
    total_train_time = end_train_time - start_train_time
    print(f"\n=============== Training Finished ================")
    print(f"Total Training Time: {total_train_time:.2f} seconds ({total_train_time/60:.2f} minutes)")

    # Save final model state
    final_model_path = os.path.join(config.checkpoint_dir, 'final_model')
    model.save_pretrained(final_model_path)
    processor.save_pretrained(final_model_path)
    print(f"Final model state saved to {final_model_path}")

    # Save final training history
    history_path = os.path.join(final_model_path, "training_history.json")
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=4)
    print(f"Final training history saved to {history_path}")

    if os.path.exists(os.path.join(config.checkpoint_dir, "best_model")):
        print(f"Best model based on '{config.metric_to_track}' ({best_val_metric:.4f}) saved in: {os.path.join(config.checkpoint_dir, 'best_model')}")
    print(f"=================================================")

else:
    print("ERROR: Prerequisites for training not met (Model, Processor, Train Loader, Optimizer, or Scheduler). Training loop skipped.")

# Cell 12: Final Evaluation on Test Set

print("\n=============== Starting Test Set Evaluation ===============")

if test_loader and processor:
    # --- Load Best Model for Testing ---
    best_model_path = os.path.join(config.checkpoint_dir, "best_model")
    final_model_path = os.path.join(config.checkpoint_dir, "final_model")
    model_to_test = None
    processor_to_test = None

    load_path = None
    if os.path.exists(best_model_path):
        load_path = best_model_path
        print(f"Attempting to load best model weights from: {load_path}")
    elif os.path.exists(final_model_path):
        load_path = final_model_path
        print(f"Best model not found. Attempting to load final model weights from: {load_path}")
    else:
        print("WARNING: No saved model checkpoints ('best_model' or 'final_model') found.")
        print("         Evaluation will not be performed.")

    if load_path:
        try:
            print(f"Loading model from {load_path}...")
            model_to_test = Blip2ForConditionalGeneration.from_pretrained(load_path)
            processor_to_test = AutoProcessor.from_pretrained(load_path)
            model_to_test.to(config.device)
            print("Model and processor for testing loaded successfully.")

            # --- Run Evaluation ---
            print("\nRunning evaluation on test set...")
            test_results = validate_epoch(model_to_test, processor_to_test, test_loader, config.device, epoch_num="Test")

            print("\n--- Test Set Results ---")
            metric_log_str = ""
            for name, value in test_results.items():
                metric_log_str += f"  {name}: {value:.4f}\n"
            print(metric_log_str.strip())
            print("------------------------")

            # Save test results
            test_results_path = os.path.join(config.output_dir, "test_results.json")
            with open(test_results_path, 'w') as f:
                json.dump(test_results, f, indent=4)
            print(f"Test results saved to {test_results_path}")

        except Exception as e:
            print(f"\nERROR loading model/processor or running evaluation from {load_path}: {e}")
            import traceback
            traceback.print_exc()

else:
    if not test_loader:
         print("Test loader not available. Skipping test evaluation.")
    if not processor:
         print("Processor not available. Skipping test evaluation.")

print("\n================= Evaluation Finished ==================")

# Cell 13: Training Visualization (Adapted for Captioning Metrics)

import matplotlib.pyplot as plt
import seaborn as sns
import os
import numpy as np

# Create plot directory if it doesn't exist
plot_dir = os.path.join(config.output_dir, "plots")
os.makedirs(plot_dir, exist_ok=True)
print(f"\nPlot directory created at: {os.path.abspath(plot_dir)}")

# --- Load History ---
# Try loading from best model first, then final model
history_loaded = None
best_history_path = os.path.join(config.checkpoint_dir, "best_model", "training_history.json")
final_history_path = os.path.join(config.checkpoint_dir, "final_model", "training_history.json")

if os.path.exists(best_history_path):
    print(f"Loading history from {best_history_path}")
    with open(best_history_path, 'r') as f:
        history_loaded = json.load(f)
elif os.path.exists(final_history_path):
    print(f"Loading history from {final_history_path}")
    with open(final_history_path, 'r') as f:
        history_loaded = json.load(f)
elif 'history' in globals():
     print("Using history from current training run.")
     history_loaded = history # Use history from the training loop if available


def save_plot(fig, save_path):
    """Helper function to save a figure"""
    plt.tight_layout()
    fig.savefig(save_path, bbox_inches='tight', dpi=300)
    print(f"Saved plot to: {save_path}")
    plt.close(fig) # Close figure after saving


def plot_training_metrics_blip(history_data):
    if not history_data or not history_data.get('train_loss') or not history_data.get('validation_results'):
        print("No valid training history available to plot.")
        return

    # Filter out None entries in validation results if any epoch was skipped
    valid_val_results = [res for res in history_data['validation_results'] if res]
    if not valid_val_results:
         print("No valid validation results found in history.")
         # Plot only training loss if available
         if history_data.get('train_loss'):
            epochs = range(1, len(history_data['train_loss']) + 1)
            fig_loss, ax_loss = plt.subplots(1, 1, figsize=(8, 6))
            ax_loss.plot(epochs, history_data['train_loss'], 'b-', label='Training Loss')
            ax_loss.set_title('Training Loss over Epochs')
            ax_loss.set_xlabel('Epoch')
            ax_loss.set_ylabel('Loss')
            ax_loss.legend()
            ax_loss.grid(True)
            save_plot(fig_loss, os.path.join(plot_dir, 'training_loss.png'))
            plt.show()
         return

    # Determine number of epochs based on validation results
    epochs = range(1, len(valid_val_results) + 1)
    train_loss_to_plot = history_data['train_loss'][:len(epochs)] # Match length if training loss is longer

    # Prepare data for plotting
    val_loss = [res.get('loss', np.nan) for res in valid_val_results] # Use np.nan for missing keys
    val_bleu = [res.get('bleu', np.nan) for res in valid_val_results]
    val_rougeL = [res.get('rougeL', np.nan) for res in valid_val_results]

    # Create figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=(20, 6)) # Adjusted for 3 main plots
    fig.suptitle('Training and Validation Metrics', fontsize=16, y=1.02)

    # Plot Loss
    axes[0].plot(epochs, train_loss_to_plot, 'b-', label='Training Loss')
    axes[0].plot(epochs, val_loss, 'r-', label='Validation Loss')
    axes[0].set_title('Loss over Epochs')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True)
    save_subplot_as_figure(axes[0], os.path.join(plot_dir, 'training_loss.png'))


    # Plot BLEU Score
    axes[1].plot(epochs, val_bleu, 'g-', label='Validation BLEU')
    axes[1].set_title('Validation BLEU Score over Epochs')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('BLEU')
    axes[1].legend()
    axes[1].grid(True)
    save_subplot_as_figure(axes[1], os.path.join(plot_dir, 'training_bleu.png'))


    # Plot ROUGE-L Score
    axes[2].plot(epochs, val_rougeL, 'm-', label='Validation ROUGE-L')
    axes[2].set_title('Validation ROUGE-L Score over Epochs')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('ROUGE-L')
    axes[2].legend()
    axes[2].grid(True)
    save_subplot_as_figure(axes[2], os.path.join(plot_dir, 'training_rougeL.png'))


    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap

    # Save combined plot
    combined_save_path = os.path.join(plot_dir, 'training_metrics_combined.png')
    fig.savefig(combined_save_path, bbox_inches='tight', dpi=300)
    print(f"Saved combined plot to: {combined_save_path}")

    # Display the plot
    plt.show()


# Plot if history was loaded or exists
if history_loaded:
    plot_training_metrics_blip(history_loaded)
else:
    print("No training history found to plot.")

# Cell 14: Helper function to save subplots (needed for plot_training_metrics_blip)
# NOTE: This needs to be defined *before* calling plot_training_metrics_blip
# Moved it here for logical flow, ensure it's run before Cell 13 if running cells individually.

def save_subplot_as_figure(subplot, save_path):
    """Helper function to save a subplot axes as a separate figure"""
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111)

    # Copy lines and labels
    lines = subplot.get_lines()
    if not lines: # Skip if no lines to plot
        plt.close(fig)
        return
    labels = [line.get_label() for line in lines]

    for line in lines:
        ax.plot(line.get_xdata(), line.get_ydata(),
                color=line.get_color(),
                label=line.get_label(),
                linestyle=line.get_linestyle(),
                marker=line.get_marker()) # Copy marker style too

    # Copy title, labels, grid, and legend
    ax.set_title(subplot.get_title())
    ax.set_xlabel(subplot.get_xlabel())
    ax.set_ylabel(subplot.get_ylabel())
    ax.grid(subplot.axes.has_grid()) # Copy grid status
    if any(label and not label.startswith('_') for label in labels): # Check if there are valid labels for legend
        ax.legend()

    save_plot(fig, save_path) # Use the save_plot helper


print("Subplot saving helper function defined.")