In [None]:
import os
import numpy as np
import random
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from nltk.tokenize import word_tokenize
from PIL import Image
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim
from torch.optim import Adam
import torch.nn as nn
import pandas as pd

from  customDatasetFromCSV import CustomDatasetFromCSV

CREATE DATASET AND GET CUDA DEVICE

In [None]:
captions_csv= './filesCSV/captions.csv'
vocab_csv ='./filesCSV/vocab.csv'
data_dir = "./imagesTrainVal/train2017/"


transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),       
     
])
dataset = CustomDatasetFromCSV(data_dir,captions_csv,vocab_csv,transform=transform,percentage=100)
len(dataset.vocab)

DEVICE 

In [None]:


print("CUDA disponible:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("Nombre de la GPU:", torch.cuda.get_device_name(0))
    print("Capacidad de la GPU:", torch.cuda.get_device_capability(0))
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Using GPU:", torch.cuda.get_device_name(device))
else:
    device = torch.device("cpu")
    print("No GPU available. Using CPU.")

MODEL ARQUITECTURE

In [None]:

# EncoderCNN
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, device):
        super(EncoderCNN, self).__init__()
        from torchvision.models.resnet import resnet50
        resnet = resnet50(pretrained=True)
        self.device = device

        # Disable learning for parameters
        for param in resnet.parameters():
            param.requires_grad_(True)

        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules).to(self.device)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size).to(self.device)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.embed(features)
        return features

class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.attention_dim = attention_dim
        self.W = nn.Linear(decoder_dim, attention_dim)
        self.U = nn.Linear(encoder_dim, attention_dim)
        self.A = nn.Linear(attention_dim, 1)

    def forward(self, features, hidden_state):
        u_hs = self.U(features)  # (batch_size, num_layers, attention_dim)
        w_ah = self.W(hidden_state)  # (batch_size, attention_dim)

        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1))  # (batch_size, num_layers, attention_dim)

        attention_scores = self.A(combined_states)  # (batch_size, num_layers, 1)
        attention_scores = attention_scores.squeeze(2)  # (batch_size, num_layers)

        alpha = torch.softmax(attention_scores, dim=1)  # (batch_size, num_layers)

        attention_weights = features * alpha.unsqueeze(2)  # (batch_size, num_layers, features_dim)
        attention_weights = attention_weights.sum(dim=1)  # (batch_size, features_dim)

        return alpha, attention_weights

# DecoderRNN
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, caption_length, num_layers, attention_dim, encoder_dim, device):
        super(DecoderRNN, self).__init__()
        self.device = device
        self.num_layers = num_layers
        self.caption_length = caption_length
        self.hidden_dim = hidden_size
        self.embed_size = embed_size

        self.embed = nn.Embedding(vocab_size, embed_size).to(device)
        self.attention = Attention(encoder_dim, hidden_size, attention_dim).to(device)
        self.lstm = nn.LSTM(embed_size + encoder_dim*2, hidden_size, num_layers, batch_first=True).to(device)
        self.linear = nn.Linear(hidden_size, vocab_size).to(device)

    def init_hidden(self, batch_size, device):
        return (
            torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device),
            torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
        )
    
    def forward(self, features, captions):
        batch_size = features.size(0)
        hidden = self.init_hidden(batch_size, self.device)
        outputs = []

   
        input_tokens = torch.tensor([1], device=self.device).expand(batch_size).unsqueeze(1)

        for t in range(self.caption_length):
            embeddings = self.embed(input_tokens)
            # print(f" Embeddings shape: {embeddings.shape}")
            alpha, context = self.attention(features, hidden[0][0])  # Use the first hidden state
            # print(f" Context shape: {context.shape}")
            # print (f" Features shape: {features.shape}")
            # featuresAndContext = features + context
            featuresAndContext = torch.cat((features, context), dim=1)
            # print(f" FeaturesAndContext shape : {featuresAndContext.shape}")
          
            featuresAndContext = featuresAndContext.unsqueeze(1)
            # print(f" FeaturesAndContext shape after unesqueezing : {featuresAndContext.shape}")
            lstm_input = torch.cat((embeddings, featuresAndContext), dim=2)  # Concatenate along the second dimension
            # print(f" LSTM input shape: {lstm_input.shape}")
            h, hidden = self.lstm(lstm_input, hidden)
            # print(f" H shape: {h.shape}")
            output = self.linear(h)
            # print(f" Output shape: {output.shape}")
            outputs.append(output)

            input_tokens = captions[:, t].unsqueeze(1).to(self.device)
          
            # predicted_word_idx = output.argmax(dim=2)
            
            # input_tokens = predicted_word_idx.to(self.device)


        outputs = torch.cat(outputs, dim=1)
        return outputs
     
    
    

    def generate(self, features):
        input_token = torch.tensor([1], device=self.device).expand(features.size(0)).unsqueeze(1)
        alphas = []
        outputs = []
        tokens= []
        # Initialize the hidden state
        hidden = self.init_hidden(features.size(0), self.device)

        for _ in range(self.caption_length):
            embeddings = self.embed(input_token)
            # print(f" Embeddings shape: {embeddings.shape}")
            # print(f" Features shape: {features.shape}")
            
            alpha, context = self.attention(features, hidden[0][0])
            # alphas.append(alpha.cpu().detach().numpy())
            alphas.append(alpha)
            # print(f" Alpha shape: {alpha.shape}")
            # print(f" Context shape: {context.shape}")
            # print (f" Features shape: {features.shape}")
            featuresAndContext = torch.cat((features, context), dim=1)

            # featuresAndContext = features + context
            # print(f" FeaturesAndContext shape : {featuresAndContext.shape}")
            featuresAndContext = featuresAndContext.unsqueeze(1)
            # print(f" FeaturesAndContext shape after unesqueezing : {featuresAndContext.shape}")
            lstm_input = torch.cat((embeddings, featuresAndContext), dim=2)
            h, hidden = self.lstm(lstm_input, hidden)
            output = self.linear(h)
            # print(f"Output Shape: {output.shape}") 
            outputs.append(output)

            predicted_word_idx = output.argmax(dim=2)
            
            input_token = predicted_word_idx.to(self.device)
            tokens.append(input_token)
            # print("Input token shape : ",input_token.shape)
            if predicted_word_idx[0, 0].item() == 2:


                break
            # print("Shape of input token",input_token.shape)

            # hidden = (h, c)  # Update hidden state for the next time step

        outputs = torch.cat(outputs, dim=1)
        tokens = torch.cat(tokens, dim=1)
        
        # print(f"tokens shape: {tokens.shape}")
        # print(outputs.shape)
        return tokens, alphas

# Model
class Model(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, caption_length, num_layers, attention_dim, device):
        super(Model, self).__init__()
        self.encoder = EncoderCNN(embed_size, device)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, caption_length, num_layers, attention_dim, embed_size, device)
        self.encoder.to(device)
        self.decoder.to(device)

        # Store arguments as object attributes
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.caption_length = caption_length
        self.num_layers = num_layers
        self.attention_dim = attention_dim
        self.device = device

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    def generate(self, image):
        features = self.encoder(image)
        max_len = self.caption_length
        return self.decoder.generate(features)

HYPERPARAMETERS

In [None]:
# Define the hyperparameters
embed_size = 256
hidden_size = 512
num_layers = 1
vocab = dataset.vocab
vocab_size = len(vocab)
learning_rate = 0.001
max_caption_length = dataset.maxCaptionLength

# Initialize the model
model = Model(embed_size, hidden_size, vocab_size, dataset.maxCaptionLength, num_layers, embed_size, device)  # Remove the duplicated hidden_size argument
model.to(device) 

# Define the loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Create dataLoaders
batch_size = 64  # Set your batch size

# Data loaders
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(range(len(dataset.train_data))))
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(range(len(dataset.train_data), len(dataset.train_data) + len(dataset.val_data))))
test_loader = DataLoader(dataset, batch_size=batch_size * 6, sampler=SubsetRandomSampler(range(len(dataset.train_data) + len(dataset.val_data), len(dataset))))

len(train_loader)

NUMBER OF TRAINABLE PARAMETERS

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Calculate the number of trainable parameters
num_params = count_parameters(model)
print(f'Number of trainable parameters: {num_params}')

In [None]:
def visualize_attention(image, alphas, caption, num_boxes=8):
    num_timesteps = len(caption)
    img_height, img_width, _ = image.shape

    # Create subplots for the original image and attention boxes
    fig, axes = plt.subplots(2, (num_timesteps + 1) // 2, figsize=(15, 7))

    # Display the original image
    axes[0, 0].imshow(image.transpose(1, 2, 0))
    axes[0, 0].set_title("Original Image")
    axes[0, 0].axis('off')

    for t in range(num_timesteps):
        temp_att = alphas[t].cpu().numpy()  # Move the tensor to CPU

        # Check if the attention map has a valid shape
        if temp_att.ndim == 1 and temp_att.shape[0] > 0:
            # Normalize attention scores
            attention_scores = (temp_att - temp_att.min()) / (temp_att.max() - temp_att.min())

            # Divide the image into num_boxes x num_boxes boxes
            box_height = img_height // num_boxes
            box_width = img_width // num_boxes

            # Create a copy of the original image
            img_copy = np.copy(image.transpose(1, 2, 0))

            # Iterate over each box and adjust transparency based on attention
            for i in range(num_boxes):
                for j in range(num_boxes):
                    start_row, end_row = i * box_height, (i + 1) * box_height
                    start_col, end_col = j * box_width, (j + 1) * box_width

                    # Adjust transparency based on attention score
                    img_copy[start_row:end_row, start_col:end_col, :] *= attention_scores[i * num_boxes + j]

            # Display the image with attention overlay
            axes[t // ((num_timesteps + 1) // 2), t % ((num_timesteps + 1) // 2)].set_title(caption[t])
            img = axes[t // ((num_timesteps + 1) // 2), t % ((num_timesteps + 1) // 2)].imshow(img_copy)

            # Add attention overlay
            axes[t // ((num_timesteps + 1) // 2), t % ((num_timesteps + 1) // 2)].imshow(
                attention_scores.reshape((num_boxes, num_boxes)),
                cmap='gray', alpha=0.7, extent=img.get_extent()
            )
        else:
            print(f"Attention map at timestep {t} has an unexpected shape: {temp_att.shape}")
            axes[t // ((num_timesteps + 1) // 2), t % ((num_timesteps + 1) // 2)].set_title(caption[t])
            axes[t // ((num_timesteps + 1) // 2), t % ((num_timesteps + 1) // 2)].imshow(image.transpose(1, 2, 0))

        axes[t // ((num_timesteps + 1) // 2), t % ((num_timesteps + 1) // 2)].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
patience = 3
def train_image_captioning_model(model, optimizer, criterion, train_loader, val_loader, save_dir, num_epochs=10, start_epoch=0):
    model.to(device)

    # Early stopping variables
    best_val_loss = float('inf')
    patience_counter = 0

    # Training loop
    for epoch in range(start_epoch, start_epoch + num_epochs):
        model.train()  # Set the model to training mode
        batch_counter = 0  # Initialize batch counter
        total_batches = len(train_loader)  # Total number of batches in the training set
        train_losses = []

        for images, captions in tqdm(train_loader, desc=f'Epoch {epoch + 1}'):
            # Move data to the appropriate device
            images = images.to(device)
            captions = captions.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(images, captions)

            # Pad captions to ensure they have the same length
            padded_captions = pad_sequence(captions, batch_first=True, padding_value=0)

            # Calculate the loss
            loss = criterion(outputs.view(-1, len(dataset.vocab)), padded_captions.view(-1))
            loss.backward()

            # Update the weights
            optimizer.step()

            train_losses.append(loss.item())
            batch_counter += 1

            # Print loss and generated output every x batches
            if batch_counter % 100 == 0:
                with torch.no_grad():
                    print(f"Epoch [{epoch + 1}/{num_epochs}] - Batch [{batch_counter}/{total_batches}] - Loss: {loss.item():.4f}")
                    generated_output_ids = outputs.argmax(dim=2).tolist()
                    generated_words = [dataset.id_to_token.get(token_id, "UNK") for token_id in generated_output_ids[0] if token_id != 0]
                    generated_caption_str = " ".join(generated_words)
                    real_caption_ids = captions.tolist()
                    real_caption_words = [dataset.id_to_token.get(int(token_id), '<UNK>') for token_id in real_caption_ids[0] if token_id != 0]
                    real_caption_str = " ".join(real_caption_words)
                    print(f"Generated Output: {generated_caption_str}")
                    print(f"Real Caption: {real_caption_str}")

        avg_train_loss = sum(train_losses) / len(train_losses)

        # Validation loop
        model.eval()  # Set the model to evaluation mode
        total_val_loss = 0.0

        with torch.no_grad():
            val_losses = []
            for images, captions in val_loader:
                # Move data to the appropriate device
                images = images.to(device)
                captions = captions.to(device)

                outputs = model(images, captions)

                # Pad captions to ensure they have the same length
                padded_captions = pad_sequence(captions, batch_first=True, padding_value=0)

                # Calculate the loss
                val_loss = criterion(outputs.view(-1, len(dataset.vocab)), padded_captions.view(-1))
                val_losses.append(val_loss.item())
                total_val_loss += val_loss.item()

        avg_val_loss = sum(val_losses) / len(val_losses)

        print(f"Epoch [{epoch + 1}/{num_epochs}] - Training Loss: {avg_train_loss:.4f} - Validation Loss: {avg_val_loss:.4f}")

        # Save the model if validation loss improves
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), os.path.join(save_dir, f'model_epoch_{epoch + 1}_val_loss_{avg_val_loss:.4f}.pt'))
            patience_counter = 0
        else:
            patience_counter += 1

        # Early stopping check
        if patience_counter >= patience:
            print(f"Validation loss hasn't improved for {patience} epochs. Early stopping...")
            break

    print("Training completed.")

# Example usage:
train_image_captioning_model(model, optimizer, criterion, train_loader, val_loader, save_dir, patience)


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import csv

# Define the directory where you want to save the model
model_path = '.\SAVED_MODELS\model2\model_epoch_1_val_loss_2.3506.pt'

# Create the directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)




model.load_state_dict(torch.load(model_path))

# Create a CSV file to store real and generated captions
csv_file = 'resultsModel2Curriculum.csv'
with open(csv_file, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Real Caption', 'Generated Caption'])

# Testing loop on the test dataset
model.eval()  # Set the model to evaluation mode

with torch.no_grad():
    batch_counter = 0  # Initialize batch counter
    total_batches = len(test_loader)  # Total number of batches in the training set
    for images, captions in tqdm(test_loader, desc='Testing'):
        # Move data to the appropriate device
        images = images.to(device)
        captions = captions.to(device)
        outputs, alphas = model.generate(images)  # Assuming your model generates captions given images
        if batch_counter % 20 == 0:
            with torch.no_grad():
                image = images[0].cpu().numpy()  #
                image = np.transpose(image, (1, 2, 0))
                plt.imshow(image)
                plt.axis('off')
                plt.show()
                plt.show()
                generated_words = [dataset.id_to_token.get(token_id, "UNK") for token_id in outputs[0].tolist() if token_id != 0]
                generated_caption_str = " ".join(generated_words)
                real_caption_ids = captions.tolist()
                real_caption_words = [dataset.id_to_token.get(int(token_id), '<UNK>') for token_id in real_caption_ids[0] if token_id != 0]
                real_caption_str = " ".join(real_caption_words)
                print(f"Epoch [{epoch + 1}/{num_epochs}] - Batch [{batch_counter}/{total_batches}] - Loss: {loss.item():.4f}")
                print(f"Generated Output: {generated_caption_str}")
                print(f"Real Caption: {real_caption_str}")

        for i in range(len(images)):
            generated_words = [dataset.id_to_token.get(token_id, "UNK") for token_id in outputs[i].tolist() if token_id != 0]
            generated_caption_str = " ".join(generated_words)
            real_caption_ids = captions.tolist()
            real_caption_words = [dataset.id_to_token.get(int(token_id), '<UNK>') for token_id in real_caption_ids[i] if token_id != 0]
            real_caption_str = " ".join(real_caption_words)

            # Append the real and generated captions to the CSV file
            with open(csv_file, mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([real_caption_str, generated_caption_str])
        batch_counter+=1
print("Testing completed. Real and generated captions saved in 'results.csv'.")

In [None]:
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

# Download NLTK data
nltk.download('punkt')

# Load the CSV file
csv_file = './resultsModel2Curriculum.csv'
df = pd.read_csv(csv_file)

# Lists to store metric scores
bleu_scores = []
rouge_l_scores = []

# Smoothing function for BLEU score
smoother = SmoothingFunction().method1
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

for idx, row in df.iterrows():
    real_caption = " ".join(word_tokenize(row['Real Caption']))  # Join the words into a string
    generated_caption = " ".join(word_tokenize(row['Generated Caption']))  # Join the words into a string

    real_caption_tokenized = word_tokenize(real_caption)
    generated_caption_tokenized = word_tokenize(generated_caption)
    
    # BLEU Score
    bleu = sentence_bleu([real_caption], generated_caption, smoothing_function=smoother)
    bleu_scores.append(bleu)
    
    # ROUGE-L Score
    rouge_scores = scorer.score(real_caption, generated_caption)
    rouge_l_f1 = rouge_scores['rougeL'].fmeasure
    rouge_l_scores.append(rouge_l_f1)

# Add metric columns to the DataFrame
df['BLEU Score'] = bleu_scores
df['ROUGE-L Score'] = rouge_l_scores

# Calculate mean scores
mean_bleu_score = sum(bleu_scores) / len(bleu_scores)
mean_rouge_l_score = sum(rouge_l_scores) / len(rouge_l_scores)

# Print the DataFrame with scores and mean scores
print(df)
print("\nMean Metrics:")
print(f"Mean BLEU Score: {mean_bleu_score:.4f}")
print(f"Mean ROUGE-L Score: {mean_rouge_l_score:.4f}")


In [None]:
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

# Download NLTK data
nltk.download('punkt')

# Load the CSV file
csv_file = './resultsCSV/resultsModel2.csv'
df = pd.read_csv(csv_file)

# Lists to store metric scores
bleu_scores = []
rouge_l_scores = []

# Smoothing function for BLEU score
smoother = SmoothingFunction().method1
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

for idx, row in df.iterrows():
    real_caption = " ".join(word_tokenize(row['Real Caption']))  # Join the words into a string
    generated_caption = " ".join(word_tokenize(row['Generated Caption']))  # Join the words into a string

    real_caption_tokenized = word_tokenize(real_caption)
    generated_caption_tokenized = word_tokenize(generated_caption)
    
    # BLEU Score
    bleu = sentence_bleu([real_caption], generated_caption, smoothing_function=smoother)
    bleu_scores.append(bleu)
    
    # ROUGE-L Score
    rouge_scores = scorer.score(real_caption, generated_caption)
    rouge_l_f1 = rouge_scores['rougeL'].fmeasure
    rouge_l_scores.append(rouge_l_f1)

# Add metric columns to the DataFrame
df['BLEU Score'] = bleu_scores
df['ROUGE-L Score'] = rouge_l_scores

# Calculate mean scores
mean_bleu_score = sum(bleu_scores) / len(bleu_scores)
mean_rouge_l_score = sum(rouge_l_scores) / len(rouge_l_scores)

# Print the DataFrame with scores and mean scores
print(df)
print("\nMean Metrics:")
print(f"Mean BLEU Score: {mean_bleu_score:.4f}")
print(f"Mean ROUGE-L Score: {mean_rouge_l_score:.4f}")


In [None]:
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

# Download NLTK data
nltk.download('punkt')

# Load the CSV file
csv_file = './resultsCSV/resultsModel2Encoder.csv'
df = pd.read_csv(csv_file)

# Lists to store metric scores
bleu_scores = []
rouge_l_scores = []

# Smoothing function for BLEU score
smoother = SmoothingFunction().method1
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

for idx, row in df.iterrows():
    real_caption = " ".join(word_tokenize(row['Real Caption']))  # Join the words into a string
    generated_caption = " ".join(word_tokenize(row['Generated Caption']))  # Join the words into a string

    real_caption_tokenized = word_tokenize(real_caption)
    generated_caption_tokenized = word_tokenize(generated_caption)
    
    # BLEU Score
    bleu = sentence_bleu([real_caption], generated_caption, smoothing_function=smoother)
    bleu_scores.append(bleu)
    
    # ROUGE-L Score
    rouge_scores = scorer.score(real_caption, generated_caption)
    rouge_l_f1 = rouge_scores['rougeL'].fmeasure
    rouge_l_scores.append(rouge_l_f1)

# Add metric columns to the DataFrame
df['BLEU Score'] = bleu_scores
df['ROUGE-L Score'] = rouge_l_scores

# Calculate mean scores
mean_bleu_score = sum(bleu_scores) / len(bleu_scores)
mean_rouge_l_score = sum(rouge_l_scores) / len(rouge_l_scores)

# Print the DataFrame with scores and mean scores
print(df)
print("\nMean Metrics:")
print(f"Mean BLEU Score: {mean_bleu_score:.4f}")
print(f"Mean ROUGE-L Score: {mean_rouge_l_score:.4f}")


In [None]:
from PIL import Image
from torchvision import transforms
import torch
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
import cv2

model_path= "./SAVED_MODELS/model2/ModelEncoder2_epoch_4_Loss_2.2639.pth"


# Define the hyperparameters
embed_size = 256
hidden_size = 512
num_layers = 1
vocab = dataset.vocab
vocab_size = len(vocab)
learning_rate = 0.001
max_caption_length = dataset.maxCaptionLength

# Initialize the model
model = Model(embed_size, hidden_size, vocab_size, dataset.maxCaptionLength, num_layers, embed_size, device)  # Remove the duplicated hidden_size argument
model.to(device)

model.load_state_dict(torch.load(model_path))



def load_and_preprocess_image_cv2(image_path):
    # Load and preprocess the image using OpenCV
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
    image = cv2.resize(image, (224, 224))  # Resize to (224, 224)
    image = image / 255.0  # Normalize pixel values to the range [0, 1]
    image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)  # Add batch dimension
    return image
# Function to load and preprocess the image
def load_and_preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image
def generate_caption(model, image_path, id_to_token, real_description=None):
    # Load and preprocess the image using OpenCV
    image = load_and_preprocess_image(image_path)

    # Move the model and image to the same device
    device = next(model.parameters()).device
    image = image.to(device)
    model.eval()
    model.to(device)

    # Display the image using OpenCV
    plt.imshow(cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.show()

    with torch.no_grad():
        # Get the generated tokens and attention scores
        generated_tokens, attention_scores = model.generate(image)

        # Convert token IDs to words
        generated_words = [id_to_token.get(token_id.item(), "UNK") for token_id in generated_tokens[0] if token_id != 0]
        generated_caption_str = " ".join(generated_words)

    print("Generated description by the model:", "###  ", generated_caption_str, "  ###")
    if real_description:
        print(f"Real description made by hand by the tester: ###  {real_description}  ###")







# Loop over each image in the folder
folder_path = './ImagenesPrueba/'
for filename in os.listdir(folder_path):
    if filename:#.endswith(".jpg")
        image_path = os.path.join(folder_path, filename)

        # Print the image filename
        print(f"Image: {filename}")

        # Use the generate_caption function
        generate_caption(model, image_path, dataset.id_to_token)
        print("\n")