# ViL-Cap: Image Captioning with Vision-LSTM
This notebook trains the ViL-Cap model on COCO 2017 captions dataset.

**Dataset**: COCO 2017 Captions (kaggle.com/datasets/awsaf49/coco-2017-dataset)

**Architecture**:
- Encoder: Vision-LSTM (ViL) - bidirectional mLSTM for visual features
- Decoder: Causal mLSTM for caption generation (3 blocks)
- Fusion: Simple add/merge (no cross-attention, following Bi-LSTM paper)

**Training Notes**:
- This is a simplified training loop for demonstration
- For production: use mixed precision (bfloat16), torch.compile, better augmentations
- Recommended: Start with pretrained ViL encoder for faster convergence

In [None]:
# Clone repository
!git clone https://github.com/NX-AI/vision-lstm
%cd vision-lstm

In [None]:
# Install dependencies
!pip install -q einops transformers

In [None]:
# General imports
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
# Initialize device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"GPU: {torch.cuda.get_device_name()}")
else:
    device = torch.device("cpu")
print(f"Device: {device}")

## Create Dataset and DataLoaders

In [None]:
# Import dataset and collator
from src.ksuit.datasets import CocoCaptionsDataset
from src.ksuit.data.collators import CaptionCollator

# Image transforms (resize to 224x224 for ViL)
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create datasets
train_dataset = CocoCaptionsDataset(
    root="/kaggle/input/coco-2017-dataset/coco2017",
    split="train",
    return_all_captions=False,  # Random caption per image
)

val_dataset = CocoCaptionsDataset(
    root="/kaggle/input/coco-2017-dataset/coco2017",
    split="val",
    return_all_captions=True,  # All captions for evaluation
)

print(f"Train dataset: {len(train_dataset)} images")
print(f"Val dataset: {len(val_dataset)} images")

In [None]:
# Create collator and dataloaders
collator = CaptionCollator(transform=image_transform)

batch_size = 32
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collator,
    num_workers=2,
    drop_last=True,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collator,
    num_workers=2,
)

print(f"Train batches: {len(train_dataloader)}")
print(f"Val batches: {len(val_dataloader)}")

## Create Model

In [None]:
# Import model
from caption_lstm.model import ViLCap, ViLCapConfig

# Create config
config = ViLCapConfig(
    # Encoder (ViL-T configuration)
    encoder_dim=192,
    encoder_depth=24,
    encoder_input_shape=(3, 224, 224),
    encoder_patch_size=16,
    encoder_pooling="bilateral_avg",
    encoder_drop_path_rate=0.0,
    encoder_pretrained_path=None,  # Set to path if using pretrained encoder
    
    # Decoder
    decoder_dim=512,
    decoder_num_blocks=3,
    decoder_num_heads=4,
    decoder_dropout=0.2,
    max_caption_length=50,
    
    # Tokenizer
    tokenizer_model="bert-base-uncased",
)

# Create model
model = ViLCap(config).to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters()) / 1e6:.1f}M")
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters()) / 1e6:.1f}M")

## Training Setup

In [None]:
# Hyperparameters
epochs = 5
lr = 1e-4
weight_decay = 0.01

# Optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

total_updates = len(train_dataloader) * epochs
warmup_updates = int(total_updates * 0.1)

# Learning rate schedule (linear warmup + linear decay)
lrs = torch.cat([
    torch.linspace(0, lr, warmup_updates),
    torch.linspace(lr, 0, total_updates - warmup_updates),
])

print(f"Total updates: {total_updates}")
print(f"Warmup updates: {warmup_updates}")

## Training Loop

In [None]:
# Training loop
update = 0
train_losses = []
val_losses = []

pbar = tqdm(total=total_updates)
pbar.set_description("train_loss: ????? val_loss: ?????")

for epoch in range(epochs):
    # Training
    model.train()
    epoch_loss = 0
    
    for batch in train_dataloader:
        images = batch['images'].to(device)
        captions = batch['captions']
        
        # Schedule learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lrs[update]
        
        # Forward pass
        output = model(images, captions=captions, mode='train')
        logits = output['logits']
        target_ids = output['target_ids']
        attention_mask = output['attention_mask']
        
        # Compute loss (cross entropy)
        # Flatten for loss computation
        logits_flat = logits.reshape(-1, logits.size(-1))
        target_flat = target_ids.reshape(-1)
        
        # Compute loss only on non-padded tokens
        mask_flat = attention_mask.reshape(-1)
        loss = F.cross_entropy(logits_flat, target_flat, reduction='none')
        loss = (loss * mask_flat).sum() / mask_flat.sum()
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update
        optimizer.step()
        optimizer.zero_grad()
        
        # Logging
        train_losses.append(loss.item())
        epoch_loss += loss.item()
        update += 1
        
        pbar.update(1)
        pbar.set_description(f"train_loss: {loss.item():.4f}")
    
    # Validation
    model.eval()
    val_loss = 0
    num_val_batches = 0
    
    with torch.no_grad():
        for batch in val_dataloader:
            images = batch['images'].to(device)
            # For validation, take first caption from each image's caption list
            captions = [caps[0] if isinstance(caps, list) else caps for caps in batch['captions']]
            
            output = model(images, captions=captions, mode='train')
            logits = output['logits']
            target_ids = output['target_ids']
            attention_mask = output['attention_mask']
            
            logits_flat = logits.reshape(-1, logits.size(-1))
            target_flat = target_ids.reshape(-1)
            mask_flat = attention_mask.reshape(-1)
            
            loss = F.cross_entropy(logits_flat, target_flat, reduction='none')
            loss = (loss * mask_flat).sum() / mask_flat.sum()
            
            val_loss += loss.item()
            num_val_batches += 1
            
            # Limit validation batches for speed
            if num_val_batches >= 100:
                break
    
    val_loss /= num_val_batches
    val_losses.append(val_loss)
    
    print(f"\nEpoch {epoch+1}/{epochs} - Train Loss: {epoch_loss/len(train_dataloader):.4f}, Val Loss: {val_loss:.4f}")

pbar.close()

## Plot Training Curves

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

axes[0].plot(range(len(train_losses)), train_losses)
axes[0].set_xlabel('Updates')
axes[0].set_ylabel('Train Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True)

axes[1].plot(range(len(val_losses)), val_losses, marker='o')
axes[1].set_xlabel('Epochs')
axes[1].set_ylabel('Val Loss')
axes[1].set_title('Validation Loss')
axes[1].grid(True)

plt.tight_layout()
plt.show()

## Generate Captions on Test Images

In [None]:
# Generate captions for sample images
model.eval()

# Get a batch from validation set
sample_batch = next(iter(val_dataloader))
sample_images = sample_batch['images'][:8].to(device)
sample_gt_captions = sample_batch['captions'][:8]

# Generate captions
with torch.no_grad():
    generated_captions = model.generate_captions(sample_images, temperature=1.0)

# Display results
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

# Unnormalize images for display
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
sample_images_display = sample_images.cpu() * std + mean

for i in range(8):
    axes[i].imshow(sample_images_display[i].permute(1, 2, 0).clip(0, 1))
    axes[i].axis('off')
    
    # Get ground truth (first caption if list)
    gt = sample_gt_captions[i][0] if isinstance(sample_gt_captions[i], list) else sample_gt_captions[i]
    
    axes[i].set_title(f"Generated: {generated_captions[i]}\n\nGT: {gt}", fontsize=8, wrap=True)

plt.tight_layout()
plt.show()

## Save Model

In [None]:
# Save model checkpoint
torch.save({
    'epoch': epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'config': config,
}, 'vilcap_checkpoint.pth')

print("Model saved to vilcap_checkpoint.pth")