In [None]:
from IPython import get_ipython
from IPython.display import display

import os
import json
import torch
import random
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image, UnidentifiedImageError
from transformers import GPT2Tokenizer, GPT2LMHeadModel, get_scheduler, GPT2Config
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu
import re # For clean_caption
from sklearn.model_selection import train_test_split # For proper data splitting

# --- 0. Configuration and Paths ---
# Adjust these paths to where you saved your models and where Flickr30k is located

# FLICKR_BASE_PATH should be the directory that DIRECTLY contains the image files (e.g., 1000092795.jpg)
# If your images are in /content/MyDrive/data/flickr30k_images/, set it to that.
# Based on your previous successful runs, it seems your images are in the innermost 'flickr30k_images' folder.
FLICKR_BASE_PATH = '/content/MyDrive/MyDrive/M.Sc Data Science/Sem 4/CV/flickr30k_images/flickr30k_images/flickr30k_images'

CAPTIONS_FILE = '/content/MyDrive/MyDrive/M.Sc Data Science/Sem 4/CV/flickr30k_images/flickr_captions.json'
EMOTION_MODEL_PATH = '/content/MyDrive/MyDrive/M.Sc Data Science/Sem 4/CV/eff_fine_tuned_model.pth' # Path to your saved emotion model
# This path should be a specific file for saving the best trained combined model
CAPTIONING_MODEL_SAVE_PATH = '/content/MyDrive/MyDrive/M.Sc Data Science/Sem 4/CV/fine_tuned_best_model/best_caption_model_finetuned_emotion_aware.pt'

# Ensure the directory for saving the model exists
os.makedirs(os.path.dirname(CAPTIONING_MODEL_SAVE_PATH), exist_ok=True)

# --- Path Validation Check for Image Directory ---
# This helps ensure the FLICKR_BASE_PATH points to the correct location of images.
if not os.path.isdir(FLICKR_BASE_PATH):
    print(f"Error: FLICKR_BASE_PATH '{FLICKR_BASE_PATH}' does not exist or is not a directory.")
    print("Please verify the path to your Flickr30k image folder in Google Drive.")
    # Consider exiting here if the path is critical: sys.exit(1)
elif not any(fname.lower().endswith(('.jpg', '.jpeg', '.png')) for fname in os.listdir(FLICKR_BASE_PATH)):
    print(f"Warning: FLICKR_BASE_PATH '{FLICKR_BASE_PATH}' does not appear to contain image files directly.")
    print("This might mean you need to adjust the path to the actual image subdirectory (e.g., 'path/to/flickr30k_images/flickr30k_images' if it's nested).")
    print("The script will attempt to proceed, but you might encounter many 'Image not found' fallbacks.")
else:
    print(f"FLICKR_BASE_PATH '{FLICKR_BASE_PATH}' seems valid and contains images.")

# Hyperparameters for full training
IMG_SIZE = 224
BATCH_SIZE = 16
NUM_CLASSES = 7 # Adjust according to your emotion dataset (e.g., 7 emotions)
MAX_CAPTION_LENGTH = 25
NUM_EPOCHS = 1 # Increased epochs for better training
LEARNING_RATE = 5e-5
PATIENCE = 5 # For EarlyStopping

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- 1. Utility Functions ---

# Function to clean captions (copied from your Image captioning model code.pdf)
def clean_caption(text):
    text = re.sub(r'([.!?]{1,3}){2,}', ".", text) # Replace multiple punctuation with single dot
    sentence_end = re.search(r'[.!?]', text)
    if sentence_end:
        text = text[:sentence_end.end()]
    text = text.strip("`.'")
    if text:
        text = text[0].upper() + text[1:] if len(text) > 1 else text.upper()
    return text.lower() # Return lower case for BLEU calculation

# EarlyStopping class (copied and adapted from your emotion model code.pdf)
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.best_state = None

    def __call__(self, val_loss, model_state_dict):
        score = -val_loss # We want to minimize loss, so maximize -loss

        if self.best_score is None:
            self.best_score = score
            self.val_loss_min = val_loss
            self.best_state = model_state_dict
            if self.verbose:
                print(f"Validation loss decreased ({self.val_loss_min:.4f} --> {val_loss:.4f}). Saving model ...")
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.val_loss_min = val_loss
            self.best_state = model_state_dict
            self.counter = 0
            if self.verbose:
                print(f"Validation loss decreased ({self.val_loss_min:.4f} --> {val_loss:.4f}). Saving model ...")


# --- 2. Data Preparation ---

# Image Transforms for all models
image_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Custom Dataset Class (from your Image captioning model code.pdf)
class ImageCaptionDataset(Dataset):
    def __init__(self, image_folder, samples, tokenizer, transform):
        self.image_folder = image_folder
        self.samples = samples # list of (img_name, caption)
        self.tokenizer = tokenizer
        self.transform = transform
        self.max_length = MAX_CAPTION_LENGTH # Use global max_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, caption = self.samples[idx]
        img_path = os.path.join(self.image_folder, img_name)

        try:
            image = Image.open(img_path).convert("RGB")
        except (OSError, UnidentifiedImageError):
            # Fallback to a black image if unreadable
            image = Image.new('RGB', (IMG_SIZE, IMG_SIZE), (0, 0, 0))

        image = self.transform(image)
        tokens = self.tokenizer(caption, return_tensors='pt', padding='max_length',
                                truncation=True, max_length=self.max_length)
        input_ids = tokens['input_ids'].squeeze(0)
        attention_mask = tokens['attention_mask'].squeeze(0)

        # Return all five items as before. The loop will decide which ones to unpack.
        return image, input_ids, attention_mask, caption, img_name

# Load captions (Flickr30k)
with open(CAPTIONS_FILE, 'r') as f:
    captions_data = json.load(f)

# Prepare all samples (img_name, caption) - using first 5 captions for each image
all_image_caption_pairs = []
for img_name, caps in captions_data.items():
    for cap in caps:
        all_image_caption_pairs.append((img_name, cap))

print(f"Total image-caption pairs loaded: {len(all_image_caption_pairs)}")

# Create tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token # pad token set to eos token
tokenizer.bos_token = tokenizer.bos_token or tokenizer.eos_token

# --- Proper Train/Validation/Test Split ---
# We'll split the unique images first, then distribute their captions
unique_image_names = list(captions_data.keys())
train_img_names, test_img_names = train_test_split(unique_image_names, test_size=0.1, random_state=42)
train_img_names, val_img_names = train_test_split(train_img_names, test_size=0.1/0.9, random_state=42) # 10% of total for val

train_samples = [(img, cap) for img, cap in all_image_caption_pairs if img in train_img_names]
val_samples = [(img, cap) for img, cap in all_image_caption_pairs if img in val_img_names]
test_samples = [(img, cap) for img, cap in all_image_caption_pairs if img in test_img_names]

print(f"Train samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")
print(f"Test samples: {len(test_samples)}")

train_dataset = ImageCaptionDataset(FLICKR_BASE_PATH, train_samples, tokenizer, image_transform)
val_dataset = ImageCaptionDataset(FLICKR_BASE_PATH, val_samples, tokenizer, image_transform)
test_dataset = ImageCaptionDataset(FLICKR_BASE_PATH, test_samples, tokenizer, image_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=16) # Increased workers
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=16)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=16)


# --- 3. Model Definitions ---

# 3.1. Emotion Classification Model (Copied from your emotion model code)
def create_emotion_model(base_model_name='efficientnet_b0', num_classes=NUM_CLASSES, pretrained=True):
    if base_model_name == 'efficientnet_b0':
        model = models.efficientnet_b0(pretrained=pretrained)
        for param in model.parameters():
            param.requires_grad = False # Freeze base layers
        in_features = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes),
        )
    elif base_model_name == 'resnet50':
        model = models.resnet50(pretrained=pretrained)
        for param in model.parameters():
            param.requires_grad = False # Freeze base layers
        in_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes),
        )
    else:
        raise ValueError("Unsupported model name")
    return model.to(device)

# Load the trained emotion classification model
emotion_classifier = create_emotion_model('efficientnet_b0', NUM_CLASSES, pretrained=False)
emotion_classifier.load_state_dict(torch.load(EMOTION_MODEL_PATH, map_location=device))
emotion_classifier.eval() # Set to evaluation mode; weights are frozen
print(f"Loaded emotion classification model from {EMOTION_MODEL_PATH}")

# 3.2. Core Captioning Model Components (from your image captioning model code)
encoder_cnn = models.efficientnet_b0(pretrained=True)
encoder_cnn.classifier = nn.Identity() # Remove classifier head
encoder_cnn.to(device)
encoder_cnn.eval() # Freeze encoder; its weights are not updated

decoder = GPT2LMHeadModel.from_pretrained('gpt2')
decoder.resize_token_embeddings(len(tokenizer))
decoder.to(device)

projector = nn.Linear(1280, decoder.config.n_embd).to(device) # Project encoder features to GPT-2 embedding dim

# --- 4. Emotion-Aware Model Definition (NEW COMBINED MODEL) ---

class EmotionAwareCaptioningModel(nn.Module):
    def __init__(self, encoder_cnn, emotion_classifier, projector, decoder, num_emotions, decoder_embedding_dim):
        super().__init__()
        self.encoder_cnn = encoder_cnn
        self.emotion_classifier = emotion_classifier
        self.projector = projector
        self.decoder = decoder
        # A linear layer to project the 7-dim emotion probabilities into the decoder's embedding dimension
        self.emotion_projection = nn.Linear(num_emotions, decoder_embedding_dim).to(device)

    def forward(self, images, input_ids=None, attention_mask=None):
        # 1. Get visual features from CNN encoder
        with torch.no_grad(): # Ensure encoder_cnn and emotion_classifier remain frozen
            image_features = self.encoder_cnn(images) # (batch_size, 1280)
            image_embeddings = self.projector(image_features).unsqueeze(1) # (batch_size, 1, decoder_embedding_dim)

            # 2. Get emotion prediction from emotion classifier
            emotion_logits = self.emotion_classifier(images) # (batch_size, NUM_CLASSES)
            emotion_probs = torch.softmax(emotion_logits, dim=-1) # (batch_size, NUM_CLASSES)
            # Project emotion probabilities to the decoder's embedding dimension
            emotion_embedding = self.emotion_projection(emotion_probs).unsqueeze(1) # (batch_size, 1, decoder_embedding_dim)

        # 3. Get embeddings for captions (for teacher forcing during training)
        if input_ids is not None:
            # Shift input_ids for teacher forcing, add BOS token explicitly
            bos_tokens = torch.full((input_ids.shape[0], 1), tokenizer.bos_token_id, dtype=torch.long, device=device)
            # Concatenate BOS token to the beginning of input_ids
            decoder_input_ids = torch.cat([bos_tokens, input_ids], dim=1)
            caption_embeddings = self.decoder.transformer.wte(decoder_input_ids) # (batch_size, seq_len+1, embd_dim)

            # 4. Concatenate image, emotion, and caption embeddings
            # The order here is crucial: [image_embedding, emotion_embedding, caption_embeddings]
            # This makes the first effective input to the decoder a combined visual+emotional token.
            decoder_inputs = torch.cat([image_embeddings, emotion_embedding, caption_embeddings[:, :-1]], dim=1) # Remove last token from caption_embeddings for shifted input
        else:
            # For inference, only image_embedding and emotion_embedding form the initial input
            decoder_inputs = torch.cat([image_embeddings, emotion_embedding], dim=1)

        # 5. Pass through GPT-2 decoder
        if input_ids is not None:
            # For training, we provide the whole sequence of embeddings
            # Corrected: expanded_attention_mask should match the length of decoder_inputs
            expanded_attention_mask = torch.cat([
                torch.ones(images.shape[0], 2, dtype=torch.long, device=device), # For image_embed and emotion_embed (2 tokens)
                attention_mask # Use the full original attention_mask for the caption part (MAX_CAPTION_LENGTH tokens)
            ], dim=1)
            # This results in an attention mask length of 2 + MAX_CAPTION_LENGTH, which matches decoder_inputs.

            outputs = self.decoder(inputs_embeds=decoder_inputs, attention_mask=expanded_attention_mask)
            return outputs
        else:
            # For inference, the generate method handles sequence generation
            # Initial inputs_embeds for generation
            initial_attention_mask = torch.ones(decoder_inputs.shape[0], decoder_inputs.shape[1], dtype=torch.long, device=device)
            generated_output = self.decoder.generate(
                inputs_embeds=decoder_inputs,
                max_length=MAX_CAPTION_LENGTH,
                num_beams=3, # Use beam search for better quality
                early_stopping=True,
                no_repeat_ngram_size=2,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                attention_mask=initial_attention_mask # Initial attention mask for the prompt
            )
            return generated_output

# Instantiate the full emotion-aware model
emotion_aware_model = EmotionAwareCaptioningModel(
    encoder_cnn=encoder_cnn,
    emotion_classifier=emotion_classifier,
    projector=projector,
    decoder=decoder,
    num_emotions=NUM_CLASSES,
    decoder_embedding_dim=decoder.config.n_embd
)
emotion_aware_model.to(device)

# --- 5. Training Configuration for Emotion-Aware Model ---

# Only optimize parameters of the projector, decoder, and the new emotion_projection
optimizer = optim.AdamW(
    list(projector.parameters()) +
    list(decoder.parameters()) +
    list(emotion_aware_model.emotion_projection.parameters()),
    lr=LEARNING_RATE
)

# CrossEntropyLoss expects logits and target indices (not one-hot)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# --- 6. Training Loop ---
early_stopping = EarlyStopping(patience=PATIENCE, verbose=True)

print("\n--- Starting Emotion-Aware Captioning Model Training ---")
for epoch in range(NUM_EPOCHS):
    emotion_aware_model.train() # Set model to training mode
    total_train_loss = 0.0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} - Training")
    # Corrected unpacking: Expect 5 items from the dataset, but only unpack the first 3 for training
    for images, input_ids, attention_mask, _, _ in loop:
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)

        optimizer.zero_grad()

        # Forward pass through the emotion-aware model
        outputs = emotion_aware_model(images, input_ids=input_ids, attention_mask=attention_mask)

        # Logits for the actual caption tokens start from index 2
        # (index 0 is image_embed, index 1 is emotion_embed)
        logits_for_caption = outputs.logits[:, 2:, :]

        # Labels are the actual input_ids (from tokenizer), reshaped
        loss = loss_fn(logits_for_caption.reshape(-1, logits_for_caption.size(-1)), input_ids.reshape(-1))

        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

        loop.set_postfix(loss=loss.item())

    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")

    # --- Validation Loop ---
    emotion_aware_model.eval() # Set model to evaluation mode
    total_val_loss = 0.0
    val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} - Validation")
    with torch.no_grad():
        # Corrected unpacking for validation loop as well
        for images, input_ids, attention_mask, _, _ in val_loop:
            images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)

            outputs = emotion_aware_model(images, input_ids=input_ids, attention_mask=attention_mask)
            logits_for_caption = outputs.logits[:, 2:, :]
            loss = loss_fn(logits_for_caption.reshape(-1, logits_for_caption.size(-1)), input_ids.reshape(-1))
            total_val_loss += loss.item()
            val_loop.set_postfix(val_loss=loss.item())

    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}")

    # --- Early Stopping and Model Saving ---
    # Prepare model state for saving
    model_state = {
        'encoder_cnn_state_dict': encoder_cnn.state_dict(), # Not strictly necessary if always frozen, but good for completeness
        'projector_state_dict': projector.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'emotion_projection_state_dict': emotion_aware_model.emotion_projection.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'tokenizer': tokenizer,
        'epoch': epoch + 1,
        'val_loss': avg_val_loss
    }

    early_stopping(avg_val_loss, model_state)
    if early_stopping.early_stop:
        print("Early stopping triggered. Restoring best model weights.")
        # Load the best state back if early stopping is triggered
        best_state = early_stopping.best_state
        projector.load_state_dict(best_state['projector_state_dict'])
        decoder.load_state_dict(best_state['decoder_state_dict'])
        emotion_aware_model.emotion_projection.load_state_dict(best_state['emotion_projection_state_dict'])
        break # Exit training loop

# Save the final (or best) model
print(f"Saving best model to {CAPTIONING_MODEL_SAVE_PATH}")
torch.save(early_stopping.best_state, CAPTIONING_MODEL_SAVE_PATH)


# --- 7. Final Inference/Testing Loop ---
print("\n--- Running Final Evaluation on Test Set ---")
# Load the best model to ensure evaluation is on the best weights
best_saved_state = torch.load(CAPTIONING_MODEL_SAVE_PATH, map_location=device)
projector.load_state_dict(best_saved_state['projector_state_dict'])
decoder.load_state_dict(best_saved_state['decoder_state_dict'])
emotion_aware_model.emotion_projection.load_state_dict(best_saved_state['emotion_projection_state_dict'])

emotion_aware_model.eval() # Set model to evaluation mode

# --- BLEU Evaluation on Test Set ---
hypotheses = [] # List of generated captions (tokenized)
references = {} # Dictionary mapping image_name to list of reference captions (tokenized)

# Prepare references from full captions_data (ensure test set images have all their references)
for img_name, caps in captions_data.items():
    # Only include images that are actually in the test_samples for consistency
    if img_name in [s[0] for s in test_samples]:
        references[img_name] = [clean_caption(cap).lower().split() for cap in caps]

# We need to ensure we only generate one caption per unique test image
# and that the references are correctly collected for those specific unique images.
# The `test_loader` iterates through all samples, which includes 5 captions per image.
# We should iterate through unique test image names for caption generation and BLEU calculation.
unique_test_image_names = list(set([s[0] for s in test_samples]))

hypotheses_for_bleu = []
references_for_bleu = []

print("\n--- Generating captions for unique test images for BLEU ---")
for img_name in tqdm(unique_test_image_names, desc="Generating captions"):
    img_path = os.path.join(FLICKR_BASE_PATH, img_name)
    try:
        image = Image.open(img_path).convert("RGB")
    except (OSError, UnidentifiedImageError):
        image = Image.new('RGB', (IMG_SIZE, IMG_SIZE), (0, 0, 0)) # Fallback

    image_tensor = image_transform(image).unsqueeze(0).to(device) # Add batch dimension

    with torch.no_grad():
        generated_ids = emotion_aware_model(image_tensor)
        pred_caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        pred_caption = clean_caption(pred_caption)

    hypotheses_for_bleu.append(pred_caption.split())
    # Add all 5 references for this unique image
    references_for_bleu.append(references[img_name])


# Calculate BLEU scores
bleu_1 = corpus_bleu(references_for_bleu, hypotheses_for_bleu, weights=(1, 0, 0, 0))
bleu_2 = corpus_bleu(references_for_bleu, hypotheses_for_bleu, weights=(0.5, 0.5, 0, 0))
bleu_3 = corpus_bleu(references_for_bleu, hypotheses_for_bleu, weights=(0.33, 0.33, 0.33, 0))
bleu_4 = corpus_bleu(references_for_bleu, hypotheses_for_bleu, weights=(0.25, 0.25, 0.25, 0.25))

print(f"\nFinal BLEU-1 Score on {len(unique_test_image_names)} images: {bleu_1 * 100:.2f}%")
print(f"Final BLEU-2 Score on {len(unique_test_image_names)} images: {bleu_2 * 100:.2f}%")
print(f"Final BLEU-3 Score on {len(unique_test_image_names)} images: {bleu_3 * 100:.2f}%")
print(f"Final BLEU-4 Score on {len(unique_test_image_names)} images: {bleu_4 * 100:.2f}%")


# --- Sample Generated Captions from Test Set ---
print("\n--- Example Generated Captions from Test Set (with Predicted Emotion) ---")
# Get a few sample images from the unique test set for qualitative review
sample_img_names_for_qualitative = random.sample(unique_test_image_names, min(5, len(unique_test_image_names)))

for i, img_name in enumerate(sample_img_names_for_qualitative):
    img_path = os.path.join(FLICKR_BASE_PATH, img_name)
    try:
        image = Image.open(img_path).convert("RGB")
    except (OSError, UnidentifiedImageError):
        image = Image.new('RGB', (IMG_SIZE, IMG_SIZE), (0, 0, 0)) # Fallback

    image_tensor = image_transform(image).unsqueeze(0).to(device)

    # Predict emotion for the image (using the frozen classifier)
    with torch.no_grad():
        emotion_logits_pred = emotion_classifier(image_tensor)
        predicted_emotion_idx = torch.argmax(emotion_logits_pred, dim=1).item()
        emotion_class_names = ['Angry', 'Disgusted', 'Fearful', 'Happy', 'Neutral', 'Sad', 'Surprised']
        predicted_emotion_name = emotion_class_names[predicted_emotion_idx]

    # Generate caption from the emotion-aware model
    with torch.no_grad():
        generated_ids = emotion_aware_model(image_tensor)
        pred_caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        pred_caption = clean_caption(pred_caption)

    print(f"\nImage: {img_name}")
    print(f"Predicted Emotion: {predicted_emotion_name}")
    print(f"Generated Caption: {pred_caption}")
    print("References:")
    for ref_cap in references[img_name]: # Use references dict
        print(f"  - {' '.join(ref_cap)}") # Join tokens back for display

print("\n--- Training and Evaluation Complete ---")


In [None]:
# --- 7. Final Inference/Testing Loop ---
print("\n--- Running Final Evaluation on Test Set ---")
# Load the best model to ensure evaluation is on the best weights
# Corrected: Added map_location and weights_only=False to load the tokenizer
best_saved_state = torch.load(CAPTIONING_MODEL_SAVE_PATH, map_location=device, weights_only=False)
projector.load_state_dict(best_saved_state['projector_state_dict'])
decoder.load_state_dict(best_saved_state['decoder_state_dict'])
emotion_aware_model.emotion_projection.load_state_dict(best_saved_state['emotion_projection_state_dict'])

emotion_aware_model.eval() # Set model to evaluation mode

# Retrieve the tokenizer from the loaded state
tokenizer = best_saved_state['tokenizer'] # Set model to evaluation mode

# --- BLEU Evaluation on Test Set ---
hypotheses = [] # List of generated captions (tokenized)
references = {} # Dictionary mapping image_name to list of reference captions (tokenized)

# Prepare references from full captions_data (ensure test set images have all their references)
for img_name, caps in captions_data.items():
    # Only include images that are actually in the test_samples for consistency
    if img_name in [s[0] for s in test_samples]:
        references[img_name] = [clean_caption(cap).lower().split() for cap in caps]

# We need to ensure we only generate one caption per unique test image
# and that the references are correctly collected for those specific unique images.
# The `test_loader` iterates through all samples, which includes 5 captions per image.
# We should iterate through unique test image names for caption generation and BLEU calculation.
unique_test_image_names = list(set([s[0] for s in test_samples]))

hypotheses_for_bleu = []
references_for_bleu = []

print("\n--- Generating captions for unique test images for BLEU ---")
for img_name in tqdm(unique_test_image_names, desc="Generating captions"):
    img_path = os.path.join(FLICKR_BASE_PATH, img_name)
    try:
        image = Image.open(img_path).convert("RGB")
    except (OSError, UnidentifiedImageError):
        image = Image.new('RGB', (IMG_SIZE, IMG_SIZE), (0, 0, 0)) # Fallback

    image_tensor = image_transform(image).unsqueeze(0).to(device) # Add batch dimension

    with torch.no_grad():
        generated_ids = emotion_aware_model(image_tensor)
        pred_caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        pred_caption = clean_caption(pred_caption)

    hypotheses_for_bleu.append(pred_caption.split())
    # Add all 5 references for this unique image
    references_for_bleu.append(references[img_name])


# Calculate BLEU scores
bleu_1 = corpus_bleu(references_for_bleu, hypotheses_for_bleu, weights=(1, 0, 0, 0))
bleu_2 = corpus_bleu(references_for_bleu, hypotheses_for_bleu, weights=(0.5, 0.5, 0, 0))
bleu_3 = corpus_bleu(references_for_bleu, hypotheses_for_bleu, weights=(0.33, 0.33, 0.33, 0))
bleu_4 = corpus_bleu(references_for_bleu, hypotheses_for_bleu, weights=(0.25, 0.25, 0.25, 0.25))

print(f"\nFinal BLEU-1 Score on {len(unique_test_image_names)} images: {bleu_1 * 100:.2f}%")
print(f"Final BLEU-2 Score on {len(unique_test_image_names)} images: {bleu_2 * 100:.2f}%")
print(f"Final BLEU-3 Score on {len(unique_test_image_names)} images: {bleu_3 * 100:.2f}%")
print(f"Final BLEU-4 Score on {len(unique_test_image_names)} images: {bleu_4 * 100:.2f}%")


# --- Sample Generated Captions from Test Set ---
print("\n--- Example Generated Captions from Test Set (with Predicted Emotion) ---")
# Get a few sample images from the unique test set for qualitative review
sample_img_names_for_qualitative = random.sample(unique_test_image_names, min(5, len(unique_test_image_names)))

for i, img_name in enumerate(sample_img_names_for_qualitative):
    img_path = os.path.join(FLICKR_BASE_PATH, img_name)
    try:
        image = Image.open(img_path).convert("RGB")
    except (OSError, UnidentifiedImageError):
        image = Image.new('RGB', (IMG_SIZE, IMG_SIZE), (0, 0, 0)) # Fallback

    image_tensor = image_transform(image).unsqueeze(0).to(device)

    # Predict emotion for the image (using the frozen classifier)
    with torch.no_grad():
        emotion_logits_pred = emotion_classifier(image_tensor)
        predicted_emotion_idx = torch.argmax(emotion_logits_pred, dim=1).item()
        emotion_class_names = ['Angry', 'Disgusted', 'Fearful', 'Happy', 'Neutral', 'Sad', 'Surprised']
        predicted_emotion_name = emotion_class_names[predicted_emotion_idx]

    # Generate caption from the emotion-aware model
    with torch.no_grad():
        generated_ids = emotion_aware_model(image_tensor)
        pred_caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        pred_caption = clean_caption(pred_caption)

    print(f"\nImage: {img_name}")
    print(f"Predicted Emotion: {predicted_emotion_name}")
    print(f"Generated Caption: {pred_caption}")
    print("References:")
    for ref_cap in references[img_name]: # Use references dict
        print(f"  - {' '.join(ref_cap)}") # Join tokens back for display

print("\n--- Training and Evaluation Complete ---")
