# Image Captioning với Attention – Lab Notebook

Notebook này minh họa các pipeline Image Captioning với **Encoder–Decoder có Attention**,
áp dụng cho các lựa chọn sau (chọn trong phần cấu hình):

1. `cnn_lstm_attn`: Encoder CNN – Decoder LSTM + Attention (có thể dùng GloVe).
2. `cnn_tf`: Encoder CNN – Decoder Transformer.
3. `vit_tf`: Encoder ViT – Decoder Transformer.
4. `vit_tf_gnn`: Như `vit_tf` nhưng bổ sung thông tin từ Graph NN nhẹ.

Các bước chính trong notebook:

1. Đọc dữ liệu caption Flickr8K, tiền xử lý, tạo vocab.
2. Tạo input phần ngôn ngữ (chuỗi token).
3. Dùng Encoder (pretrained) để trích xuất feature vector / memory từ ảnh.
4. Kết hợp Attention để tạo context vector chuyển sang Decoder.
5. Huấn luyện Decoder (phần ngôn ngữ).
6. Tính BLEU trên tập validation.
7. Chạy inference trên 1 hoặc nhiều ảnh.
8. Trực quan hóa chú ý (attention heatmap / mask) trên ảnh tương ứng caption sinh ra.


In [None]:
# =========================
# CẤU HÌNH CHUNG
# =========================
import os

# CHỌN KIẾN TRÚC:
# 'cnn_lstm_attn', 'cnn_tf', 'vit_tf', 'vit_tf_gnn'
ARCH = 'cnn_lstm_attn'

# Đường dẫn dữ liệu Flickr8K
DATA_ROOT = '/kaggle/input/flickr8k'  # sửa lại cho phù hợp môi trường
IMG_DIR = os.path.join(DATA_ROOT, 'Images')
CAPTION_FILE = os.path.join(DATA_ROOT, 'Flickr8k_captions.txt')

# Đường dẫn file GloVe (chỉ dùng cho ARCH = 'cnn_lstm_attn' nếu muốn)
GLOVE_PATH = '/kaggle/input/glove-6b/glove.6B.50d.txt'  # sửa lại nếu cần
USE_GLOVE = True  # đặt False nếu không có file GloVe

# Cấu hình train
BATCH_SIZE = 32
MAX_LEN = 20
FREQ_MIN = 5
VOCAB_MAX_SIZE = 10000
EMB_DIM = 256     # sẽ override nếu dùng GloVe 50d
HIDDEN_DIM = 512
LR = 1e-4
EPOCHS = 3        # demo; có thể tăng thêm

DEVICE = 'cuda' if __import__('torch').cuda.is_available() else 'cpu'
print("Using device:", DEVICE)

In [None]:
# =========================
# IMPORTS & HÀM TIỆN ÍCH
# =========================
import re
import math
import json
import random
import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms

try:
    import nltk
    from nltk.translate.bleu_score import corpus_bleu
except ImportError:
    nltk = None
    print("Không tìm thấy nltk. BLEU sẽ không dùng được nếu chưa cài nltk.")

PAD, BOS, EOS, UNK = "<pad>", "<bos>", "<eos>", "<unk>"

def tokenize(text):
    return re.findall(r"\w+|\S", text.lower())

In [None]:
# =========================
# XÂY DỰNG VOCAB & DATASET
# =========================
def build_vocab(caption_file, freq_min=5, max_size=10000):
    counter = Counter()
    with open(caption_file, 'r', encoding='utf-8') as f:
        for line in f:
            if '\t' not in line:
                continue
            _, cap = line.strip().split('\t')
            toks = tokenize(cap)
            counter.update(toks)
    itos = [PAD, BOS, EOS, UNK]
    for w, c in counter.most_common():
        if c < freq_min:
            break
        if len(itos) >= max_size:
            break
        itos.append(w)
    stoi = {w: i for i, w in enumerate(itos)}
    print("Vocab size:", len(itos))
    return itos, stoi

def encode_caption(cap, stoi, max_len=20):
    toks = tokenize(cap)
    ids = [stoi.get(t, stoi[UNK]) for t in toks]
    ids = [stoi[BOS]] + ids + [stoi[EOS]]
    if len(ids) < max_len:
        ids = ids + [stoi[PAD]] * (max_len - len(ids))
    return torch.tensor(ids[:max_len], dtype=torch.long)

class Flickr8kDataset(Dataset):
    def __init__(self, img_dir, caption_file, itos, stoi,
                 max_len=20, transform=None, split='train', split_ratio=0.8):
        self.img_dir = img_dir
        self.itos, self.stoi = itos, stoi
        self.max_len = max_len
        self.transform = transform
        self.samples = []  # list of (img_name, caption)

        lines = []
        with open(caption_file, 'r', encoding='utf-8') as f:
            for line in f:
                if '\t' not in line:
                    continue
                lines.append(line.strip())
        random.seed(42)
        random.shuffle(lines)
        n_train = int(len(lines) * split_ratio)
        if split == 'train':
            use_lines = lines[:n_train]
        else:
            use_lines = lines[n_train:]
        for line in use_lines:
            img_id, cap = line.split('\t')
            img_name = img_id.split('#')[0]
            self.samples.append((img_name, cap))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, cap = self.samples[idx]
        path = os.path.join(self.img_dir, img_name)
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        cap_ids = encode_caption(cap, self.stoi, self.max_len)
        return img, cap_ids, img_name

In [None]:
# =========================
# TẠO DATALOADER
# =========================
# Transform cho CNN (và tạm thời cho mọi kiến trúc, trừ khi dùng ViT weights.transforms())
transform_cnn = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

itos, stoi = build_vocab(CAPTION_FILE, freq_min=FREQ_MIN,
                         max_size=VOCAB_MAX_SIZE)

train_ds = Flickr8kDataset(IMG_DIR, CAPTION_FILE, itos, stoi,
                           max_len=MAX_LEN, transform=transform_cnn,
                           split='train', split_ratio=0.8)
val_ds = Flickr8kDataset(IMG_DIR, CAPTION_FILE, itos, stoi,
                         max_len=MAX_LEN, transform=transform_cnn,
                         split='val', split_ratio=0.8)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=0)

print("Train samples:", len(train_ds), "Val samples:", len(val_ds))

In [None]:
# =========================
# NẠP GloVe (TÙY CHỌN CHO ARCH = 'cnn_lstm_attn')
# =========================
glove_vectors = None
if ARCH == 'cnn_lstm_attn' and USE_GLOVE:
    if not os.path.exists(GLOVE_PATH):
        print("Không tìm thấy file GloVe, sẽ dùng embedding train từ đầu.")
        USE_GLOVE = False
    else:
        print("Đang đọc GloVe từ:", GLOVE_PATH)
        glove_vectors = {}
        with open(GLOVE_PATH, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) < 51:
                    continue
                w = parts[0]
                vec = np.asarray(parts[1:], dtype=np.float32)
                glove_vectors[w] = vec
        print("Số từ trong GloVe:", len(glove_vectors))
        EMB_DIM = 50  # cố định theo glove.6B.50d

In [None]:
# =========================
# ENCODER: CNN & ViT
# =========================
class CNNEncoder(nn.Module):
    def __init__(self, out_dim=512):
        super().__init__()
        m = torchvision.models.resnet50(
            weights=torchvision.models.ResNet50_Weights.DEFAULT
        )
        self.cnn = nn.Sequential(*list(m.children())[:-2])  # B×2048×7×7
        self.proj = nn.Conv2d(2048, out_dim, kernel_size=1)
        self.out_dim = out_dim

    def forward(self, x):
        f = self.cnn(x)               # B×2048×7×7
        f = self.proj(f)              # B×D×7×7
        B, D, H, W = f.shape
        memory = f.flatten(2).permute(0, 2, 1)  # B×L×D, L=H*W
        global_vec = f.mean(dim=[2, 3])         # B×D
        return global_vec, memory, (H, W)

class ViTEncoder(nn.Module):
    def __init__(self, out_dim=512):
        super().__init__()
        weights = torchvision.models.ViT_B_16_Weights.DEFAULT
        m = torchvision.models.vit_b_16(weights=weights)
        self.preprocess = weights.transforms()
        self.vit = m
        self.proj = nn.Linear(m.hidden_dim, out_dim)
        self.out_dim = out_dim

    def forward(self, x):
        # x: B×3×H×W, cần resize + normalize theo weights
        x = self.preprocess.transforms[0](x) if hasattr(self.preprocess, 'transforms') else x
        # để đơn giản, giả sử đầu vào đã được chuẩn hóa giống transform_cnn
        x = self.vit._process_input(x)
        n = x.shape[0]
        cls_token = self.vit.class_token.expand(n, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.vit.encoder(x)           # B×(1+L)×Dh
        x = self.proj(x)                  # B×(1+L)×D
        cls = x[:, 0]                     # B×D
        tokens = x[:, 1:]                 # B×L×D
        return cls, tokens, None  # không có grid H×W rõ ràng

In [None]:
# =========================
# DECODER: LSTM + ATTENTION
# =========================
class AdditiveAttention(nn.Module):
    def __init__(self, dim_q, dim_k, dim_h):
        super().__init__()
        self.Wq = nn.Linear(dim_q, dim_h)
        self.Wk = nn.Linear(dim_k, dim_h)
        self.v  = nn.Linear(dim_h, 1)

    def forward(self, q, k, mask=None):
        # q: B×H, k: B×L×D
        B, L, D = k.shape
        q_ = self.Wq(q).unsqueeze(1).expand(B, L, -1)
        k_ = self.Wk(k)
        e = self.v(torch.tanh(q_ + k_)).squeeze(-1)  # B×L
        if mask is not None:
            e = e.masked_fill(~mask.bool(), -1e9)
        a = F.softmax(e, dim=-1)                     # B×L
        c = torch.bmm(a.unsqueeze(1), k).squeeze(1)  # B×D
        return c, a

class LSTMAttnDecoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, mem_dim,
                 padding_idx=0, pretrained_emb=None):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=padding_idx)
        if pretrained_emb is not None:
            self.emb.weight.data.copy_(torch.from_numpy(pretrained_emb))
        self.lstm = nn.LSTM(emb_dim + mem_dim, hidden_dim,
                            batch_first=True)
        self.attn = AdditiveAttention(hidden_dim, mem_dim, dim_h=256)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.hidden_dim = hidden_dim

    def forward(self, captions, memory):
        B, T = captions.shape
        emb = self.emb(captions)         # B×T×E
        h = torch.zeros(1, B, self.hidden_dim, device=captions.device)
        c = torch.zeros(1, B, self.hidden_dim, device=captions.device)
        outputs = []
        for t in range(T):
            h_t = h[-1]                  # B×H
            ctx, _ = self.attn(h_t, memory)
            x_t = torch.cat([emb[:, t, :], ctx], dim=-1).unsqueeze(1)
            o_t, (h, c) = self.lstm(x_t, (h, c))
            logits_t = self.fc(o_t.squeeze(1))
            outputs.append(logits_t.unsqueeze(1))
        return torch.cat(outputs, dim=1)

In [None]:
# =========================
# DECODER: TRANSFORMER (CÓ LẤY RA ATTENTION)
# =========================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(
            torch.arange(0, d_model, 2).float() *
            (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        L = x.size(1)
        return x + self.pe[:, :L, :]

class CustomDecoderLayer(nn.TransformerDecoderLayer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, tgt, memory,
                tgt_mask=None,
                memory_mask=None,
                tgt_key_padding_mask=None,
                memory_key_padding_mask=None,
                need_attn=False):
        # copy từ PyTorch, thêm trả về attn cross nếu need_attn=True
        x = tgt
        if self.norm_first:
            x2 = self.self_attn(
                self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
            )[0]
            x = x + self.dropout1(x2)
            x2, attn_cross = self.multihead_attn(
                self._mha_block(self.norm2(x), memory, memory_mask,
                                memory_key_padding_mask)
            )
            x = x + self.dropout2(x2)
            x2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm3(x)))))
            x = x + self.dropout3(x2)
        else:
            x2 = self.self_attn(
                x, x, x,
                attn_mask=tgt_mask,
                key_padding_mask=tgt_key_padding_mask
            )[0]
            x = x + self.dropout1(x2)
            x = self.norm1(x)
            x2, attn_cross = self.multihead_attn(
                x, memory, memory,
                attn_mask=memory_mask,
                key_padding_mask=memory_key_padding_mask
            )
            x = x + self.dropout2(x2)
            x = self.norm2(x)
            x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
            x = x + self.dropout3(x2)
            x = self.norm3(x)
        if need_attn:
            return x, attn_cross
        return x, None

class TransformerCaptionDecoder(nn.Module):
    def __init__(self, vocab_size,
                 d_model=512, nhead=8, num_layers=6,
                 padding_idx=0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
        self.pos = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([
            CustomDecoderLayer(d_model, nhead,
                               dim_feedforward=1024,
                               batch_first=True)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, captions_in, memory,
                tgt_padding_mask=None, mem_padding_mask=None,
                return_attn=False):
        # captions_in: B×T_in (BOS..)
        x = self.emb(captions_in) * math.sqrt(self.d_model)
        x = self.pos(x)
        T = x.size(1)
        causal = torch.triu(
            torch.ones(T, T, device=x.device), 1
        ).bool()
        attn_last = None
        for i, layer in enumerate(self.layers):
            x, attn = layer(
                x, memory,
                tgt_mask=causal,
                memory_mask=None,
                tgt_key_padding_mask=tgt_padding_mask,
                memory_key_padding_mask=mem_padding_mask,
                need_attn=(return_attn and i == len(self.layers)-1)
            )
            if attn is not None:
                # attn: B×num_heads×T×L
                attn_last = attn
        x = self.norm(x)
        logits = self.fc(x)  # B×T×V
        if return_attn:
            return logits, attn_last
        return logits, None

In [None]:
# =========================
# GNN NHẸ (Simple GCN)
# =========================
class SimpleGCN(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.lin1 = nn.Linear(in_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, adj):
        # x: B×N×D, adj: B×N×N (0/1)
        deg = adj.sum(-1, keepdim=True) + 1e-6
        A_norm = adj / deg
        h = torch.bmm(A_norm, self.lin1(x))  # B×N×H
        h = F.relu(h)
        h = torch.bmm(A_norm, self.lin2(h))  # B×N×H
        return h

In [None]:
# =========================
# KHỞI TẠO MÔ HÌNH
# =========================
vocab_size = len(itos)

encoder = None
decoder = None
gcn = None

if ARCH in ['cnn_lstm_attn', 'cnn_tf']:
    encoder = CNNEncoder(out_dim=512).to(DEVICE)
elif ARCH in ['vit_tf', 'vit_tf_gnn']:
    encoder = ViTEncoder(out_dim=512).to(DEVICE)
else:
    raise ValueError("ARCH không hợp lệ")

# Chuẩn bị embedding GloVe (nếu dùng LSTM + Attention)
pretrained_emb = None
if ARCH == 'cnn_lstm_attn':
    if USE_GLOVE and glove_vectors is not None:
        emb_matrix = np.random.normal(scale=0.1, size=(vocab_size, EMB_DIM)).astype(np.float32)
        for i, w in enumerate(itos):
            if w in glove_vectors:
                emb_matrix[i] = glove_vectors[w]
        pretrained_emb = emb_matrix
        print("Khởi tạo embedding từ GloVe với shape:", emb_matrix.shape)
    decoder = LSTMAttnDecoder(vocab_size, EMB_DIM, HIDDEN_DIM, mem_dim=512,
                              padding_idx=stoi[PAD],
                              pretrained_emb=pretrained_emb).to(DEVICE)
else:
    # Transformer decoder
    decoder = TransformerCaptionDecoder(vocab_size, d_model=512,
                                        nhead=8, num_layers=4,
                                        padding_idx=stoi[PAD]).to(DEVICE)

# GNN cho ARCH vit_tf_gnn (demo: sinh features object random)
if ARCH == 'vit_tf_gnn':
    gcn = SimpleGCN(in_dim=256, hidden_dim=256).to(DEVICE)

# Optimizer & loss
criterion = nn.CrossEntropyLoss(ignore_index=stoi[PAD])
params = list(decoder.parameters())
# encoder ta có thể fine-tune nhẹ hoặc freeze; để đơn giản: freeze CNN
for p in encoder.parameters():
    p.requires_grad = False
optimizer = torch.optim.Adam(params, lr=LR)

print("Model ready.")

In [None]:
# =========================
# HÀM TRAIN & EVAL
# =========================
def forward_batch(images, captions, encoder, decoder, arch,
                  gcn=None):
    images = images.to(DEVICE)
    captions = captions.to(DEVICE)
    bos_id = stoi[BOS]

    if arch in ['cnn_lstm_attn', 'cnn_tf']:
        global_vec, memory, grid = encoder(images)
    else:
        global_vec, memory, grid = encoder(images)

    if arch == 'vit_tf_gnn' and gcn is not None:
        # DEMO: tạo object features & adjacency random
        B = images.size(0)
        N = 5
        obj_feats = torch.randn(B, N, 256, device=images.device)
        adj = torch.ones(B, N, N, device=images.device)
        x_gnn = gcn(obj_feats, adj)  # B×N×256
        graph_vec = x_gnn.mean(dim=1)  # B×256
        # fuse vào memory bằng cách cộng bias
        bias = graph_vec.unsqueeze(1)  # B×1×256
        if memory.size(-1) == bias.size(-1):
            memory = memory + bias

    inp = captions[:, :-1]
    tgt = captions[:, 1:]

    if arch == 'cnn_lstm_attn':
        logits = decoder(inp, memory)  # B×T×V
        loss = criterion(logits.reshape(-1, logits.size(-1)),
                         tgt.reshape(-1))
        return loss, None
    else:
        logits, _ = decoder(inp, memory)
        loss = criterion(logits.reshape(-1, logits.size(-1)),
                         tgt.reshape(-1))
        return loss, None


def train_one_epoch(loader, encoder, decoder, arch, optimizer, gcn=None):
    decoder.train()
    total_loss = 0.0
    for images, caps, _ in loader:
        optimizer.zero_grad()
        loss, _ = forward_batch(images, caps, encoder, decoder, arch, gcn)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
    return total_loss / len(loader.dataset)


def evaluate_bleu(loader, encoder, decoder, arch, gcn=None, max_samples=200):
    if nltk is None:
        print("Chưa có nltk, không tính BLEU được.")
        return None
    decoder.eval()
    all_refs = []
    all_hyps = []
    cnt = 0
    with torch.no_grad():
        for images, caps, _ in loader:
            for i in range(images.size(0)):
                if cnt >= max_samples:
                    break
                img = images[i:i+1].to(DEVICE)
                ref = caps[i].tolist()
                hyp_tokens = greedy_decode(encoder, decoder, arch, img, gcn=gcn)
                # ref: convert to tokens without pad/bos
                ref_tokens = []
                for idx in ref:
                    if idx == stoi[PAD] or idx == stoi[BOS]:
                        continue
                    if idx == stoi[EOS]:
                        break
                    ref_tokens.append(itos[idx])
                all_refs.append([ref_tokens])
                all_hyps.append(hyp_tokens)
                cnt += 1
            if cnt >= max_samples:
                break
    bleu4 = corpus_bleu(all_refs, all_hyps, weights=(0.25,0.25,0.25,0.25))
    print("BLEU-4:", bleu4)
    return bleu4

In [None]:
# =========================
# HÀM GREEDY DECODE + LẤY ATTENTION
# =========================
def greedy_decode(encoder, decoder, arch, image_tensor,
                  max_len=20, gcn=None):
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        img = image_tensor.to(DEVICE)
        if img.ndim == 3:
            img = img.unsqueeze(0)
        global_vec, memory, grid = encoder(img)

        if arch == 'vit_tf_gnn' and gcn is not None:
            B = img.size(0)
            N = 5
            obj_feats = torch.randn(B, N, 256, device=img.device)
            adj = torch.ones(B, N, N, device=img.device)
            x_gnn = gcn(obj_feats, adj)
            graph_vec = x_gnn.mean(dim=1)
            bias = graph_vec.unsqueeze(1)
            if memory.size(-1) == bias.size(-1):
                memory = memory + bias

        bos_id = stoi[BOS]
        eos_id = stoi[EOS]
        ids = [bos_id]

        if arch == 'cnn_lstm_attn':
            h = torch.zeros(1, 1, decoder.lstm.hidden_dim, device=img.device)
            c = torch.zeros(1, 1, decoder.lstm.hidden_dim, device=img.device)
            tokens = []
            for t in range(max_len):
                inp = torch.tensor([[ids[-1]]], device=img.device)
                emb = decoder.emb(inp)
                h_t = h[-1]
                ctx, _ = decoder.attn(h_t, memory)
                x_t = torch.cat([emb.squeeze(1), ctx], dim=-1).unsqueeze(1)
                o, (h, c) = decoder.lstm(x_t, (h, c))
                logits = decoder.fc(o.squeeze(1))
                next_id = logits.argmax(-1).item()
                if next_id == eos_id:
                    break
                ids.append(next_id)
                tokens.append(itos[next_id])
            return tokens
        else:
            # Transformer decoder
            tokens = []
            inp = torch.tensor([[bos_id]], device=img.device)
            for t in range(max_len):
                logits, _ = decoder(inp, memory, return_attn=False)
                next_id = logits[0, -1].argmax(-1).item()
                if next_id == eos_id:
                    break
                inp = torch.cat([inp, torch.tensor([[next_id]], device=img.device)], dim=1)
                tokens.append(itos[next_id])
            return tokens

In [None]:
# =========================
# VIZ ATTENTION (CNN + LSTM + ATTENTION)
# =========================
def generate_with_attn(encoder, decoder, image_tensor,
                       stoi, itos, max_len=20, device="cuda"):
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        img = image_tensor.unsqueeze(0).to(device)
        global_vec, memory, grid = encoder(img)      # memory: 1×L×D
        H, W = grid if grid is not None else (7, 7)

        bos_id = stoi[BOS]
        eos_id = stoi[EOS]
        ids = [bos_id]
        tokens = []
        attn_maps = []    # list of (L,) numpy

        h = torch.zeros(1, 1, decoder.lstm.hidden_dim, device=device)
        c = torch.zeros(1, 1, decoder.lstm.hidden_dim, device=device)

        for t in range(max_len):
            inp = torch.tensor([[ids[-1]]], device=device)
            emb = decoder.emb(inp)
            h_t = h[-1]                      # 1×H
            ctx, att = decoder.attn(h_t, memory)  # att: 1×L
            attn_maps.append(att.squeeze(0).cpu().numpy())  # L

            x_t = torch.cat(
                [emb.squeeze(1), ctx], dim=-1
            ).unsqueeze(1)
            o, (h, c) = decoder.lstm(x_t, (h, c))
            logits = decoder.fc(o.squeeze(1))
            next_id = logits.argmax(-1).item()

            if next_id == eos_id:
                break
            ids.append(next_id)
            tokens.append(itos[next_id])

        return " ".join(tokens), tokens, attn_maps, (H, W)


def show_attention_on_image(img_pil, tokens, attn_maps,
                            H=7, W=7, cols=4):
    import math
    import numpy as np
    from PIL import Image as PILImage

    n = len(tokens)
    rows = math.ceil(n / cols)
    fig = plt.figure(figsize=(3*cols, 3*rows))

    for i, (word, attn_vec) in enumerate(zip(tokens, attn_maps)):
        ax = fig.add_subplot(rows, cols, i + 1)
        ax.set_title(word)

        attn = attn_vec.reshape(H, W)
        attn = attn / (attn.max() + 1e-8)
        attn = np.clip(attn, 0.0, 1.0)

        attn_img = PILImage.fromarray(
            (attn * 255).astype(np.uint8)
        ).resize(img_pil.size, resample=PILImage.BILINEAR)

        ax.imshow(img_pil)
        ax.imshow(attn_img, cmap="jet", alpha=0.4)
        ax.axis("off")

    plt.tight_layout()
    plt.show()


def attn_to_mask(attn_vec, H=7, W=7, threshold=0.6,
                 use_percentile=True):
    import numpy as np
    attn = attn_vec.reshape(H, W)
    attn = attn / (attn.max() + 1e-8)
    if use_percentile:
        thr_value = np.quantile(attn, threshold)
    else:
        thr_value = threshold
    mask = (attn >= thr_value).astype(np.float32)
    return mask


def show_attention_mask_on_image(img_pil, tokens, attn_maps,
                                 H=7, W=7, cols=4,
                                 threshold=0.6,
                                 use_percentile=True):
    import math
    import numpy as np
    from PIL import Image as PILImage

    n = len(tokens)
    rows = math.ceil(n / cols)
    fig = plt.figure(figsize=(3*cols, 3*rows))

    img_np = np.array(img_pil).astype(np.float32) / 255.0

    for i, (word, attn_vec) in enumerate(zip(tokens, attn_maps)):
        ax = fig.add_subplot(rows, cols, i + 1)
        ax.set_title(word)

        mask_hw = attn_to_mask(attn_vec, H, W,
                               threshold=threshold,
                               use_percentile=use_percentile)

        mask_img = PILImage.fromarray(
            (mask_hw * 255).astype(np.uint8)
        ).resize(img_pil.size, resample=PILImage.NEAREST)
        mask_np = np.array(mask_img).astype(np.float32) / 255.0
        mask_np = mask_np[..., None]

        bg = img_np * 0.3
        red_overlay = np.zeros_like(img_np)
        red_overlay[..., 0] = 1.0

        out = bg * (1 - mask_np) +               (img_np * 0.7 + red_overlay * 0.3) * mask_np

        ax.imshow(out)
        ax.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
# =========================
# DEMO TRAIN NHANH 1 EPOCH + EVAL BLEU (OPTIONAL)
# =========================
for epoch in range(EPOCHS):
    loss = train_one_epoch(train_loader, encoder, decoder,
                           ARCH, optimizer, gcn=gcn)
    print(f"Epoch {epoch+1}/{EPOCHS}, loss = {loss:.4f}")
    evaluate_bleu(val_loader, encoder, decoder, ARCH, gcn=gcn, max_samples=50)

In [None]:
# =========================
# DEMO INFERENCE TRÊN 1 ẢNH + VIZ ATTENTION (CHO ARCH = cnn_lstm_attn)
# =========================
test_img_path = None  # điền đường dẫn 1 ảnh cụ thể nếu muốn

if test_img_path is not None and ARCH == 'cnn_lstm_attn':
    img_pil = Image.open(test_img_path).convert('RGB')
    img_t = transform_cnn(img_pil)
    caption, tokens, attn_maps, (H, W) = generate_with_attn(
        encoder, decoder, img_t, stoi, itos, device=DEVICE
    )
    print("Caption:", caption)
    show_attention_on_image(img_pil, tokens, attn_maps, H, W)
    show_attention_mask_on_image(img_pil, tokens, attn_maps, H, W,
                                 threshold=0.6, use_percentile=True)
else:
    print("Đặt test_img_path và ARCH='cnn_lstm_attn' để viz attention.")

In [None]:
# =========================
# DEMO INFERENCE TRÊN CẢ THƯ MỤC ẢNH
# =========================
def run_inference_on_folder(img_dir, encoder, decoder,
                            arch, stoi, itos,
                            transform, gcn=None,
                            max_images=10):
    files = [f for f in os.listdir(img_dir)
             if f.lower().endswith(('.jpg','.jpeg','.png'))]
    files = files[:max_images]
    for fname in files:
        path = os.path.join(img_dir, fname)
        img_pil = Image.open(path).convert('RGB')
        img_t = transform(img_pil)
        tokens = greedy_decode(encoder, decoder, arch,
                               img_t.unsqueeze(0), gcn=gcn)
        caption = " ".join(tokens)
        plt.figure(figsize=(4,4))
        plt.imshow(img_pil)
        plt.axis('off')
        plt.title(caption)
        plt.show()

# Ví dụ chạy:
# run_inference_on_folder(IMG_DIR, encoder, decoder,
#                         ARCH, stoi, itos,
#                         transform_cnn, gcn=gcn, max_images=5)