# Image Captioning - Data Preprocessing

In [24]:
import re
import string
from pathlib import Path
from collections import defaultdict, Counter
import torch
from torch.nn.utils.rnn import pad_sequence
import json
import pickle
from torch.utils.data import Dataset
from PIL import Image
import sys 


In [None]:

# Path to captions
captions_file = Path("../data/Flickr8k_text/Flickr8k.token.txt")

# Load all lines
with open(captions_file, 'r') as f:
    lines = f.readlines()

# Preview a few lines
for line in lines[:5]:
    print(line.strip())

1000268201_693b08cb0e.jpg#0	A child in a pink dress is climbing up a set of stairs in an entry way .
1000268201_693b08cb0e.jpg#1	A girl going into a wooden building .
1000268201_693b08cb0e.jpg#2	A little girl climbing into a wooden playhouse .
1000268201_693b08cb0e.jpg#3	A little girl climbing the stairs to her playhouse .
1000268201_693b08cb0e.jpg#4	A little girl in a pink dress going into a wooden cabin .


In [None]:
# Initialize a dictionary with lists
image_captions = defaultdict(list)

# Parse each line
for line in lines:
    line = line.strip()
    if '\t' in line:
        img_id_with_index, caption = line.split('\t')
        img_id = img_id_with_index.split('#')[0]
        image_captions[img_id].append(caption)

# Preview example
example_key = list(image_captions.keys())[0]
print(f"Image: {example_key}")
print("Captions:")
for cap in image_captions[example_key]:
    print(f"- {cap}")

Image: 1000268201_693b08cb0e.jpg
Captions:
- A child in a pink dress is climbing up a set of stairs in an entry way .
- A girl going into a wooden building .
- A little girl climbing into a wooden playhouse .
- A little girl climbing the stairs to her playhouse .
- A little girl in a pink dress going into a wooden cabin .


In [4]:
def clean_caption(caption):
    # Convert to lowercase
    caption = caption.lower()
    
    # Remove punctuation
    caption = caption.translate(str.maketrans('', '', string.punctuation))
    
    # Remove numbers (optional)
    caption = re.sub(r'\d+', '', caption)
    
    # Remove extra whitespace
    caption = caption.strip()
    
    # Add special tokens
    caption = f"<start> {caption} <end>"
    
    return caption

In [7]:
# Apply cleaning to all captions
for img_id in image_captions:
    cleaned = [clean_caption(c) for c in image_captions[img_id]]
    image_captions[img_id] = cleaned

# Preview cleaned captions
print(f"Cleaned captions for {example_key}:")
for cap in image_captions[example_key]:
    print(f"- {cap}")

Cleaned captions for 1000268201_693b08cb0e.jpg:
- <start> a child in a pink dress is climbing up a set of stairs in an entry way <end>
- <start> a girl going into a wooden building <end>
- <start> a little girl climbing into a wooden playhouse <end>
- <start> a little girl climbing the stairs to her playhouse <end>
- <start> a little girl in a pink dress going into a wooden cabin <end>


In [9]:
# Flatten all captions into a single list of words
all_captions = []
for captions in image_captions.values():
    for cap in captions:
        all_captions.extend(cap.split())

# Count word frequencies
word_freq = Counter(all_captions)

# Minimum word frequency threshold
min_word_freq = 5

# Filter out rare words
words = [w for w in word_freq if word_freq[w] >= min_word_freq]

# Special tokens
special_tokens = ['<pad>', '<start>', '<end>', '<unk>']

# Final vocabulary
vocab = special_tokens + sorted(words)

# Build mappings
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}

print(f"Total words in vocab (after filtering): {len(vocab)}")
print(f"Sample mapping: 'girl' → {word2idx.get('girl')}")

Total words in vocab (after filtering): 2990
Sample mapping: 'girl' → 1055


In [10]:
# Convert a caption into a list of word indices
def caption_to_indices(caption, word2idx):
    return [word2idx.get(word, word2idx['<unk>']) for word in caption.split()]

# Store caption sequences in a new dictionary
image_caption_seqs = {}

for img_id, captions in image_captions.items():
    image_caption_seqs[img_id] = [caption_to_indices(cap, word2idx) for cap in captions]

# Preview example
print(f"Indexed captions for {example_key}:")
for seq in image_caption_seqs[example_key]:
    print(seq)

Indexed captions for 1000268201_693b08cb0e.jpg:
[5, 6, 496, 1278, 6, 1875, 768, 1311, 530, 2821, 6, 2220, 1695, 2468, 1278, 54, 3, 2902, 4]
[5, 6, 1055, 1076, 1309, 6, 2957, 354, 4]
[5, 6, 1479, 1055, 530, 1309, 6, 2957, 1906, 4]
[5, 6, 1479, 1055, 530, 2657, 2468, 2699, 1198, 1906, 4]
[5, 6, 1479, 1055, 1278, 6, 1875, 768, 1076, 1309, 6, 2957, 3, 4]


In [12]:
# Convert lists of token IDs into tensors
all_seqs = []
for img_id, caption_list in image_caption_seqs.items():
    for seq in caption_list:
        all_seqs.append(torch.tensor(seq, dtype=torch.long))

# Pad all sequences to the same length (returns tensor of shape [num_captions, max_len])
padded_seqs = pad_sequence(all_seqs, batch_first=True, padding_value=word2idx['<pad>'])

# Optionally get sequence lengths (not padded)
seq_lengths = torch.tensor([len(seq) for seq in all_seqs], dtype=torch.long)

# Preview
print(f"Padded shape: {padded_seqs.shape}")
print(f"First padded caption: {padded_seqs[0]}")
print(f"Original length: {seq_lengths[0]}")

Padded shape: torch.Size([40460, 37])
First padded caption: tensor([   5,    6,  496, 1278,    6, 1875,  768, 1311,  530, 2821,    6, 2220,
        1695, 2468, 1278,   54,    3, 2902,    4,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0])
Original length: 19


In [15]:
save_dir = Path("../data/processed/")
save_dir.mkdir(parents=True, exist_ok=True)

# Save word2idx
with open(save_dir / "word2idx.json", 'w') as f:
    json.dump(word2idx, f)

# Save idx2word
with open(save_dir / "idx2word.json", 'w') as f:
    json.dump(idx2word, f)

print("Vocab mappings saved.")

Vocab mappings saved.


In [16]:
torch.save(padded_seqs, save_dir / "padded_captions.pt")
torch.save(seq_lengths, save_dir / "caption_lengths.pt")

print("Caption tensors saved.")

Caption tensors saved.


In [18]:
with open(save_dir / "image_caption_seqs.pkl", "wb") as f:
    pickle.dump(image_caption_seqs, f)

print("Image-caption sequences saved.")

Image-caption sequences saved.


In [20]:
class CaptionDataset(Dataset):
    def __init__(self, image_folder, captions_tensor, lengths_tensor, image_filenames, transform=None):
        """
        image_folder: path to folder with images (e.g., Flicker8k_Dataset)
        captions_tensor: Tensor of padded caption sequences
        lengths_tensor: Tensor of original caption lengths
        image_filenames: list of image filenames (one per caption)
        transform: torchvision transforms for the image
        """
        self.image_folder = image_folder
        self.captions = captions_tensor
        self.lengths = lengths_tensor
        self.image_filenames = image_filenames
        self.transform = transform

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_folder, self.image_filenames[idx])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Get caption and length
        caption = self.captions[idx]
        length = self.lengths[idx]

        return image, caption, length

In [25]:
sys.path.append("..")  # If utils is in the parent folder
# or
sys.path.append("../")  # Adjust based on where your notebook is relative to the repo root
from utils.dataloader import get_transforms, load_split_ids
from utils.caption_dataset import CaptionDataset

train_ids = load_split_ids("../data/Flickr8k_text/Flickr_8k.trainImages.txt")
val_ids = train_ids[-1000:]        # Last 1000 as validation (or use dev split)
train_ids = train_ids[:-1000]

transform = get_transforms('train')

# Map image IDs to their corresponding caption tensors
# Flatten caption sequences and match with filenames
image_filenames = []
caption_tensors = []
lengths = []

for img_id, captions in image_caption_seqs.items():
    if img_id in train_ids:
        for seq in captions:
            image_filenames.append(img_id)
            caption_tensors.append(torch.tensor(seq))
            lengths.append(len(seq))

# Pad
padded_seqs = pad_sequence(caption_tensors, batch_first=True, padding_value=word2idx['<pad>'])
lengths_tensor = torch.tensor(lengths)

# Build dataset
train_dataset = CaptionDataset(
    image_folder="../data/Flicker8k_Dataset",
    captions_tensor=padded_seqs,
    lengths_tensor=lengths_tensor,
    image_filenames=image_filenames,
    transform=transform
)