In [1]:
!git clone https://github.com/omar-A-hassan/caption-lstm

Cloning into 'caption-lstm'...
remote: Enumerating objects: 953, done.[K
remote: Counting objects: 100% (182/182), done.[K
remote: Compressing objects: 100% (57/57), done.[K
remote: Total 953 (delta 143), reused 146 (delta 125), pack-reused 771 (from 1)[K
Receiving objects: 100% (953/953), 1.64 MiB | 29.02 MiB/s, done.
Resolving deltas: 100% (419/419), done.


In [2]:
%cd caption-lstm

/kaggle/working/caption-lstm


In [3]:
!ls -R

.:
caption_lstm  eval.py	  LICENSE_APACHE  logs	     src	vision_lstm
docs	      hubconf.py  LICENSE_MIT	  README.md  tutorials

./caption_lstm:
decoder.py  fusion.py  __init__.py  mbr_decoder.py  model.py  tokenizer.py

./docs:
_config.yml  imgs  index.md

./docs/imgs:
flops_vs_performance.png  results_imagenet.png	  schematic.svg
results_ade20k.png	  results_tiny_small.png
results_base.png	  results_vtab1k.png

./logs:
clean_logs.py  pretrain

./logs/pretrain:
vil_b16_e400_finetune.log	vil_t16_e400_finetune.log
vil_b16_e400_pretrain_run1.log	vil_t16_e400_pretrain.log
vil_b16_e400_pretrain_run2.log	vil_t16_e800_finetune.log
vil_s16_e400_finetune.log	vil_t16_e800_pretrain_run1.log
vil_s16_e400_pretrain.log	vil_t16_e800_pretrain_run2.log

./src:
environment.yml  main_run_folder.py  main_train.py  setup     vislstm
ksuit		 main_sbatch.py      RUN.md	    SETUP.md

./src/ksuit:
callbacks  datasets	freezers      losses   optim		 runners
configs    distributed	i

In [4]:
!pip install -q einops transformers

In [5]:
!pip install -q kappamodules kappautils kappaschedules kappaconfig kappadata kappaprofiler

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.7/77.7 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.5/144.5 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m107.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m80.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m50.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

In [6]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("hf")


In [7]:
# General imports
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [8]:
# Initialize device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"GPU: {torch.cuda.get_device_name()}")
else:
    device = torch.device("cpu")
print(f"Device: {device}")

GPU: Tesla P100-PCIE-16GB
Device: cuda


In [9]:
# Point Python to the repo's src/ so 'ksuit' and 'vislstm' are top-level packages
import sys, os

REPO_ROOT = "/kaggle/working/caption-lstm"  # adjust if your cwd differs
SRC_DIR = os.path.join(REPO_ROOT, "src")
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

print("Added to sys.path:", SRC_DIR)


Added to sys.path: /kaggle/working/caption-lstm/src


In [10]:
# Import dataset and collator
from ksuit.datasets.coco_captions_dataset import CocoCaptionsDataset
from ksuit.data.collators.caption_collator import CaptionCollator

# Image transforms (resize to 224x224 for ViL)
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create datasets
train_dataset = CocoCaptionsDataset(
    root="/kaggle/input/coco-2017-dataset/coco2017",
    split="train",
    return_all_captions=False,  # Random caption per image
)

val_dataset = CocoCaptionsDataset(
    root="/kaggle/input/coco-2017-dataset/coco2017",
    split="val",
    return_all_captions=True,  # All captions for evaluation
)

print(f"Train dataset: {len(train_dataset)} images")
print(f"Val dataset: {len(val_dataset)} images")

Train dataset: 118287 images
Val dataset: 5000 images


In [11]:
# Create collator and dataloaders
collator = CaptionCollator(transform=image_transform)

batch_size = 8
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collator,
    num_workers=2,
    drop_last=True,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collator,
    num_workers=2,
)

print(f"Train batches: {len(train_dataloader)}")
print(f"Val batches: {len(val_dataloader)}")

Train batches: 14785
Val batches: 625


In [12]:
# Download the updated model.py with fixed pooling support
!wget -q https://raw.githubusercontent.com/omar-A-hassan/caption-lstm/main/caption_lstm/model.py \
      -O /kaggle/working/caption-lstm/caption_lstm/model.py

print("Updated model.py")

Updated model.py


In [13]:
# Reload the model module to get the updated code
import sys
import importlib

# Clear cached modules
if 'caption_lstm.model' in sys.modules:
    del sys.modules['caption_lstm.model']
if 'caption_lstm' in sys.modules:
    del sys.modules['caption_lstm']

# Re-import
from caption_lstm.model import ViLCap, ViLCapConfig

print("Reloaded model module with updated pooling support")

2025-10-11 22:56:14.950181: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760223375.134811      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760223375.190734      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Reloaded model module with updated pooling support


In [14]:
# Import model
from caption_lstm.model import ViLCap, ViLCapConfig

# Create config
config = ViLCapConfig(
    # Encoder (ViL-T configuration)
    encoder_dim=192,
    encoder_depth=24,
    encoder_input_shape=(3, 224, 224),
    encoder_patch_size=16,
    encoder_pooling="bilateral_avg",
    encoder_drop_path_rate=0.0,
    encoder_pretrained_path="/kaggle/input/vil-encoder/vil2_tiny16_e400_in1k.th",  # Set to path if using pretrained encoder
    
    # Decoder
    decoder_dim=512,
    decoder_num_blocks=3,
    decoder_num_heads=4,
    decoder_dropout=0.2,
    max_caption_length=50,
    
    # Tokenizer
    tokenizer_model="bert-base-uncased",
)

# Create model
model = ViLCap(config).to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters()) / 1e6:.1f}M")
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters()) / 1e6:.1f}M")



Loaded encoder weights from /kaggle/input/vil-encoder/vil2_tiny16_e400_in1k.th
mLSTMLayerConfig(proj_factor=2.0, round_proj_up_dim_up=True, round_proj_up_to_multiple_of=64, _proj_up_dim=1024, conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4, bidirectional=False, quaddirectional=False, sharedirs=False, alternation=None, layerscale=None, use_conv2d=False, use_v_conv=False, share_conv=True, embedding_dim=512, bias=False, dropout=0.0, context_length=50, _num_blocks=3, _inner_embedding_dim=1024)
mLSTMLayerConfig(proj_factor=2.0, round_proj_up_dim_up=True, round_proj_up_to_multiple_of=64, _proj_up_dim=1024, conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4, bidirectional=False, quaddirectional=False, sharedirs=False, alternation=None, layerscale=None, use_conv2d=False, use_v_conv=False, share_conv=True, embedding_dim=512, bias=False, dropout=0.0, context_length=50, _num_blocks=3, _inner_embedding_dim=1024)
mLSTMLayerConfig(proj_factor=2.0, round_proj_up_dim_up=True, round_proj_u

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Loading pretrained BERT embeddings...


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

  BERT embedding dim: 768
  Decoder embedding dim: 512
  Vocab size: 30523
  Using projection: 768 -> 512
  ✓ Projected and copied 30522 embeddings
  ✓ 1 new tokens keep random initialization
✓ Pretrained embeddings initialized successfully!
Total parameters: 42.2M
Encoder parameters: 5.9M
Decoder parameters: 36.2M


In [15]:
# Hyperparameters
epochs = 5
lr = 1e-4
weight_decay = 0.01

# Optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

total_updates = len(train_dataloader) * epochs
warmup_updates = int(total_updates * 0.1)

# Learning rate schedule (linear warmup + linear decay)
lrs = torch.cat([
    torch.linspace(0, lr, warmup_updates),
    torch.linspace(lr, 0, total_updates - warmup_updates),
])

print(f"Total updates: {total_updates}")
print(f"Warmup updates: {warmup_updates}")

Total updates: 73925
Warmup updates: 7392


In [16]:
# Training loop
update = 0
train_losses = []
val_losses = []

pbar = tqdm(total=total_updates)
pbar.set_description("train_loss: ????? val_loss: ?????")

for epoch in range(epochs):
    # Training
    model.train()
    epoch_loss = 0
    
    for batch in train_dataloader:
        images = batch['images'].to(device)
        captions = batch['captions']
        
        # Schedule learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lrs[update]
        
        # Forward pass
        output = model(images, captions=captions, mode='train')
        logits = output['logits']
        target_ids = output['target_ids']
        attention_mask = output['attention_mask']
        
        # Compute loss (cross entropy)
        # Flatten for loss computation
        logits_flat = logits.reshape(-1, logits.size(-1))
        target_flat = target_ids.reshape(-1)
        
        # Compute loss only on non-padded tokens
        mask_flat = attention_mask.reshape(-1)
        loss = F.cross_entropy(logits_flat, target_flat, reduction='none')
        loss = (loss * mask_flat).sum() / mask_flat.sum()
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update
        optimizer.step()
        optimizer.zero_grad()
        
        # Logging
        train_losses.append(loss.item())
        epoch_loss += loss.item()
        update += 1
        
        pbar.update(1)
        pbar.set_description(f"train_loss: {loss.item():.4f}")
    
    # Validation
    model.eval()
    val_loss = 0
    num_val_batches = 0
    
    with torch.no_grad():
        for batch in val_dataloader:
            images = batch['images'].to(device)
            # For validation, take first caption from each image's caption list
            captions = [caps[0] if isinstance(caps, list) else caps for caps in batch['captions']]
            
            output = model(images, captions=captions, mode='train')
            logits = output['logits']
            target_ids = output['target_ids']
            attention_mask = output['attention_mask']
            
            logits_flat = logits.reshape(-1, logits.size(-1))
            target_flat = target_ids.reshape(-1)
            mask_flat = attention_mask.reshape(-1)
            
            loss = F.cross_entropy(logits_flat, target_flat, reduction='none')
            loss = (loss * mask_flat).sum() / mask_flat.sum()
            
            val_loss += loss.item()
            num_val_batches += 1
            
            # Limit validation batches for speed
            if num_val_batches >= 100:
                break
    
    val_loss /= num_val_batches
    val_losses.append(val_loss)
    
    print(f"\nEpoch {epoch+1}/{epochs} - Train Loss: {epoch_loss/len(train_dataloader):.4f}, Val Loss: {val_loss:.4f}")

pbar.close()

  0%|          | 0/73925 [00:00<?, ?it/s]

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

axes[0].plot(range(len(train_losses)), train_losses)
axes[0].set_xlabel('Updates')
axes[0].set_ylabel('Train Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True)

axes[1].plot(range(len(val_losses)), val_losses, marker='o')
axes[1].set_xlabel('Epochs')
axes[1].set_ylabel('Val Loss')
axes[1].set_title('Validation Loss')
axes[1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Generate captions for sample images
model.eval()

# Get a batch from validation set
sample_batch = next(iter(val_dataloader))
sample_images = sample_batch['images'][:8].to(device)
sample_gt_captions = sample_batch['captions'][:8]

# Generate captions
with torch.no_grad():
    generated_captions = model.generate_captions(sample_images, temperature=1.0)

# Display results
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

# Unnormalize images for display
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
sample_images_display = sample_images.cpu() * std + mean

for i in range(8):
    axes[i].imshow(sample_images_display[i].permute(1, 2, 0).clip(0, 1))
    axes[i].axis('off')
    
    # Get ground truth (first caption if list)
    gt = sample_gt_captions[i][0] if isinstance(sample_gt_captions[i], list) else sample_gt_captions[i]
    
    axes[i].set_title(f"Generated: {generated_captions[i]}\n\nGT: {gt}", fontsize=8, wrap=True)

plt.tight_layout()
plt.show()

In [None]:
# Download the updated mbr_decoder.py
!wget -q https://raw.githubusercontent.com/omar-A-hassan/caption-lstm/main/caption_lstm/mbr_decoder.py \
      -O /kaggle/working/caption-lstm/caption_lstm/mbr_decoder.py

print("Updated mbr_decoder.py")

In [None]:
# Reload the mbr_decoder module to get the updated code
import importlib
import caption_lstm.mbr_decoder
importlib.reload(caption_lstm.mbr_decoder)

print("Reloaded mbr_decoder module with updated code")

In [None]:
# MBR Decoding: Generate better captions using Minimum Bayes Risk
# First install mbrs library
!pip install -q mbrs

from caption_lstm.mbr_decoder import MBRCaptionDecoder

# Initialize MBR decoder
print("Initializing MBR decoder with COMET metric...")
mbr_decoder = MBRCaptionDecoder(
    model=model,
    num_candidates=16,  # Generate 16 candidates per image
    metric_name="comet",
    metric_model="Unbabel/wmt22-comet-da",
    batch_size=64,
    fp16=True,
    temperature=1.0,
    top_k=50,
)

# Generate captions using MBR
print("Generating captions with MBR decoding (this may take a few minutes)...")
with torch.no_grad():
    mbr_captions = mbr_decoder.decode(sample_images)

# Display MBR results - one image per row (vertical layout)
fig, axes = plt.subplots(8, 1, figsize=(12, 40))  # 8 rows, 1 column

for i in range(8):
    axes[i].imshow(sample_images_display[i].permute(1, 2, 0).clip(0, 1))
    axes[i].axis('off')
    
    # Get ground truth (first caption if list)
    gt = sample_gt_captions[i][0] if isinstance(sample_gt_captions[i], list) else sample_gt_captions[i]
    
    axes[i].set_title(
        f"MBR: {mbr_captions[i]}\n\n"
        f"Simple: {generated_captions[i]}\n\n"
        f"GT: {gt}",
        fontsize=11,
        wrap=True,
        pad=20,
        loc='left'  # Left-align for readability
    )

# Add spacing between images
plt.subplots_adjust(hspace=0.4)
plt.tight_layout()
plt.show()

print("\nMBR Decoding complete!")
print("MBR generates multiple candidates and selects the best one based on COMET metric.")

In [None]:
# Save model checkpoint
torch.save({
    'epoch': epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'config': config,
}, 'vilcap_checkpoint.pth')

print("Model saved to vilcap_checkpoint.pth")