In [1]:
from LSTM_Attn import *

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
valid_df = build_valid_df(CSV_PATH, IMG_ROOT)
if valid_df.empty:
    print("[WARN] No valid rows found; check paths and PNG conversion.")
labels_as_str = valid_df[TEXT_COL].astype(str).tolist()
tokenizer = build_tokenizer_from_labels(labels_as_str)
pad_id = getattr(tokenizer, "pad_token_id", 0)
bos_id = getattr(tokenizer, "bos_token_id", 1)
eos_id = getattr(tokenizer, "eos_token_id", 2)
tf = dino_image_transform(img_size=516)
ds = CheXpertDataset(img_root=IMG_ROOT, csv=valid_df, transform=tf, text_col=TEXT_COL)
collate_fn = CaptionCollate(tokenizer, pad_id)
is_windows = os.name == "nt"
num_workers = 0 if is_windows else 2
persistent_workers = False if num_workers == 0 else True
loader = DataLoader(
    ds,
    batch_size=8,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=persistent_workers,
    collate_fn=collate_fn
)
train_ds = torch.utils.data.Subset(ds, range(0, 80))#int(len(ds)*.8)))
valid_ds = torch.utils.data.Subset(ds, range(80, 160))#int(len(ds)*.8), len(ds)))
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_ds, batch_size=8, shuffle=False, collate_fn=collate_fn)
D_IMG = 384
model = DinoLSTMAttnCaptioner(
    vocab_size=tokenizer.vocab_size,
    d_img=D_IMG,
    d_h=512,
    pad_id=pad_id,
    dino_model_id="facebook/dinov3-vits16-pretrain-lvd1689m",
    freeze_dino=True,
).to(device)
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2
)


Using device: cuda
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG


In [3]:
for epoch in range(100):
    slice_train_loader = islice(train_loader, 10)
    slice_valid_loader = islice(valid_loader, 10)
    train_stats = train_one_epoch(model, slice_train_loader, optimizer, device, pad_id, num_batches=10, grad_clip=1.0)
    val_stats = evaluate(model, slice_valid_loader, device, pad_id, num_batches=10)
    print(f"Epoch {epoch + 1}: Train Loss={train_stats['loss']:.4f}, PPL={train_stats['ppl']:.2f} | "
            f"Val Loss={val_stats['val_loss']:.4f}, Val PPL={val_stats['val_ppl']:.2f}")
test_loader_sliced = iter(valid_loader)
with torch.no_grad():
    for batch in test_loader_sliced:
        pixel_values, ids_loader, paths, raw_labels = batch
        pixel_values = pixel_values.to(device)
        gen_ids = model.generate(
            pixel_values=pixel_values,
            bos_id=bos_id, eos_id=eos_id,
            max_new_tokens=50, top_p=0.8, temperature=0.9, greedy=False
        )
        print("Predictions test:")
        for i in range(gen_ids.size(0)):
            print(f"\nTEST GEN {i+1}:", tokenizer.decode(gen_ids[i].tolist()))
            print(f"TEST TARGET {i+1}:", tokenizer.decode(ids_loader[i].tolist()))
        # Free batch memory
        del pixel_values, ids_loader, paths, raw_labels, gen_ids
        torch.cuda.empty_cache()
        break

Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Epoch 1: Train Loss=9.4246, PPL=15369.86 | Val Loss=8.1643, Val PPL=3575.08


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Epoch 2: Train Loss=7.2939, PPL=1726.83 | Val Loss=6.4971, Val PPL=698.94


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 3: Train Loss=5.7195, PPL=328.88 | Val Loss=5.6024, Val PPL=297.61


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 4: Train Loss=4.8139, PPL=124.56 | Val Loss=5.2023, Val PPL=204.75


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch 5: Train Loss=4.2250, PPL=71.30 | Val Loss=4.9525, Val PPL=163.04


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]


Epoch 6: Train Loss=3.7656, PPL=45.44 | Val Loss=4.7822, Val PPL=138.02


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch 7: Train Loss=3.3966, PPL=31.08 | Val Loss=4.6528, Val PPL=121.73


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 8: Train Loss=3.1115, PPL=22.71 | Val Loss=4.5631, Val PPL=111.49


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 9: Train Loss=2.8585, PPL=17.88 | Val Loss=4.5011, Val PPL=105.08


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 10: Train Loss=2.6499, PPL=14.37 | Val Loss=4.4623, Val PPL=101.04


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]


Epoch 11: Train Loss=2.4560, PPL=11.78 | Val Loss=4.4319, Val PPL=98.31


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch 12: Train Loss=2.2920, PPL=10.08 | Val Loss=4.4182, Val PPL=97.18


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 13: Train Loss=2.1473, PPL=8.72 | Val Loss=4.4225, Val PPL=97.76


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 14: Train Loss=2.0507, PPL=7.83 | Val Loss=4.4035, Val PPL=96.51


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 15: Train Loss=1.9508, PPL=7.12 | Val Loss=4.4063, Val PPL=96.94


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 16: Train Loss=1.8709, PPL=6.51 | Val Loss=4.4180, Val PPL=98.67


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 17: Train Loss=1.7968, PPL=6.06 | Val Loss=4.4242, Val PPL=99.17


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 18: Train Loss=1.7236, PPL=5.62 | Val Loss=4.4368, Val PPL=100.30


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.72it/s]


Epoch 19: Train Loss=1.6788, PPL=5.37 | Val Loss=4.4571, Val PPL=102.47


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 20: Train Loss=1.6314, PPL=5.12 | Val Loss=4.4593, Val PPL=103.05


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 21: Train Loss=1.5976, PPL=4.94 | Val Loss=4.4812, Val PPL=105.17


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 22: Train Loss=1.5646, PPL=4.78 | Val Loss=4.5080, Val PPL=107.81


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 23: Train Loss=1.5412, PPL=4.67 | Val Loss=4.5170, Val PPL=109.23


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 24: Train Loss=1.5189, PPL=4.57 | Val Loss=4.5354, Val PPL=111.26


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 25: Train Loss=1.4994, PPL=4.48 | Val Loss=4.5427, Val PPL=111.95


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 26: Train Loss=1.4878, PPL=4.43 | Val Loss=4.5643, Val PPL=114.54


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 27: Train Loss=1.4757, PPL=4.37 | Val Loss=4.5810, Val PPL=116.63


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 28: Train Loss=1.4648, PPL=4.33 | Val Loss=4.5931, Val PPL=118.22


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 29: Train Loss=1.4550, PPL=4.28 | Val Loss=4.6156, Val PPL=121.04


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 30: Train Loss=1.4453, PPL=4.24 | Val Loss=4.6329, Val PPL=122.83


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch 31: Train Loss=1.4426, PPL=4.23 | Val Loss=4.6483, Val PPL=125.73


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch 32: Train Loss=1.4370, PPL=4.21 | Val Loss=4.6611, Val PPL=127.08


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch 33: Train Loss=1.4315, PPL=4.18 | Val Loss=4.6796, Val PPL=129.09


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 34: Train Loss=1.4266, PPL=4.16 | Val Loss=4.6990, Val PPL=132.19


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch 35: Train Loss=1.4232, PPL=4.15 | Val Loss=4.7162, Val PPL=134.46


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 36: Train Loss=1.4210, PPL=4.14 | Val Loss=4.7270, Val PPL=136.03


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 37: Train Loss=1.4181, PPL=4.13 | Val Loss=4.7266, Val PPL=135.85


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch 38: Train Loss=1.4155, PPL=4.12 | Val Loss=4.7503, Val PPL=139.57


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch 39: Train Loss=1.4134, PPL=4.11 | Val Loss=4.7667, Val PPL=141.57


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 40: Train Loss=1.4124, PPL=4.11 | Val Loss=4.7635, Val PPL=141.18


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 41: Train Loss=1.4101, PPL=4.10 | Val Loss=4.7738, Val PPL=143.16


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 42: Train Loss=1.4091, PPL=4.09 | Val Loss=4.8011, Val PPL=146.65


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 43: Train Loss=1.4069, PPL=4.08 | Val Loss=4.8145, Val PPL=148.89


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 44: Train Loss=1.4050, PPL=4.08 | Val Loss=4.8226, Val PPL=150.35


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 45: Train Loss=1.4043, PPL=4.07 | Val Loss=4.8248, Val PPL=150.71


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 46: Train Loss=1.4031, PPL=4.07 | Val Loss=4.8316, Val PPL=151.27


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 47: Train Loss=1.4015, PPL=4.06 | Val Loss=4.8505, Val PPL=154.47


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 48: Train Loss=1.4010, PPL=4.06 | Val Loss=4.8576, Val PPL=155.33


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch 49: Train Loss=1.3997, PPL=4.05 | Val Loss=4.8422, Val PPL=153.10


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 50: Train Loss=1.4001, PPL=4.06 | Val Loss=4.8716, Val PPL=157.73


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]


Epoch 51: Train Loss=1.3991, PPL=4.05 | Val Loss=4.8783, Val PPL=158.88


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 52: Train Loss=1.3988, PPL=4.05 | Val Loss=4.8837, Val PPL=160.20


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 53: Train Loss=1.3986, PPL=4.05 | Val Loss=4.8691, Val PPL=157.03


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 54: Train Loss=1.3985, PPL=4.05 | Val Loss=4.8951, Val PPL=162.03


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 55: Train Loss=1.3963, PPL=4.04 | Val Loss=4.8971, Val PPL=162.15


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 56: Train Loss=1.3963, PPL=4.04 | Val Loss=4.9070, Val PPL=164.51


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Epoch 57: Train Loss=1.3952, PPL=4.04 | Val Loss=4.9184, Val PPL=165.80


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch 58: Train Loss=1.3935, PPL=4.03 | Val Loss=4.9170, Val PPL=165.42


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]


Epoch 59: Train Loss=1.3924, PPL=4.02 | Val Loss=4.9407, Val PPL=170.61


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 60: Train Loss=1.3925, PPL=4.02 | Val Loss=4.9471, Val PPL=171.15


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 61: Train Loss=1.3925, PPL=4.02 | Val Loss=4.9589, Val PPL=172.91


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 62: Train Loss=1.3930, PPL=4.03 | Val Loss=4.9567, Val PPL=172.44


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch 63: Train Loss=1.3903, PPL=4.02 | Val Loss=4.9669, Val PPL=174.02


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch 64: Train Loss=1.3903, PPL=4.02 | Val Loss=4.9747, Val PPL=176.15


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 65: Train Loss=1.3892, PPL=4.01 | Val Loss=4.9774, Val PPL=177.13


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch 66: Train Loss=1.3886, PPL=4.01 | Val Loss=4.9934, Val PPL=179.62


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 67: Train Loss=1.3884, PPL=4.01 | Val Loss=4.9947, Val PPL=180.42


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 68: Train Loss=1.3873, PPL=4.00 | Val Loss=5.0022, Val PPL=181.48


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 69: Train Loss=1.3872, PPL=4.00 | Val Loss=4.9861, Val PPL=178.67


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch 70: Train Loss=1.3869, PPL=4.00 | Val Loss=5.0109, Val PPL=182.94


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch 71: Train Loss=1.3865, PPL=4.00 | Val Loss=4.9987, Val PPL=180.50


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 72: Train Loss=1.3857, PPL=4.00 | Val Loss=5.0376, Val PPL=188.16


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 73: Train Loss=1.3858, PPL=4.00 | Val Loss=5.0183, Val PPL=184.79


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 74: Train Loss=1.3860, PPL=4.00 | Val Loss=5.0274, Val PPL=185.70


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 75: Train Loss=1.3854, PPL=4.00 | Val Loss=5.0315, Val PPL=186.96


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Epoch 76: Train Loss=1.3848, PPL=3.99 | Val Loss=5.0455, Val PPL=189.38


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 77: Train Loss=1.3850, PPL=3.99 | Val Loss=5.0340, Val PPL=188.06


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 78: Train Loss=1.3844, PPL=3.99 | Val Loss=5.0661, Val PPL=192.55


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 79: Train Loss=1.3843, PPL=3.99 | Val Loss=5.0481, Val PPL=190.81


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 80: Train Loss=1.3836, PPL=3.99 | Val Loss=5.0730, Val PPL=194.58


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 81: Train Loss=1.3837, PPL=3.99 | Val Loss=5.0623, Val PPL=193.31


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 82: Train Loss=1.3832, PPL=3.99 | Val Loss=5.0814, Val PPL=196.27


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Epoch 83: Train Loss=1.3831, PPL=3.99 | Val Loss=5.0709, Val PPL=194.44


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 84: Train Loss=1.3827, PPL=3.99 | Val Loss=5.0919, Val PPL=198.25


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 85: Train Loss=1.3825, PPL=3.98 | Val Loss=5.0794, Val PPL=195.94


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch 86: Train Loss=1.3821, PPL=3.98 | Val Loss=5.0746, Val PPL=195.05


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 87: Train Loss=1.3825, PPL=3.99 | Val Loss=5.1035, Val PPL=201.79


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Epoch 88: Train Loss=1.3824, PPL=3.98 | Val Loss=5.0920, Val PPL=199.22


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch 89: Train Loss=1.3816, PPL=3.98 | Val Loss=5.0746, Val PPL=195.69


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch 90: Train Loss=1.3815, PPL=3.98 | Val Loss=5.1228, Val PPL=205.05


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 91: Train Loss=1.3815, PPL=3.98 | Val Loss=5.0985, Val PPL=201.77


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 92: Train Loss=1.3810, PPL=3.98 | Val Loss=5.1017, Val PPL=201.30


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 93: Train Loss=1.3810, PPL=3.98 | Val Loss=5.1224, Val PPL=205.10


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 94: Train Loss=1.3805, PPL=3.98 | Val Loss=5.1297, Val PPL=206.47


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 95: Train Loss=1.3802, PPL=3.98 | Val Loss=5.1241, Val PPL=205.84


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 96: Train Loss=1.3802, PPL=3.98 | Val Loss=5.1225, Val PPL=205.21


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 97: Train Loss=1.3803, PPL=3.98 | Val Loss=5.1305, Val PPL=207.61


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch 98: Train Loss=1.3796, PPL=3.97 | Val Loss=5.1346, Val PPL=208.29


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch 99: Train Loss=1.3795, PPL=3.97 | Val Loss=5.1310, Val PPL=208.23


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch 100: Train Loss=1.3791, PPL=3.97 | Val Loss=5.1423, Val PPL=210.00
Predictions test:

TEST GEN 1: 1. supine frontal view of the chest demonstrates interval removal of the enteric tube. the remaining right ij catheter and surgical materials are stable. 2. the heart is moderately enlarged and mitral annular calcification as well as
TEST TARGET 1: 1. stable right internal jugular central venous catheter, prosthetic heart valve, and median sternotomy wires. 2. loculated right pleural effusion, low lung volumes in the right hemithorax, and mild cardiomegaly are unchanged. 3. slight increase in diffuse interstitial opacities, likely reflecting underlying mild pulmonary edema increased.

TEST GEN 2: 1. supine frontal view of the chest demonstrates interval removal of the enteric tube. the remaining right ij catheter and surgical materials are stable. 2. the heart is moderately enlarged and mitral annular calcification as well as
TEST TARGET 2: 1. stable right internal jugular central ve

In [4]:
# Installs you might need:
# !pip install -U nltk rouge-score bert-score f1chexbert requests appdirs

import os, shutil, requests
import torch, numpy as np
from appdirs import user_cache_dir
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from nltk.tokenize import wordpunct_tokenize
from rouge_score import rouge_scorer

# ---- Optional imports (guarded) ----
try:
    from bert_score import score as bertscore_score
except Exception:
    bertscore_score = None

# ---- CheXbert cache utils ----
CHEXBERT_URLS = [
    # Official Stanford Box link (may redirect / require retries)
    "https://stanfordmedicine.box.com/shared/static/c3stck6w6dol3h36grdc97xoydzxd7w9",
    # Hugging Face mirror (often more reliable)
    "https://huggingface.co/StanfordAIMI/RRG_scorers/resolve/main/chexbert.pth",
]

def default_chexbert_cache_dir():
    # Matches f1chexbert's default on Windows:
    # C:\Users\<YOU>\AppData\Local\chexbert\chexbert\Cache
    return os.path.join(user_cache_dir("chexbert", "chexbert"), "Cache")

def chexbert_ckpt_path(cache_dir=None, filename="chexbert.pth"):
    cache_dir = cache_dir or default_chexbert_cache_dir()
    return os.path.join(cache_dir, filename)

def delete_chexbert_weights(cache_dir=None, filename="chexbert.pth"):
    """Delete existing CheXbert weights if present."""
    path = chexbert_ckpt_path(cache_dir, filename)
    if os.path.exists(path):
        try:
            os.remove(path)
            print(f"[CheXbert] Deleted old weights at: {path}")
        except Exception as e:
            print(f"[CheXbert] Could not delete {path}: {e}")
    else:
        print("[CheXbert] No existing weights to delete.")

def ensure_chexbert_weights(cache_dir=None, filename="chexbert.pth", force=False, min_bytes=10_000_000):
    """Ensure weights exist. If force=True, redownload even if file exists."""
    cache_dir = cache_dir or default_chexbert_cache_dir()
    os.makedirs(cache_dir, exist_ok=True)
    dst = os.path.join(cache_dir, filename)

    if not force and os.path.exists(dst) and os.path.getsize(dst) > min_bytes:
        return dst  # looks valid

    # (Re)download
    for url in CHEXBERT_URLS:
        try:
            print(f"[CheXbert] Downloading weights from: {url}")
            with requests.get(url, stream=True, timeout=120, allow_redirects=True) as r:
                r.raise_for_status()
                tmp = dst + ".tmp"
                with open(tmp, "wb") as f:
                    for chunk in r.iter_content(chunk_size=1024 * 1024):
                        if chunk:
                            f.write(chunk)
                os.replace(tmp, dst)
            if os.path.getsize(dst) > min_bytes:
                print(f"[CheXbert] Weights saved to: {dst}")
                return dst
            else:
                print(f"[CheXbert] Downloaded file too small ({os.path.getsize(dst)} bytes), trying next mirror...")
        except Exception as e:
            print(f"[CheXbert] Download failed from {url}: {e}")
    return None

# ---- Delete then (re)download weights ----
delete_chexbert_weights()
_ckpt_path = ensure_chexbert_weights(force=True)

# ---- Try to load CheXbert evaluator ----
try:
    from f1chexbert import F1CheXbert
    _chexbert = F1CheXbert() if _ckpt_path else None
    if _chexbert is None:
        print("[Info] CheXbert unavailable (no weights). Place chexbert.pth in the cache dir above.")
except Exception as e:
    _chexbert = None
    print(f"[Info] CheXbert not available: {e}")

# ---- Helpers ----
def tok(s: str):
    return [w for w in wordpunct_tokenize(s.lower()) if w.strip()]

_warned_bert = False
def safe_bertscore(pred: str, tgt: str) -> float:
    global _warned_bert
    if bertscore_score is None:
        if not _warned_bert:
            print("[Info] bert-score not installed/available; BERTScore will be NaN. `pip install bert-score`.")
            _warned_bert = True
        return float("nan")
    P, R, F = bertscore_score([pred], [tgt], lang="en", rescale_with_baseline=True)
    return float(F.mean().item())

# ---- Metric setup ----
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
smooth = SmoothingFunction().method1

bleu_list, rougeL_list, meteor_list, bert_list = [], [], [], []
gens_all, tgts_all = [], []

# ===== Evaluation loop (single batch), assumes: valid_loader, model, tokenizer, device, bos_id, eos_id =====
test_loader_sliced = iter(valid_loader)
with torch.no_grad():
    for batch in test_loader_sliced:
        pixel_values, ids_loader, paths, raw_labels = batch
        pixel_values = pixel_values.to(device)

        gen_ids = model.generate(
            pixel_values=pixel_values,
            bos_id=bos_id, eos_id=eos_id,
            max_new_tokens=50, top_p=0.8, temperature=0.9, greedy=False
        )

        print("Predictions test:")
        for i in range(gen_ids.size(0)):
            # Skip special tokens to avoid artifacts
            gen_text = tokenizer.decode(gen_ids[i].tolist())
            tgt_text = tokenizer.decode(ids_loader[i].tolist())

            print(f"\nTEST GEN {i+1}:", gen_text)
            print(f"TEST TARGET {i+1}:", tgt_text)

            # BLEU
            bleu = sentence_bleu([tgt_text.split()], gen_text.split(), smoothing_function=smooth)
            # ROUGE-L (F)
            rougeL = scorer.score(tgt_text, gen_text)['rougeL'].fmeasure
            # METEOR (token lists)
            try:
                meteor = meteor_score([tok(tgt_text)], tok(gen_text))
            except TypeError:
                meteor = meteor_score([tgt_text], gen_text)
            # BERTScore
            bert = safe_bertscore(gen_text, tgt_text)

            print(f"BLEU: {bleu:.4f} | ROUGE-L: {rougeL:.4f} | METEOR: {meteor:.4f} | BERTScore(F1): {bert:.4f}")

            bleu_list.append(bleu); rougeL_list.append(rougeL); meteor_list.append(meteor); bert_list.append(bert)
            gens_all.append(gen_text); tgts_all.append(tgt_text)

        def nanmean(x):
            arr = np.array(x, dtype=float)
            return float(np.nanmean(arr)) if arr.size else float("nan")

        print("\n--- Batch means (text metrics) ---")
        print(f"BLEU: {nanmean(bleu_list):.4f} | ROUGE-L: {nanmean(rougeL_list):.4f} | "
              f"METEOR: {nanmean(meteor_list):.4f} | BERTScore(F1): {nanmean(bert_list):.4f}")

        # ---- CheXbert (batch clinical metric) ----
        if _chexbert is not None and gens_all and len(gens_all) == len(tgts_all):
            try:
                chex_accuracy, chex_acc_not_avg, chex_report, chex_report_5 = _chexbert(
                    hyps=gens_all, refs=tgts_all
                )
                print("\n--- CheXbert (batch) ---")
                print(f"CheXbert macro F1 (called 'accuracy' in pkg): {float(chex_accuracy):.4f}")
                if isinstance(chex_report_5, dict):
                    top5 = ', '.join([f"{k}: {v:.3f}" for k, v in chex_report_5.items()])
                    print(f"Top-5 precision [Cardiomegaly, Edema, Consolidation, Atelectasis, Pleural Effusion]: {top5}")
            except Exception as e:
                print(f"[Warn] CheXbert failed to run: {e}")

        # Cleanup
        del pixel_values, ids_loader, paths, raw_labels, gen_ids
        torch.cuda.empty_cache()
        break


[CheXbert] Deleted old weights at: C:\Users\emman\AppData\Local\chexbert\chexbert\Cache\Cache\chexbert.pth
[CheXbert] Downloading weights from: https://stanfordmedicine.box.com/shared/static/c3stck6w6dol3h36grdc97xoydzxd7w9
[CheXbert] Download failed from https://stanfordmedicine.box.com/shared/static/c3stck6w6dol3h36grdc97xoydzxd7w9: 404 Client Error: Not Found for url: https://stanfordmedicine.app.box.com/public/static/c3stck6w6dol3h36grdc97xoydzxd7w9
[CheXbert] Downloading weights from: https://huggingface.co/StanfordAIMI/RRG_scorers/resolve/main/chexbert.pth
[CheXbert] Weights saved to: C:\Users\emman\AppData\Local\chexbert\chexbert\Cache\Cache\chexbert.pth
[Info] CheXbert not available: [Errno 2] No such file or directory: 'C:\\Users\\emman\\AppData\\Local\\chexbert\\chexbert\\Cache\\chexbert.pth'
Predictions test:

TEST GEN 1: 1. interval appearance of endotracheal tube which is 4 cm above the carina. a defibrillator pad overlies the left hemithorax. 2. right lower lobe opacifica