In [1]:
import os, re, time, math, random, json, pickle, itertools, warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import collections
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
from PIL import Image
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers.schedules import CosineDecay
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [2]:
# import os
# import re
# import time
# 
# import random
# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# from PIL import Image
# from typing import Dict, List, Tuple, Optional
# import tensorflow as tf
# from tensorflow.keras import layers, Model
# from tensorflow.keras.preprocessing.text import Tokenizer
# from tensorflow.keras.preprocessing.sequence import pad_sequences
# from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
# import tqdm
# from gtts import gTTS
# from IPython.display import Audio, display

###  CONFIGURATION

In [3]:
CONFIG: Dict[str, object] = {
    'image_dir': '/home/flickr8k/Images',
    'caption_file': '/home/flickr8k/captions_8k.csv',
    'feature_cache_dir': '/home/flickr8k/cache',
    'num_examples': None,
    'max_caption_length': 50,
    'min_word_frequency': 5,

    'embedding_dim': 256,
    'units': 512,
    'decoder_dropout': 0.5,

    'learning_rate': 5e-5,
    'epochs': 20,
    'batch_size': 64,

    'buffer_size': 1000,
    'patience': 5,
    'checkpoint_path': './checkpoints/lstm_attention_flickr8k',
    'mixed_precision': True,

    'attention_reg_lambda': 0.5,
    'grad_clip_value': 5.0, # Added missing grad_clip_value
    'scheduled_sampling_max_prob': 0.2, # Added missing scheduled_sampling_max_prob
    'seed': 42,
}

###  ENV SETUP

In [4]:
np.random.seed(CONFIG['seed'])
random.seed(CONFIG['seed'])
tf.random.set_seed(CONFIG['seed'])

if CONFIG['mixed_precision']:
    policy = tf.keras.mixed_precision.Policy('mixed_float16')
    tf.keras.mixed_precision.set_global_policy(policy)
    print("[AMP] mixed_float16 policy active")
else:
    print("[AMP] disabled – using float32 throughout")

physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    for gpu in physical_devices:
        tf.config.experimental.set_memory_growth(gpu, True)
    print(f"Using GPU: {physical_devices[0].name} | batch={CONFIG['batch_size']}")
else:
    print("GPU not found – fallback to CPU")

AUTOTUNE = tf.data.AUTOTUNE

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA RTX 6000 Ada Generation, compute capability 8.9
[AMP] mixed_float16 policy active
Using GPU: /physical_device:GPU:0 | batch=64


###  DATA LOADING AND PREPROCESSING

In [5]:
class DataProcessor:
    def __init__(self, config):
        self.config = config
        self.tokenizer: Optional[Tokenizer] = None
        self.img_to_cap_map: Dict[str, List[str]] = collections.defaultdict(list)
        self.image_paths: List[str] = []
        self.all_captions: List[str] = []
        self.train_data: List[Tuple[str, List[int]]] = []
        self.val_data: List[Tuple[str, List[int]]] = []
        self.test_data: List[Tuple[str, List[int]]] = []
        self.max_caption_length = 0
        self.vocab_size = 0
        self.num_steps_per_epoch = 0

    def load_and_preprocess_data(self):
        print("Loading and preprocessing captions...")
        df = pd.read_csv(self.config['caption_file'], header=None, names=['image_name', 'comment'], engine='python')

        df['image_name'] = df['image_name'].str.strip()
        df['comment'] = df['comment'].str.strip()

        temp_img_to_cap_map = collections.defaultdict(list)
        all_unique_img_names_from_csv = df['image_name'].unique()

        print(f"Checking {len(all_unique_img_names_from_csv)} unique image files from CSV...")
        found_images_count = 0
        
        existing_image_files = set(os.listdir(self.config['image_dir']))

        for index, row in tqdm.tqdm(df.iterrows(), total=df.shape[0], desc="Validating images & processing captions"):
            img_name = row['image_name']
            caption = row['comment']
            
            if img_name in existing_image_files:
                temp_img_to_cap_map[img_name].append(self.preprocess_text(caption))
                if img_name not in self.img_to_cap_map:
                    found_images_count += 1
                self.img_to_cap_map[img_name] = temp_img_to_cap_map[img_name]
            
        if found_images_count < len(all_unique_img_names_from_csv):
            print(f"Warning: {len(all_unique_img_names_from_csv) - found_images_count} images mentioned in CSV were not found in {self.config['image_dir']}. They have been discarded.")
        
        self.image_paths = sorted(list(self.img_to_cap_map.keys()))

        if self.config['num_examples']:
            if len(self.image_paths) > self.config['num_examples']:
                self.image_paths = random.sample(self.image_paths, self.config['num_examples'])
                self.img_to_cap_map = {img: self.img_to_cap_map[img] for img in self.image_paths}
                print(f"Using a subset of {len(self.image_paths)} images due to 'num_examples' config.")

        self.all_captions = []
        for img_name in self.image_paths:
            self.all_captions.extend(self.img_to_cap_map[img_name])

        print(f"Total valid images (with captions): {len(self.image_paths)}")
        print(f"Total valid captions: {len(self.all_captions)}")

        self.tokenizer = Tokenizer(num_words=None, oov_token="<unk>",
                                   filters='!"#$%&()*+.,-/:;=?@[\]^_`{|}~ ',
                                   lower=True)
        self.tokenizer.fit_on_texts(self.all_captions)

        word_counts = collections.Counter(word for caption in self.all_captions for word in caption.split())
        filtered_word_index = {
            word: index for word, index in self.tokenizer.word_index.items()
            if word_counts[word] >= self.config['min_word_frequency'] or word in ['<pad>', '<start>', '<end>', '<unk>']
        }
        self.tokenizer.word_index = filtered_word_index
        self.tokenizer.index_word = {v: k for k, v in filtered_word_index.items()}

        if '<pad>' not in self.tokenizer.word_index:
            self.tokenizer.word_index['<pad>'] = len(self.tokenizer.word_index) + 1
            self.tokenizer.index_word[len(self.tokenizer.index_word) + 1] = '<pad>'
            
        self.vocab_size = len(self.tokenizer.word_index) + 1
        print(f"Vocabulary size after pruning (min_word_frequency={self.config['min_word_frequency']}): {self.vocab_size}")

        all_seqs = self.tokenizer.texts_to_sequences(self.all_captions)
        self.max_caption_length = max(len(s) for s in all_seqs)
        self.config['max_caption_length'] = self.max_caption_length
        print(f"Max caption length: {self.max_caption_length}")

    def preprocess_text(self, caption: str) -> str:
        caption = caption.lower()
        caption = re.sub(r"[^a-z ]", "", caption)
        caption = re.sub(r'\s+', ' ', caption).strip()
        caption = '<start> ' + caption + ' <end>'
        return caption

    def create_dataset_splits(self, train_ratio=0.8, val_ratio=0.1):
        random.shuffle(self.image_paths)
        num_images = len(self.image_paths)
        num_train = int(train_ratio * num_images)
        num_val = int(val_ratio * num_images)

        train_image_paths = self.image_paths[:num_train]
        val_image_paths = self.image_paths[num_train:num_train + num_val]
        test_image_paths = self.image_paths[num_train + num_val:]

        print(f"Train images: {len(train_image_paths)}, Val images: {len(val_image_paths)}, Test images: {len(test_image_paths)}")

        self.train_data = self._create_pairs(train_image_paths)
        self.val_data = self._create_pairs(val_image_paths)
        self.test_data = self._create_pairs(test_image_paths)

        print(f"Train pairs: {len(self.train_data)}, Val pairs: {len(self.val_data)}, Test pairs: {len(self.test_data)}")

        self.num_steps_per_epoch = len(self.train_data) // self.config['batch_size']

    def _create_pairs(self, image_names: List[str]) -> List[Tuple[str, List[int]]]:
        pairs = []
        for img_name in image_names:
            full_img_path = os.path.join(self.config['image_dir'], img_name)
            for caption in self.img_to_cap_map[img_name]:
                seq = self.tokenizer.texts_to_sequences([caption])[0]
                padded_seq = pad_sequences([seq], maxlen=self.max_caption_length, padding='post')[0]
                pairs.append((full_img_path, list(padded_seq)))
        return pairs

    def get_data_with_cached_features(self, image_name_to_cached_path_map: Dict[str, str]) -> Tuple[List, List, List]:
        def _reconstruct_list(data_list):
            reconstructed = []
            for original_img_path, caption_ids in data_list:
                basename = os.path.basename(original_img_path)
                cached_path = image_name_to_cached_path_map.get(basename)
                if cached_path:
                    reconstructed.append((original_img_path, cached_path, caption_ids))
            return reconstructed

        final_train = _reconstruct_list(self.train_data)
        final_val = _reconstruct_list(self.val_data)
        final_test = _reconstruct_list(self.test_data)

        print(f"Adjusted train pairs (after feature caching check): {len(final_train)}")
        print(f"Adjusted val pairs (after feature caching check): {len(final_val)}")
        print(f"Adjusted test pairs (after feature caching check): {len(final_test)}")
        
        return final_train, final_val, final_test



###  IMAGE FEATURE EXTRACTION (CACHE FEATURES)

In [6]:
class ImageFeatureExtractor(Model):
    def __init__(self, target_size=(299, 299)):
        super().__init__()
        self.target_size = target_size
        self.inception_v3 = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet')
        self.inception_v3.trainable = False
        self.feature_extractor = Model(inputs=self.inception_v3.input,
                                       outputs=self.inception_v3.get_layer('mixed7').output)

    @tf.function
    def load_and_preprocess_image(self, image_path: tf.Tensor) -> tf.Tensor:
        img = tf.io.read_file(image_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, self.target_size)
        img = tf.keras.applications.inception_v3.preprocess_input(img)
        return img

class ImageFeatureCacheManager:
    def __init__(self, config, feature_extractor: ImageFeatureExtractor):
        self.config = config
        self.feature_extractor = feature_extractor
        self.cache_dir = config['feature_cache_dir']
        os.makedirs(self.cache_dir, exist_ok=True)

    def manage_feature_cache(self, image_names: List[str]) -> Dict[str, str]:
        print(f"\nManaging image feature cache in: {self.cache_dir}")
        print(f"Checking/extracting features for {len(image_names)} unique images.")

        image_name_to_cached_path = {}
        images_to_extract = []

        for img_name in image_names:
            cache_file = os.path.join(self.cache_dir, img_name + '.npy')
            if not os.path.exists(cache_file):
                images_to_extract.append(img_name)
            image_name_to_cached_path[img_name] = cache_file

        if images_to_extract:
            print(f"Found {len(images_to_extract)} images whose features need extraction...")
            
            full_paths_for_extraction = [os.path.join(self.config['image_dir'], img_name) for img_name in images_to_extract]

            for i, img_path in enumerate(tqdm.tqdm(full_paths_for_extraction, desc="Extracting & Caching Features")):
                img_name = os.path.basename(img_path)
                cache_path = os.path.join(self.cache_dir, img_name + '.npy')
                try:
                    img_tensor_processed = self.feature_extractor.load_and_preprocess_image(tf.constant(img_path))
                    features = self.feature_extractor.feature_extractor(tf.expand_dims(img_tensor_processed, 0))
                    features_flat = tf.reshape(features, (features.shape[0], -1, features.shape[3]))
                    np.save(cache_path, features_flat.numpy())
                except Exception as e:
                    print(f"\nError processing {img_path}: {e}. Skipping feature caching for this image.")
                    if img_name in image_name_to_cached_path:
                        del image_name_to_cached_path[img_name]
                    continue
        else:
            print("All image features already cached. Skipping extraction.")
            
        print("Image feature cache management complete.")
        return image_name_to_cached_path


###  MODEL ARCHITECTURE

In [7]:
class Encoder(Model):
    def __init__(self, embedding_dim):
        super(Encoder, self).__init__()
        self.fc = layers.Dense(embedding_dim)

    def call(self, x):
        x = self.fc(x)
        return x

class BahdanauAttention(layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = layers.Dense(units)
        self.W2 = layers.Dense(units)
        self.V = layers.Dense(1)

    def call(self, features, hidden):
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, tf.squeeze(attention_weights, -1)

class Decoder(Model):
    def __init__(self, embedding_dim, units, vocab_size, dropout=0.5):
        super(Decoder, self).__init__()
        self.units = units
        self.embedding = layers.Embedding(vocab_size, embedding_dim)
        self.lstm = layers.LSTM(self.units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform',
                                   dropout=dropout)
        self.fc1 = layers.Dense(self.units)
        self.fc2 = layers.Dense(vocab_size, dtype='float32')
        self.attention = BahdanauAttention(self.units)

    def call(self, x, features, hidden_state, cell_state):
        context_vector, attention_weights = self.attention(features, hidden_state)
        x = self.embedding(x)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
        output, new_hidden_state, new_cell_state = self.lstm(x, initial_state=[hidden_state, cell_state])
        output = tf.reshape(output, (-1, output.shape[2]))
        x = self.fc1(output)
        logits = self.fc2(x)
        return logits, new_hidden_state, new_cell_state, attention_weights

###  TEXT-TO-SPEECH UTILITY

In [8]:
class TextToSpeech:
    def __init__(self):
        try:
            self.gTTS = gTTS
            self.Audio = Audio
            self.display = display
            self.available = True
        except ImportError:
            print("WARNING: gTTS or IPython.display not found. Speech functionality will be disabled.")
            self.available = False

    def speak(self, text: str, filename: str = "caption_audio.mp3"):
        if not self.available:
            print("Text-to-speech functionality is not available. Please install 'gtts' and ensure running in an IPython environment.")
            return
        
        if not text.strip():
            print("Empty text, nothing to speak.")
            return
        
        try:
            tts = self.gTTS(text=text, lang='en')
            tts.save(filename)
            self.display(self.Audio(filename))
            print(f"Audio saved to {filename} and played.")
        except Exception as e:
            print(f"Error generating or playing audio: {e}")

###  TRAINING LOOP & UTILITIES

In [9]:
class ImageCaptioningTrainer:
    def __init__(self, config, processor: DataProcessor, feature_extractor: ImageFeatureExtractor):
        self.config = config
        self.processor = processor
        self.feature_extractor = feature_extractor

        self.encoder = Encoder(self.config['embedding_dim'])
        self.decoder = Decoder(self.config['embedding_dim'], self.config['units'],
                               self.processor.vocab_size, self.config['decoder_dropout'])

        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.config['learning_rate'])
        if self.config['mixed_precision']:
            self.optimizer = tf.keras.mixed_precision.LossScaleOptimizer(self.optimizer)

        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction='none')

        self.checkpoint_prefix = os.path.join(self.config['checkpoint_path'], "ckpt")
        self.checkpoint = tf.train.Checkpoint(encoder=self.encoder,
                                              decoder=self.decoder,
                                              optimizer=self.optimizer)
        self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint, self.config['checkpoint_path'], max_to_keep=5)

        self.tts_speaker = TextToSpeech()

        self.train_loss_results = []
        self.val_bleu_results = []
        self.best_val_bleu = 0.0
        self.smoothing_function = SmoothingFunction().method4
        self.patience_counter = 0


        if self.checkpoint_manager.latest_checkpoint:
            self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
            print(f"Restored from {self.checkpoint_manager.latest_checkpoint}")
        else:
            print("Initializing from scratch.")

    def loss_function(self, real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, 0))
        loss_ = self.loss_object(real, pred)
        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask
        return tf.reduce_mean(loss_)

    @tf.function
    def train_step(self, img_tensor, target):
        batch_size = tf.shape(target)[0]
        loss = 0.0

        hidden = tf.zeros((batch_size, self.config['units']), dtype=tf.float32 if not self.config['mixed_precision'] else tf.float16)
        cell = tf.zeros((batch_size, self.config['units']), dtype=tf.float32 if not self.config['mixed_precision'] else tf.float16)

        dec_input = tf.expand_dims([self.processor.tokenizer.word_index['<start>']] * batch_size, 1)

        with tf.GradientTape() as tape:
            features = self.encoder(img_tensor)

            attention_sum_square_error = tf.constant(0.0, dtype=tf.float32)

            for i in tf.range(1, target.shape[1]):
                predictions, hidden, cell, attention_weights = self.decoder(dec_input, features, hidden, cell)

                loss += self.loss_function(target[:, i], predictions)

                if self.config.get('scheduled_sampling_max_prob', 0.0) > 0:
                    prob = tf.random.uniform([], 0, 1)
                    if prob < self.current_scheduled_sampling_prob:
                        predicted_id = tf.argmax(predictions, axis=1)
                        dec_input = tf.expand_dims(predicted_id, 1)
                    else:
                        dec_input = tf.expand_dims(target[:, i], 1)
                else:
                    dec_input = tf.expand_dims(target[:, i], 1)

                attention_sum_square_error += tf.reduce_mean(tf.square(tf.reduce_sum(attention_weights, axis=1) - 1.0))

            total_loss = (loss / tf.cast(target.shape[1], tf.float32))
            total_loss += self.config['attention_reg_lambda'] * attention_sum_square_error / tf.cast(target.shape[1], tf.float32)

            if self.config['mixed_precision']:
                scaled_loss = self.optimizer.get_scaled_loss(total_loss)
            else:
                scaled_loss = total_loss

        trainable_variables = self.encoder.trainable_variables + self.decoder.trainable_variables
        gradients = tape.gradient(scaled_loss, trainable_variables)

        if self.config['mixed_precision']:
            gradients = self.optimizer.get_unscaled_gradients(gradients)

        gradients, _ = tf.clip_by_global_norm(gradients, self.config['grad_clip_value'])
        self.optimizer.apply_gradients(zip(gradients, trainable_variables))

        return total_loss

    def evaluate_bleu_score(self, dataset_pairs: List[Tuple[str, List[int]]], num_samples=None):
        references = []
        hypotheses = []

        if num_samples is None:
            samples_to_evaluate = dataset_pairs
        else:
            samples_to_evaluate = random.sample(dataset_pairs, min(num_samples, len(dataset_pairs)))

        print(f"\nEvaluating BLEU on {len(samples_to_evaluate)} samples...")
        for img_path, _ in tqdm.tqdm(samples_to_evaluate):
            generated_caption_tokens = self.greedy_inference(img_path)
            if not generated_caption_tokens:
                continue
            hypotheses.append(generated_caption_tokens)

            img_name = os.path.basename(img_path)
            raw_captions = self.processor.img_to_cap_map.get(img_name, [])
            
            img_references = []
            for raw_cap in raw_captions:
                cleaned_cap = raw_cap.replace('<start>', '').replace('<end>', '').strip()
                if cleaned_cap:
                    img_references.append(cleaned_cap.split())
            
            if img_references:
                references.append(img_references)
            else:
                hypotheses.pop()

        if not references:
            print("No valid reference captions found for BLEU evaluation.")
            return {"bleu-1": 0.0, "bleu-2": 0.0, "bleu-3": 0.0, "bleu-4": 0.0}

        bleu_scores = {}
        for n in range(1, 5):
            weights = (1.0 / n,) * n + (0.0,) * (4 - n)
            bleu_scores[f"bleu-{n}"] = corpus_bleu(references, hypotheses, weights=weights,
                                                    smoothing_function=self.smoothing_function)
            print(f"BLEU-{n}: {bleu_scores[f'bleu-{n}']:.4f}")
        
        return bleu_scores

    def greedy_inference(self, image_path: str):
        filename = os.path.basename(image_path)
        feature_cache_path = os.path.join(self.config['feature_cache_dir'], filename + '.npy')
        
        if not os.path.exists(feature_cache_path):
            print(f"Error: Feature cache not found for {image_path}")
            return []

        img_features = np.load(feature_cache_path)
        img_features_tensor = tf.convert_to_tensor(img_features, dtype=tf.float32 if not self.config['mixed_precision'] else tf.float16)

        features = self.encoder(img_features_tensor)

        hidden = tf.zeros((1, self.config['units']), dtype=tf.float32 if not self.config['mixed_precision'] else tf.float16)
        cell = tf.zeros((1, self.config['units']), dtype=tf.float32 if not self.config['mixed_precision'] else tf.float16)

        dec_input = tf.expand_dims([self.processor.tokenizer.word_index['<start>']], 0)

        result = []
        for i in range(self.config['max_caption_length']):
            predictions, hidden, cell, _ = self.decoder(dec_input, features, hidden, cell)
            predicted_id = tf.argmax(predictions[0]).numpy()
            predicted_word = self.processor.tokenizer.index_word.get(predicted_id, '<unk>')

            if predicted_word == '<end>':
                break
            if predicted_word not in ('<unk>', '<start>', '<pad>'):
                result.append(predicted_word)

            dec_input = tf.expand_dims([predicted_id], 0)

        return result

    def beam_search_inference(self, image_path: str, beam_size: int = 3, length_penalty_weight: float = 0.7):
        filename = os.path.basename(image_path)
        feature_cache_path = os.path.join(self.config['feature_cache_dir'], filename + '.npy')
        if not os.path.exists(feature_cache_path):
            print(f"Error: Feature cache not found for {image_path}")
            return [], []

        img_features = np.load(feature_cache_path)
        img_features_tensor = tf.convert_to_tensor(img_features, dtype=tf.float32 if not self.config['mixed_precision'] else tf.float16)
        features = self.encoder(img_features_tensor)

        start_token = self.processor.tokenizer.word_index['<start>']
        end_token = self.processor.tokenizer.word_index['<end>']

        beams = [(
            [start_token],
            0.0,
            tf.zeros((1, self.config['units']), dtype=tf.float32 if not self.config['mixed_precision'] else tf.float16),
            tf.zeros((1, self.config['units']), dtype=tf.float32 if not self.config['mixed_precision'] else tf.float16),
            []
        )]

        completed_beams = []

        for _ in range(self.config['max_caption_length']):
            new_beams = []
            for seq, score, hidden, cell, alphas in beams:
                last_token = seq[-1]

                if last_token == end_token:
                    completed_beams.append((seq, score, alphas))
                    continue

                dec_input = tf.expand_dims([last_token], 0)
                predictions, new_hidden, new_cell, attention_weights = self.decoder(dec_input, features, hidden, cell)

                predictions = tf.cast(predictions[0], tf.float32)
                log_probs = tf.nn.log_softmax(predictions).numpy()

                top_k_indices = np.argsort(log_probs)[::-1][:beam_size]

                for idx in top_k_indices:
                    token_id = int(idx)
                    token_log_prob = float(log_probs[idx])
                    
                    new_beams.append((
                        seq + [token_id],
                        score + token_log_prob,
                        new_hidden,
                        new_cell,
                        alphas + [attention_weights[0].numpy()]
                    ))

            new_beams.sort(key=lambda x: x[1] / (len(x[0]) ** length_penalty_weight), reverse=True)
            beams = new_beams[:beam_size]

            if len(completed_beams) >= beam_size:
                break
        
        completed_beams.extend([(seq, score, alphas) for seq, score, _, _, alphas in beams])

        if not completed_beams:
             return [], []

        best_seq, best_score, best_alphas = max(completed_beams, key=lambda x: x[1] / (len(x[0]) ** length_penalty_weight))

        caption_words = [self.processor.tokenizer.index_word.get(i, '<unk>') for i in best_seq]
        
        filtered_caption_words = [
            word for word in caption_words
            if word not in ['<start>', '<end>', '<pad>', '<unk>']
        ]

        return filtered_caption_words, best_alphas

    def train(self, train_cached_paths: List[Tuple[str, str, List[int]]],
              val_dataset_pairs: List[Tuple[str, List[int]]]):
        
        # Generator function for the TensorFlow dataset
        def _data_generator():
            for orig_img_path, cache_path, caption_ids in train_cached_paths:
                # Ensure the loaded feature array has the expected shape (64, 768)
                features = np.load(cache_path)
                features = features.reshape(features.shape[-2], features.shape[-1]) # Ensure (64, 768)
                yield features, np.array(caption_ids, dtype=np.int32)


        train_dataset = tf.data.Dataset.from_generator(
            _data_generator,
            output_signature=(
                tf.TensorSpec(shape=(64, 768), dtype=tf.float32 if not self.config['mixed_precision'] else tf.float16),
                tf.TensorSpec(shape=(self.config['max_caption_length'],), dtype=tf.int32)
            )
        )
        train_dataset = train_dataset.shuffle(self.config['buffer_size']).batch(self.config['batch_size'])
        train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)


        for epoch in range(self.config['epochs']):
            start = time.time()
            total_loss = 0

            self.current_scheduled_sampling_prob = (
                self.config['scheduled_sampling_max_prob'] * (epoch / max(1, self.config['epochs'] - 1))
            )
            print(f"\nEpoch {epoch+1}/{self.config['epochs']} (Scheduled Sampling Prob: {self.current_scheduled_sampling_prob:.3f})")


            for (batch, (img_tensor, target)) in enumerate(train_dataset):
                batch_loss = self.train_step(img_tensor, target)
                total_loss += batch_loss

                if batch % 100 == 0:
                    print(f'Epoch {epoch+1} Batch {batch} Loss {batch_loss.numpy():.4f}')
            
            avg_train_loss = total_loss / self.processor.num_steps_per_epoch
            self.train_loss_results.append(avg_train_loss.numpy())
            print(f'Epoch {epoch+1} Loss {avg_train_loss:.4f}')

            val_bleu_scores = self.evaluate_bleu_score(val_dataset_pairs, num_samples=1000)
            current_val_bleu4 = val_bleu_scores.get('bleu-4', 0.0)
            self.val_bleu_results.append(current_val_bleu4)

            if current_val_bleu4 > self.best_val_bleu:
                self.best_val_bleu = current_val_bleu4
                self.checkpoint_manager.save()
                print(f"Saving checkpoint at epoch {epoch+1} with BLEU-4: {current_val_bleu4:.4f}")
                self.patience_counter = 0
            else:
                self.patience_counter += 1
                print(f"BLEU-4 not improved. Patience counter: {self.patience_counter}/{self.config['patience']}")
                if self.patience_counter >= self.config['patience']:
                    print(f"Early stopping triggered at epoch {epoch+1}.")
                    break

            print(f'Time taken for 1 epoch: {time.time() - start:.2f} secs')

    def plot_history(self):
        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.plot(self.train_loss_results, label='Train Loss')
        plt.title('Training Loss per Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(self.val_bleu_results, label='Validation BLEU-4')
        plt.title('Validation BLEU-4 Score per Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('BLEU-4 Score')
        plt.grid(True)
        plt.legend()

        plt.tight_layout()
        plt.show()

    def plot_attention(self, image_path: str, caption: List[str], alphas: List[np.ndarray]):
        img = Image.open(image_path)
        img = np.array(img.resize((299, 299)))

        fig = plt.figure(figsize=(15, 8))

        for t in range(len(caption)):
            if t > 24:
                break
            
            ax = fig.add_subplot(5, 5, t + 1)
            ax.imshow(img)
            ax.axis('off')

            alpha = np.array(alphas[t])
            attention_grid_size = int(np.sqrt(alpha.shape[0]))
            alpha_reshaped = alpha.reshape(attention_grid_size, attention_grid_size)

            alpha_resized = Image.fromarray(np.uint8(255 * alpha_reshaped)).resize(
                (299, 299), resample=Image.BICUBIC
            )
            alpha_resized = np.array(alpha_resized) / 255.0

            ax.imshow(alpha_resized, cmap='jet', alpha=0.5, extent=(0, 299, 299, 0))
            ax.set_title(f"{t+1}: '{caption[t]}'", fontsize=10, color='blue')

        plt.tight_layout()
        plt.suptitle(f"Attention Map for: {os.path.basename(image_path)}", fontsize=16, y=1.02)
        plt.show()

    def speak_caption(self, caption: str, filename="caption_audio.mp3"):
        self.tts_speaker.speak(caption, filename)

    def demo(self, image_file_name: str):
        full_image_path = os.path.join(self.config['image_dir'], image_file_name)
        
        if not os.path.exists(full_image_path):
            print(f"Error: Image not found at {full_image_path}")
            return

        print(f"\n--- Demo for {image_file_name}")
        
        img = Image.open(full_image_path)
        plt.figure(figsize=(8, 6))
        plt.imshow(img)
        plt.title(f"Image: {image_file_name}")
        plt.axis('off')
        plt.show()

        gt_captions = self.processor.img_to_cap_map.get(image_file_name, [])
        print("\nGround Truth Captions:")
        if gt_captions:
            for i, cap in enumerate(gt_captions):
                clean_cap = cap.replace('<start>', '').replace('<end>', '').strip()
                print(f"  {i+1}. {clean_cap}")
        else:
            print("  No ground truth captions available.")
        
        generated_caption_words, attention_weights = self.beam_search_inference(full_image_path, beam_size=3)
        generated_caption = " ".join(generated_caption_words)
        print(f"\nGenerated Caption (Beam Search): {generated_caption}")

        print("\nPlaying generated caption:")
        self.speak_caption(generated_caption, filename=f"caption_audio_{os.path.basename(image_file_name).split('.')[0]}.mp3")

        if generated_caption_words and attention_weights:
            self.plot_attention(full_image_path, generated_caption_words, attention_weights)
        else:
            print("Could not generate caption or attention for plotting.")



###  MAIN EXECUTION

In [10]:
processor = DataProcessor(CONFIG)

In [11]:
processor.load_and_preprocess_data()

Loading and preprocessing captions...


ParserError: Expected 2 fields in line 87, saw 3

In [None]:
processor.create_dataset_splits()
feature_extractor_model = ImageFeatureExtractor()

cache_manager = ImageFeatureCacheManager(CONFIG, feature_extractor_model)
image_name_to_cached_path_map = cache_manager.manage_feature_cache(processor.image_paths)
final_train_data, final_val_data, final_test_data = processor.get_data_with_cached_features(image_name_to_cached_path_map)
trainer = ImageCaptioningTrainer(CONFIG, processor, feature_extractor_model)

print("\nStarting training...")
trainer.train(final_train_data, processor.val_data)

trainer.plot_history()
print("\n--- Final Evaluation on Test Set")
trainer.evaluate_bleu_score(processor.test_data)
print("\n--- Running a Demo")
if processor.test_data:
    random_test_img_path, _ = random.choice(processor.test_data)
    trainer.demo(os.path.basename(random_test_img_path))
else:
    print("No test data available for demo.")