# Caption Preprocessing, Tokenization, and Sequence Padding

## Setup and Imports

In [1]:
# Confirm environment
!conda info


     active environment : genai_project
    active env location : /opt/miniconda3/envs/genai_project
            shell level : 2
       user config file : /Users/shonie/.condarc
 populated config files : /opt/miniconda3/.condarc
          conda version : 25.5.1
    conda-build version : not installed
         python version : 3.13.5.final.0
                 solver : libmamba (default)
       virtual packages : __archspec=1=m2
                          __conda=25.5.1=0
                          __osx=15.5=0
                          __unix=0=0
       base environment : /opt/miniconda3  (writable)
      conda av data dir : /opt/miniconda3/etc/conda
  conda av metadata url : None
           channel URLs : https://repo.anaconda.com/pkgs/main/osx-arm64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/r/osx-arm64
                          https://repo.anaconda.com/pkgs/r/noarch
          package cache : /opt/minico

In [2]:
# Setup autoreload
%load_ext autoreload
%autoreload 2

In [3]:
import re
from typing import Dict, List, Set, Tuple
from collections import defaultdict, Counter
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer, tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pickle
import json
from vtt.utils.config import START_TOKEN, END_TOKEN, OOV_TOKEN, MIN_WORD_FREQ

In [4]:
def clean_caption(text: str) -> str:
    """
    Clean and normalize a single caption:
    - Lowercase
    - Remove punctuation
    - Remove extra whitespace

    Args:
        text (str): Raw caption string.

    Returns:
        str: Cleaned caption with special tokens added.
    """
    text = text.lower()
    text = re.sub(r"[^a-z0-9'\s]", "", text)  # remove punctuation except apostrophes
    text = re.sub(r"\s+", " ", text).strip()  # normalize whitespace
    return f"{START_TOKEN} {text} {END_TOKEN}"

In [6]:
def load_and_clean_captions(filepath: str) -> Dict[str, List[str]]:
    """
    Load captions from a comma-separated file and clean them.

    Args:
        filepath (str): Path to the captions CSV file (e.g., Flickr8k.token.txt).

    Returns:
        Dict[str, List[str]]: Mapping from image filename to list of cleaned captions.
    """
    captions: Dict[str, List[str]] = defaultdict(list)

    with open(filepath, "r") as file:
        next(file)  # Skip header

        for line in file:
            tokens = line.strip().split(",")

            # Expect 2 tokens per line: image_name, caption
            if len(tokens) != 2:
                continue

            image_id, caption = tokens
            image_filename = image_id.split("#")[0].strip()
            cleaned = clean_caption(caption)
            captions[image_filename].append(cleaned)

    return dict(captions)

In [7]:
def count_word_frequencies(captions_dict: Dict[str, List[str]]) -> Counter:
    """
    Count word frequencies across all captions.

    Args:
        captions_dict (Dict[str, List[str]]): Caption dictionary from `load_and_clean_captions`.

    Returns:
        Counter: Word frequency dictionary.
    """
    counter = Counter()
    for captions in captions_dict.values():
        for caption in captions:
            counter.update(caption.split())
    return counter

In [14]:
def filter_captions_by_frequency(
    captions_dict: Dict[str, List[str]], min_word_freq: int
) -> Tuple[Dict[str, List[str]], Set[str]]:
    """
    Replace infrequent words in captions with the OOV token.

    Args:
        captions_dict (Dict[str, List[str]]): Mapping from image filename to list of captions.
        min_word_freq (int): Minimum number of times a word must appear to be kept.

    Returns:
        Tuple:
            - filtered_captions (Dict[str, List[str]]): Updated captions with rare words replaced.
            - vocab (Set[str]): Set of retained vocabulary words.
    """
    freq = count_word_frequencies(captions_dict)

    vocab = {
        word
        for word, count in freq.items()
        if count >= min_word_freq or word in {START_TOKEN, END_TOKEN, OOV_TOKEN}
    }

    filtered = {}
    for img_id, captions in captions_dict.items():
        new_captions = []
        for caption in captions:
            tokens = [word if word in vocab else OOV_TOKEN for word in caption.split()]
            new_captions.append(" ".join(tokens))
        filtered[img_id] = new_captions

    return filtered, vocab

In [15]:
def fit_tokenizer(
    filtered_captions: Dict[str, List[str]],
    num_words: int = None,
    oov_token: str = OOV_TOKEN,
) -> Tokenizer:
    """
    Fit a Keras tokenizer on all captions.

    Args:
        filtered_captions (Dict[str, List[str]]): Mapping from image filename to list of cleaned captions.
        num_words (int): Max number of words to keep based on frequency. If None, keep all.
        oov_token (str): Token to represent out-of-vocabulary words.

    Returns:
        Tokenizer: Fitted tokenizer object.
    """
    all_captions = [
        caption for captions in filtered_captions.values() for caption in captions
    ]
    tokenizer = Tokenizer(num_words=num_words, oov_token=oov_token, filters="")
    tokenizer.fit_on_texts(all_captions)
    return tokenizer

In [17]:
def captions_to_sequences(
    filtered_captions: Dict[str, List[str]], tokenizer: Tokenizer
) -> Dict[str, List[List[int]]]:
    """
    Convert captions to sequences of token IDs.

    Args:
        filtered_captions (Dict[str, List[str]]): Dictionary of cleaned and filtered image captions.
        tokenizer (Tokenizer): Fitted Keras tokenizer.

    Returns:
        Dict[str, List[List[int]]]: Caption token sequences by image ID.
    """
    seq_dict = {}
    for img_id, captions in filtered_captions.items():
        seq_dict[img_id] = tokenizer.texts_to_sequences(captions)
    return seq_dict

In [20]:
def pad_caption_sequences(
    seq_dict: Dict[str, List[List[int]]], max_length: int
) -> Dict[str, List[List[int]]]:
    """
    Pad or truncate token sequences to a uniform length.

    Args:
        seq_dict (Dict[str, List[List[int]]]): Tokenized caption sequences.
        max_length (int): Max allowed sequence length.

    Returns:
        Dict[str, List[List[int]]]: Padded token sequences by image ID.
    """
    padded_dict = {}
    for img_id, sequences in seq_dict.items():
        padded = pad_sequences(
            sequences, maxlen=max_length, padding="post", truncating="post"
        ).tolist()
        padded_dict[img_id] = padded
    return padded_dict

In [23]:
def compute_max_caption_length(
    seq_dict: Dict[str, List[str]], quantile: float = 0.95
) -> int:
    """
    Compute the maximum caption length based on a given percentile of sequence lengths.

    Args:
        seq_dict (Dict[str, List[List[int]]]): Tokenized caption sequences.
        quantile (float): Percentile cutoff for max length (e.g., 0.95 means ignore top 5% longest captions).

    Returns:
        int: Computed maximum sequence length for padding.
    """
    lengths = [len(seq) for seqs in seq_dict.values() for seq in seqs]

    # Use percentile cutoff to ignore extreme outliers
    max_len = int(np.quantile(lengths, quantile))
    return max_len

In [24]:
def save_padded_sequences(
    padded_dict: Dict[str, List[List[int]]], filepath: str
) -> None:
    """
    Save padded caption sequences to a .npz file.

    Args:
        padded_dict (Dict[str, List[List[int]]]): Dict mapping image IDs to lists of padded sequences.
        filepath (str): Output file path (should end in .npz).
    """
    npz_dict = {
        img_id: np.array(seqs, dtype=np.int32) for img_id, seqs in padded_dict.items()
    }
    np.savez_compressed(filepath, **npz_dict)

In [25]:
def load_padded_sequences(filepath: str) -> Dict[str, List[List[int]]]:
    """
    Load padded caption sequences from a .npz file.

    Args:
        filepath (str): Path to the .npz file.

    Returns:
        Dict[str, List[List[int]]]: Restored padded sequences.
    """
    data = np.load(filepath, allow_pickle=True)
    return {img_id: data[img_id].tolist() for img_id in data.files}

In [None]:
# captions_path = "../data/raw/flickr8k_captions.csv"
captions_path = "../data/raw/flickr30k_captions.csv"
captions_dict = load_and_clean_captions(captions_path)
filtered_captions, vocab = filter_captions_by_frequency(captions_dict, min_word_freq=MIN_WORD_FREQ)
tokenizer = fit_tokenizer(filtered_captions, num_words=10000)
seqs = captions_to_sequences(filtered_captions, tokenizer)
max_length = compute_max_caption_length(seqs, quantile=0.95)
padded_seqs = pad_caption_sequences(seqs, max_length=max_length)
print(f"# filtered captions: {len(filtered_captions)}")
print(f"# sequences with tokens: {sum(len(seqs) for seqs in seqs.values())}")
print("Example sequences:", next(iter(seqs.values()), []))
print("Max caption length:", max_length)

# filtered captions: 31783
# sequences with tokens: 157624
Example sequences: [[3, 14, 22, 310, 12, 2182, 116, 195, 19, 63, 165, 27, 325, 73, 5, 6, 471, 4], [3, 14, 22, 4], [3, 14, 30, 5, 51, 262, 16, 35, 5, 2, 471, 4], [3, 2, 8, 5, 2, 28, 23, 35, 5, 2, 686, 4], [3, 14, 457, 786, 586, 15, 134, 4]]
Max caption length: 22


In [42]:
# Save padded sequences
# output_path = "../data/processed/flickr8k_padded_caption_sequences.npz"
output_path = "../data/processed/flickr30k_padded_caption_sequences.npz"
save_padded_sequences(padded_seqs, output_path)

In [43]:
# Load padded sequences
loaded = load_padded_sequences(output_path)
# Example
# image = "1000268201_693b08cb0e.jpg" # flickr8k
image = "1000092795.jpg"  # flickr30k
# Print each padded caption sequence associated with the image
for idx in range(5):
    print(f"Caption {idx}:\n", loaded[image][idx])

Caption 0:
 [3, 14, 22, 310, 12, 2182, 116, 195, 19, 63, 165, 27, 325, 73, 5, 6, 471, 4, 0, 0, 0, 0]
Caption 1:
 [3, 14, 22, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Caption 2:
 [3, 14, 30, 5, 51, 262, 16, 35, 5, 2, 471, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Caption 3:
 [3, 2, 8, 5, 2, 28, 23, 35, 5, 2, 686, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Caption 4:
 [3, 14, 457, 786, 586, 15, 134, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [None]:
# TODO: Need functions to save and load trained tokenizer
# save_tokenizer(filepath)
# load_tokenizer(filepath)

## Sandbox

In [None]:
def save_tokenizer(tokenizer: Tokenizer, filepath: str) -> None:
    """
    Save tokenizer to a pickle file.

    Args:
        tokenizer (Tokenizer): Fitted Keras tokenizer.
        filepath (str): Output path for saving.
    """
    with open(filepath, "wb") as f:
        pickle.dump(tokenizer, f)

In [None]:
def load_tokenizer(filepath: str) -> Tokenizer:
    """
    Load a tokenizer from a pickle file.

    Args:
        filepath (str): Path to the tokenizer file.

    Returns:
        Tokenizer: Loaded tokenizer.
    """
    with open(filepath, "rb") as f:
        return pickle.load(f)

In [None]:
def save_tokenizer_json(tokenizer, filepath: str) -> None:
    """
    Save a Keras tokenizer to a JSON file.

    Args:
        tokenizer (Tokenizer): Trained tokenizer.
        filepath (str): Destination JSON path.
    """
    tokenizer_json = tokenizer.to_json()
    with open(filepath, "w", encoding="utf-8") as f:
        f.write(tokenizer_json)

In [None]:
def load_tokenizer_json(filepath: str):
    """
    Load a Keras tokenizer from a JSON file.

    Args:
        filepath (str): Path to saved JSON tokenizer.

    Returns:
        Tokenizer: Reconstructed tokenizer.
    """
    with open(filepath, "r", encoding="utf-8") as f:
        tokenizer_json = json.load(f)
    return tokenizer_from_json(tokenizer_json)