# Image Captioning with Beam Search (PyTorch)
Template notebook for training an image‑captioning model and generating captions with beam search.

In [369]:
# Configuration
import os, random
from collections import Counter
from typing import List

import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision import models, transforms as T
from PIL import Image

In [368]:
# download data from 
# https://www.kaggle.com/datasets/adityajn105/flickr8k/data

In [370]:
# 1. Configuration
# ---------------------------
IMG_DIR   = "images"        # folder with all JPGs
CAPT_FILE = "captions.txt"  # TSV file with <img>\t<caption>

EMBED_DIM  = 256
HIDDEN_DIM = 512
NUM_LAYERS = 1
BATCH_SIZE = 32
MAX_LEN    = 30
EPOCHS     = 10
LR         = 1e-3
BEAM_SIZE  = 3
SEED       = 42

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(SEED); torch.manual_seed(SEED)

<torch._C.Generator at 0x7f55200aa030>

In [371]:
# 2. Vocabulary helper
# ---------------------------
class Vocabulary:
    def __init__(self, freq_threshold:int=5):
        self.freq_threshold = freq_threshold
        self.itos = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
        self.stoi = {v:k for k,v in self.itos.items()}
    def tokenize(self,text:str)->List[str]:
        return text.lower().strip().split()
    def build(self,sents:List[str]):
        freqs = Counter()
        for s in sents: freqs.update(self.tokenize(s))
        idx=4
        for tok,f in freqs.items():
            if f>=self.freq_threshold:
                self.stoi[tok]=idx; self.itos[idx]=tok; idx+=1
    def numericalize(self,text:str)->List[int]:
        return [self.stoi.get(tok,self.stoi['<UNK>']) for tok in self.tokenize(text)]

In [372]:
# 3. Dataset & collate
# ---------------------------
class FlickrDataset(Dataset):
    def __init__(self, img_dir:str, captions_file:str, transform=None, freq_threshold:int=5):
        assert os.path.isdir(img_dir), f"Image folder '{img_dir}' not found!"
        assert os.path.isfile(captions_file), f"Captions file '{captions_file}' missing!"
        with open(captions_file, encoding='utf8') as f:
            self.data=[ln.strip().split('\t') for ln in f if '\t' in ln]
        self.img_dir=img_dir
        self.transform=transform
        self.vocab=Vocabulary(freq_threshold)
        self.vocab.build([c for _,c in self.data])
    def __len__(self): return len(self.data)
    def __getitem__(self,idx):
        img_id,cap=self.data[idx]
        img=Image.open(os.path.join(self.img_dir,img_id)).convert('RGB')
        if self.transform: img=self.transform(img)
        cap_idx=[self.vocab.stoi['<SOS>']]+self.vocab.numericalize(cap)+[self.vocab.stoi['<EOS>']]
        return img,torch.tensor(cap_idx)

class PadCollate:
    def __init__(self,pad_idx): self.pad_idx=pad_idx
    def __call__(self,batch):
        imgs=torch.stack([b[0] for b in batch])
        caps=pad_sequence([b[1] for b in batch], batch_first=True, padding_value=self.pad_idx)
        return imgs,caps


In [373]:
# 4. Encoder & Decoder (with optional beam search)
# ---------------------------
class EncoderCNN(nn.Module):
    def __init__(self, embed:int):
        super().__init__()
        res=models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        for p in res.parameters(): p.requires_grad=False
        self.backbone=nn.Sequential(*list(res.children())[:-1])
        self.flat=nn.Flatten()
        self.fc=nn.Linear(res.fc.in_features, embed)
        self.relu=nn.ReLU()
    def forward(self,x):
        with torch.no_grad(): x=self.backbone(x)
        x=self.flat(x)
        return self.relu(self.fc(x))

class DecoderRNN(nn.Module):
    def __init__(self, embed, hidden, vocab, num_layers=1, dropout=0.3):
        super().__init__()
        self.embed=nn.Embedding(vocab, embed)
        self.lstm=nn.LSTM(embed, hidden, num_layers)
        self.fc=nn.Linear(hidden, vocab)
        self.drop=nn.Dropout(dropout)
    def forward(self, feats, caps):
        caps=caps[:-1]
        emb=self.drop(self.embed(caps))
        emb=torch.cat((feats.unsqueeze(0), emb), 0)
        h,_=self.lstm(emb)
        return self.fc(h)
    def beam_search(self, feats, vocab, beam=BEAM_SIZE, max_len=MAX_LEN):
        seqs=[(0.0,[],None,feats.unsqueeze(0))]
        finished=[]
        for _ in range(max_len):
            cand=[]
            for sc,seq,st,inp in seqs:
                if seq and seq[-1]==vocab.stoi['<EOS>']:
                    finished.append((sc,seq)); continue
                h,ns=self.lstm(inp,st)
                logp=torch.log_softmax(self.fc(h.squeeze(0)),dim=-1)
                lp,idx=logp.topk(beam)
                for k in range(beam):
                    cand.append((sc+lp[0,k].item(), seq+[idx[0,k].item()], ns, self.embed(idx[0,k]).unsqueeze(0)))
            seqs=sorted(cand,key=lambda t:t[0],reverse=True)[:beam]
            if not seqs: break
        finished += [(sc,seq) for sc,seq,_,_ in seqs]
        return max(finished,key=lambda t:t[0])[1]

In [374]:
# 5. Training routine
# ---------------------------

def train(output_path="model.pth"):
    transform=T.Compose([
        T.Resize(256), T.CenterCrop(224), T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
    ds=FlickrDataset(IMG_DIR, CAPT_FILE, transform)
    pad_idx=ds.vocab.stoi['<PAD>']
    dl=DataLoader(ds,batch_size=BATCH_SIZE,shuffle=True, collate_fn=PadCollate(pad_idx))

    enc=EncoderCNN(EMBED_DIM).to(DEVICE)
    dec=DecoderRNN(EMBED_DIM,HIDDEN_DIM,len(ds.vocab.stoi),NUM_LAYERS).to(DEVICE)

    criterion=nn.CrossEntropyLoss(ignore_index=pad_idx)
    params=list(dec.parameters())+list(enc.fc.parameters())
    optim=torch.optim.Adam(params, lr=LR)

    enc.train(); dec.train()
    for epoch in range(EPOCHS):
        total=0
        for imgs,caps in dl:
            imgs,caps=imgs.to(DEVICE),caps.to(DEVICE)
            feats=enc(imgs)
            outputs=dec(feats,caps.t())
            loss=criterion(outputs.reshape(-1, outputs.size(-1)), caps.reshape(-1))
            optim.zero_grad(); loss.backward(); optim.step()
            total+=loss.item()
        print(f"Epoch {epoch+1}/{EPOCHS} | loss: {total/len(dl):.4f}")
    torch.save({'encoder':enc.state_dict(),'decoder':dec.state_dict(),'vocab':ds.vocab.stoi}, output_path)

In [375]:
def generate_caption(img_path:str, model_path="model.pth", beam=True, beam_size=BEAM_SIZE):
    ckpt=torch.load(model_path, map_location=DEVICE)
    stoi=ckpt['vocab']; vocab=Vocabulary(); vocab.stoi=stoi; vocab.itos={i:t for t,i in stoi.items()}
    enc=EncoderCNN(EMBED_DIM).to(DEVICE); enc.load_state_dict(ckpt['encoder']); enc.eval()
    dec=DecoderRNN(EMBED_DIM,HIDDEN_DIM,len(stoi)).to(DEVICE); dec.load_state_dict(ckpt['decoder']); dec.eval()
    trans=T.Compose([T.Resize(256),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
    img=trans(Image.open(img_path).convert('RGB')).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        feat=enc(img)
        ids=dec.beam_search(feat,vocab,beam_size) if beam else dec.beam_search(feat,vocab,1)
    tokens=[vocab.itos[i] for i in ids if i not in {vocab.stoi['<EOS>']}]
    return ' '.join(tokens)
