# Captioner Diagnostics Notebook

Purpose: step-by-step checks to diagnose NaN losses, validation length, data consistency, and mixed precision issues for Gemma+DINO caption model.

You can run cells independently. If a cell raises/asserts, fix the noted issue then re-run.

Sections:
1. Environment & Versions
2. Project Path + Repro Settings
3. Dataloaders (train/val) Build
4. Inspect First Batch (shapes, channels, sample caption)
5. Channel / Grayscale Scan (optional fix)
6. Build Model + Param Count
7. Parameter NaN / Inf Scan
8. Single Forward (FP32) Without Lightning
9. Tokenization Sanity (empty / all-pad detection)
10. Autocast (FP16) Forward Test
11. Gradient Step (detect exploding grad / NaN)
12. Mini Lightning Trainer (few steps, limited val)
13. Full Dataset Channel Distribution Summary
14. Summary & Next Steps

In [1]:
# 1. Environment & Versions
import os, sys, platform, math, json, random
import torch, transformers, datasets as hfds, pytorch_lightning as pl
print('Python        :', sys.version.split()[0])
print('Platform      :', platform.platform())
print('Torch         :', torch.__version__)
print('Transformers  :', transformers.__version__)
print('Datasets      :', hfds.__version__)
print('Lightning     :', pl.__version__)
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU           :', torch.cuda.get_device_name(0))

Python        : 3.13.3
Platform      : Linux-6.14.0-29-generic-x86_64-with-glibc2.39
Torch         : 2.8.0+cu128
Transformers  : 4.56.1
Datasets      : 4.1.0
Lightning     : 2.5.5
CUDA available: True
GPU           : NVIDIA GeForce GTX 1650 Ti


In [2]:
# 2. Project Path + Repro Settings
import os, sys
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..')) if os.path.basename(os.getcwd()) == 'tests' else os.path.abspath(os.path.join(os.getcwd()))
if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT)
print('Project root:', PROJECT_ROOT)
SEED = 42
pl.seed_everything(SEED, workers=True)
os.environ.setdefault('TOKENIZERS_PARALLELISM','false')
os.environ.setdefault('WANDB_MODE','disabled')
HF_CACHE = os.environ.get('HF_HOME') or None
print('HF cache:', HF_CACHE)

Seed set to 42


Project root: /home/divyansh/Documents/SEM4/DL/Project
HF cache: None


In [3]:
# 3. Dataloaders (validation split used as small corpus)
from src.data.dataloader import make_coco_dataloader
BATCH_SIZE = int(os.environ.get('DIAG_BS', 2))
NUM_WORKERS = int(os.environ.get('DIAG_NUM_WORKERS', 2))
PIN_MEMORY = torch.cuda.is_available()
TRAIN_SPLIT = 'validation'
VAL_SPLIT = 'validation'
train_loader = make_coco_dataloader(split=TRAIN_SPLIT, batch_size=BATCH_SIZE, shuffle=True, caption_index=None, seed=SEED, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, cache_dir=HF_CACHE)
val_loader   = make_coco_dataloader(split=VAL_SPLIT,   batch_size=BATCH_SIZE, shuffle=False, caption_index=None, seed=SEED, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, cache_dir=HF_CACHE)
print('Train batches (full):', len(train_loader), '| Val batches:', len(val_loader))
print('Expected val steps per full validation:', len(val_loader))

Train batches (full): 2500 | Val batches: 2500
Expected val steps per full validation: 2500


In [4]:
# 4. Inspect First Batch
first_images, first_caps = next(iter(train_loader))
print('Batch image count:', len(first_images))
print('First caption sample:', first_caps[0][:120])
# Show channel modes if PIL Images
modes = [getattr(img, 'mode', 'NA') for img in first_images]
print('Image modes in batch:', modes)

Batch image count: 2
First caption sample: A woman standing in front of a fruit stand.
Image modes in batch: ['RGB', 'RGB']


In [5]:
# 5. Channel / Grayscale Scan (quick)
from collections import Counter
def is_grayscale(img):
    if hasattr(img, 'mode'):
        return img.mode in ('L','1','I','F')
    return False
gs = sum(is_grayscale(im) for im in first_images)
print(f'Grayscale in first batch: {gs}/{len(first_images)}')
if gs: print('Consider converting to RGB inside dataloader transform.')

Grayscale in first batch: 0/2


In [6]:
# 6. Build Model + Param Count
from src.models.caption_modelling import GemmaDinoImageCaptioner
from src.utils.training import LitCaptioner
model = GemmaDinoImageCaptioner(include_cls=True, include_registers=False, include_patches=False, freeze_gemma=True)
lit = LitCaptioner(model, optimizer_cfg={'lr':1e-4,'weight_decay':0.01,'betas':(0.9,0.999),'eps':1e-8})
trainable = sum(p.numel() for p in lit.parameters() if p.requires_grad)
total = sum(p.numel() for p in lit.parameters())
print(f'Trainable params: {trainable/1e6:.2f}M / Total: {total/1e6:.2f}M')

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Trainable params: 0.66M / Total: 297.45M


In [7]:
# 7. Parameter NaN / Inf Scan
bad = []
for n,p in lit.named_parameters():
    if torch.isnan(p).any(): bad.append((n,'NaN'))
    elif torch.isinf(p).any(): bad.append((n,'Inf'))
print('Param anomalies:' if bad else 'No NaN/Inf in parameters at init.', bad)

No NaN/Inf in parameters at init. []


In [8]:
# 8. Single Forward (FP32) Without Lightning
model.eval()
sample_images, sample_caps = next(iter(train_loader))
with torch.autocast(device_type='cuda', enabled=False) if torch.cuda.is_available() else torch.no_grad():
    out = model(sample_images, sample_caps)
loss = out.loss if hasattr(out,'loss') else out['loss']
print('Forward loss (fp32):', float(loss))
assert torch.isfinite(loss), 'Loss is not finite (fp32 forward).'
print('OK: finite loss.')

Forward loss (fp32): 6.289526462554932
OK: finite loss.


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  print('Forward loss (fp32):', float(loss))


In [11]:
# 9. Tokenization Sanity
tok = model.gemma_tokenizer(sample_caps, return_tensors='pt', padding=True)
ids = tok['input_ids']
pad_id = model.gemma_tokenizer.pad_token_id
empty_mask = (ids!=pad_id).sum(dim=1)==0
if empty_mask.any():
    print('Empty captions detected at indices:', torch.nonzero(empty_mask).view(-1).tolist())
else:
    print('All captions have at least one non-pad token.')
print('Tokenized shape:', ids.shape)

All captions have at least one non-pad token.
Tokenized shape: torch.Size([2, 12])


In [12]:
# 10. Autocast (FP16) Forward Test
if torch.cuda.is_available():
    model.train()
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        out16 = model(sample_images, sample_caps)
    loss16 = out16.loss if hasattr(out16,'loss') else out16['loss']
    print('Forward loss (fp16 mixed):', float(loss16))
    assert torch.isfinite(loss16), 'Loss is not finite under autocast.'
    print('OK: finite mixed-precision loss.')
else:
    print('CUDA not available; skipping mixed precision test.')

Forward loss (fp16 mixed): 6.268799304962158
OK: finite mixed-precision loss.


In [13]:
# 11. Gradient Step (detect exploding grad / NaN)
model.train()
optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-4, weight_decay=0.01)
sample_images2, sample_caps2 = next(iter(train_loader))
use_amp = torch.cuda.is_available()
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
def forward_loss(imgs, caps):
    out = model(imgs, caps)
    return out.loss if hasattr(out,'loss') else out['loss']
optim.zero_grad()
if use_amp:
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        l = forward_loss(sample_images2, sample_caps2)
    scaler.scale(l).backward()
    # Gradient anomaly detection
    found_nan = False
    for n,p in model.named_parameters():
        if p.grad is not None and torch.isnan(p.grad).any():
            print('NaN grad in', n); found_nan=True; break
    if not found_nan: print('No NaN grads (AMP).')
    scaler.step(optim); scaler.update()
else:
    l = forward_loss(sample_images2, sample_caps2)
    l.backward()
    for n,p in model.named_parameters():
        if p.grad is not None and torch.isnan(p.grad).any():
            print('NaN grad in', n); break
    optim.step()
print('Gradient step loss:', float(l))

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


No NaN grads (AMP).
Gradient step loss: 7.310398578643799


In [14]:
# 12. Mini Lightning Trainer (few steps, limit val)
from pytorch_lightning.callbacks import TQDMProgressBar
mini_lit = LitCaptioner(model, optimizer_cfg={'lr':1e-4,'weight_decay':0.01,'betas':(0.9,0.999),'eps':1e-8})
precision = '16-mixed' if torch.cuda.is_available() else '32-true'
trainer = pl.Trainer(accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1, max_steps=5, log_every_n_steps=1, precision=precision, limit_val_batches=2, enable_checkpointing=False, enable_model_summary=False, callbacks=[TQDMProgressBar(refresh_rate=1)])
trainer.fit(mini_lit, train_dataloaders=train_loader, val_dataloaders=val_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=5` reached.


In [15]:
# 13. Full Dataset Channel Distribution Summary (sampled)
import itertools
def count_channels(dataloader, max_batches=50):
    gray = 0; total = 0
    for i,(imgs,caps) in enumerate(dataloader):
        for im in imgs:
            if getattr(im,'mode',None) in ('L','1','I','F'): gray+=1
            total+=1
        if i+1>=max_batches: break
    return gray, total
gray_train, tot_train = count_channels(train_loader)
gray_val, tot_val = count_channels(val_loader)
print(f'Train grayscale: {gray_train}/{tot_train} ({100*gray_train/max(1,tot_train):.1f}%)')
print(f'Val   grayscale: {gray_val}/{tot_val} ({100*gray_val/max(1,tot_val):.1f}%)')
if gray_train or gray_val:
    print('NOTE: Convert grayscale -> RGB in dataloader to avoid channel mismatch issues.')

Train grayscale: 0/100 (0.0%)
Val   grayscale: 0/100 (0.0%)


In [16]:
# 14. Summary & Next Steps
print('Diagnostics complete. If no assertions failed, model/data pipeline is numerically stable for initial steps.')
print('Next: adjust LR, unfreeze more layers, or extend training steps as needed.')

Diagnostics complete. If no assertions failed, model/data pipeline is numerically stable for initial steps.
Next: adjust LR, unfreeze more layers, or extend training steps as needed.
