In [None]:
import os
import pandas as pd
from itertools import islice
import torch
from torch.utils.data import DataLoader
from utils.text_metrics import evaluate_all_metrics
from utils.temp_utils import *
from utils.lstm_models import DinoLSTMAttnCaptioner, DinoBiLSTMAttnCaptioner
from utils.chexpert_dataset import CheXpertDataset
from utils.padchest_dataset import PadChestGRDataset

# Data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

os.chdir(os.path.dirname(os.getcwd()))

CSV_PATH = "Datasets/CheXpertPlus/df_chexpert_plus_240401.csv"
IMG_ROOT = "Datasets/CheXpertPlus/PNG"
TEXT_COL = "section_impression"
PATH_COL = "path_to_image"

IMG_SIZE = 224
MAX_LEN = 64
NUM_BATCH = 8

tf = dino_image_transform(img_size=IMG_SIZE)

ds_train = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="train", transform=tf, text_col=TEXT_COL)
ds_valid = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="valid", transform=tf, text_col=TEXT_COL)
ds_test = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="test", transform=tf, text_col=TEXT_COL)

tokenizer = build_tokenizer_from_labels()
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
bos_id = tokenizer.bos_token_id
collate_fn = CaptionCollate(tokenizer, pad_id)

train_loader = DataLoader(ds_train, batch_size=NUM_BATCH, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(ds_valid, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(ds_test, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)

Using device: cuda
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG


# Model

In [4]:
# DINO ViT-S/16 hidden size is 384 
EMBEDDING_D_IMG = 384
N_PREFIX = (IMG_SIZE // 16) ** 2  # number of visual prefix tokens (including CLS)

model = DinoBiLSTMAttnCaptioner(
    vocab_size=tokenizer.vocab_size,
    d_img=EMBEDDING_D_IMG,
    d_h=512,
    pad_id=pad_id,
    dino_model_id="facebook/dinov3-vits16-pretrain-lvd1689m",
    freeze_dino=True,
).to(device)

# Train Parameters

In [5]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2
)
loss = sequence_ce_loss
NUM_EPOCHS = 100
BATCHES_PER_EPOCH = 10

# Training

In [6]:
for epoch in range(NUM_EPOCHS):
    slice_train_loader = islice(train_loader, BATCHES_PER_EPOCH)
    slice_valid_loader = islice(valid_loader, BATCHES_PER_EPOCH)
    train_stats = train_one_epoch(model, slice_train_loader, optimizer, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss, grad_clip=1.0)
    val_stats = evaluate(model, slice_valid_loader, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss)
    print(f"Epoch {epoch + 1}: Train Loss={train_stats['loss']:.4f}, PPL={train_stats['ppl']:.2f} | "
            f"Val Loss={val_stats['val_loss']:.4f}, Val PPL={val_stats['val_ppl']:.2f}")

  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
Training: 100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 1: Train Loss=9.5185, PPL=21870.51 | Val Loss=7.6469, Val PPL=2134.68


Training: 100%|██████████| 10/10 [00:05<00:00,  1.72it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]


Epoch 2: Train Loss=6.8096, PPL=1059.04 | Val Loss=5.9161, Val PPL=390.86


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 3: Train Loss=5.5467, PPL=264.51 | Val Loss=5.1433, Val PPL=185.36


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 4: Train Loss=4.8883, PPL=137.09 | Val Loss=4.6878, Val PPL=117.28


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 5: Train Loss=4.6597, PPL=109.11 | Val Loss=4.3766, Val PPL=85.09


Training: 100%|██████████| 10/10 [00:05<00:00,  1.90it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.28it/s]


Epoch 6: Train Loss=4.5047, PPL=92.68 | Val Loss=4.1507, Val PPL=66.91


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 7: Train Loss=4.1576, PPL=66.57 | Val Loss=3.9794, Val PPL=55.80


Training: 100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 8: Train Loss=4.0338, PPL=58.19 | Val Loss=3.8452, Val PPL=48.52


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 9: Train Loss=3.7542, PPL=44.02 | Val Loss=3.7414, Val PPL=43.58


Training: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.28it/s]


Epoch 10: Train Loss=3.7655, PPL=44.03 | Val Loss=3.6755, Val PPL=40.68


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 11: Train Loss=3.7064, PPL=41.70 | Val Loss=3.6206, Val PPL=38.41


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.28it/s]


Epoch 12: Train Loss=3.7531, PPL=43.84 | Val Loss=3.5439, Val PPL=35.51


Training: 100%|██████████| 10/10 [00:05<00:00,  1.92it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 13: Train Loss=3.4346, PPL=31.59 | Val Loss=3.4897, Val PPL=33.60


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]


Epoch 14: Train Loss=3.5739, PPL=36.36 | Val Loss=3.4583, Val PPL=32.54


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 15: Train Loss=3.5389, PPL=35.82 | Val Loss=3.4121, Val PPL=31.06


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]


Epoch 16: Train Loss=3.4722, PPL=33.91 | Val Loss=3.3738, Val PPL=29.86


Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.31it/s]


Epoch 17: Train Loss=3.3303, PPL=28.83 | Val Loss=3.3348, Val PPL=28.68


Training: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.30it/s]


Epoch 18: Train Loss=3.3610, PPL=29.31 | Val Loss=3.3400, Val PPL=28.82


Training: 100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.28it/s]


Epoch 19: Train Loss=3.2603, PPL=26.31 | Val Loss=3.2888, Val PPL=27.34


Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 20: Train Loss=3.3245, PPL=28.86 | Val Loss=3.2792, Val PPL=27.06


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.30it/s]


Epoch 21: Train Loss=3.2982, PPL=28.02 | Val Loss=3.2558, Val PPL=26.43


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.28it/s]


Epoch 22: Train Loss=3.2894, PPL=27.22 | Val Loss=3.2401, Val PPL=25.99


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 23: Train Loss=3.2271, PPL=25.61 | Val Loss=3.2223, Val PPL=25.51


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 24: Train Loss=3.1864, PPL=24.49 | Val Loss=3.1859, Val PPL=24.58


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 25: Train Loss=3.2516, PPL=26.90 | Val Loss=3.1976, Val PPL=24.84


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 26: Train Loss=3.2139, PPL=25.54 | Val Loss=3.1679, Val PPL=24.10


Training: 100%|██████████| 10/10 [00:05<00:00,  1.72it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 27: Train Loss=3.1997, PPL=24.86 | Val Loss=3.1678, Val PPL=24.09


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 28: Train Loss=3.0943, PPL=22.40 | Val Loss=3.1433, Val PPL=23.50


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 29: Train Loss=3.1530, PPL=23.75 | Val Loss=3.1456, Val PPL=23.54


Training: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 30: Train Loss=3.0389, PPL=21.03 | Val Loss=3.1247, Val PPL=23.04


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 31: Train Loss=3.0733, PPL=21.95 | Val Loss=3.1234, Val PPL=23.00


Training: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 32: Train Loss=3.2143, PPL=25.36 | Val Loss=3.1236, Val PPL=22.99


Training: 100%|██████████| 10/10 [00:05<00:00,  1.71it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 33: Train Loss=3.1612, PPL=24.25 | Val Loss=3.1074, Val PPL=22.60


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 34: Train Loss=3.0875, PPL=22.51 | Val Loss=3.1002, Val PPL=22.42


Training: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 35: Train Loss=3.0932, PPL=22.36 | Val Loss=3.0886, Val PPL=22.15


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 36: Train Loss=3.0873, PPL=22.25 | Val Loss=3.0763, Val PPL=21.87


Training: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 37: Train Loss=2.9784, PPL=19.74 | Val Loss=3.0740, Val PPL=21.81


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 38: Train Loss=2.9970, PPL=20.30 | Val Loss=3.0746, Val PPL=21.82


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 39: Train Loss=2.9837, PPL=19.82 | Val Loss=3.0532, Val PPL=21.35


Training: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 40: Train Loss=3.0002, PPL=20.28 | Val Loss=3.0656, Val PPL=21.62


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 41: Train Loss=3.0299, PPL=20.95 | Val Loss=3.0486, Val PPL=21.25


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 42: Train Loss=3.0785, PPL=22.03 | Val Loss=3.0423, Val PPL=21.11


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 43: Train Loss=3.0215, PPL=20.77 | Val Loss=3.0296, Val PPL=20.84


Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 44: Train Loss=3.0152, PPL=20.59 | Val Loss=3.0309, Val PPL=20.87


Training: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 45: Train Loss=3.0727, PPL=21.91 | Val Loss=3.0238, Val PPL=20.72


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 46: Train Loss=2.9429, PPL=19.10 | Val Loss=3.0176, Val PPL=20.59


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 47: Train Loss=3.0552, PPL=21.52 | Val Loss=3.0075, Val PPL=20.38


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 48: Train Loss=2.9774, PPL=19.95 | Val Loss=3.0218, Val PPL=20.67


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 49: Train Loss=2.9433, PPL=19.13 | Val Loss=3.0015, Val PPL=20.25


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 50: Train Loss=2.9786, PPL=19.82 | Val Loss=3.0027, Val PPL=20.27


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 51: Train Loss=2.9181, PPL=18.70 | Val Loss=2.9999, Val PPL=20.21


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 52: Train Loss=2.9919, PPL=20.19 | Val Loss=3.0014, Val PPL=20.24


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 53: Train Loss=2.9483, PPL=19.16 | Val Loss=2.9925, Val PPL=20.05


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 54: Train Loss=2.9527, PPL=19.40 | Val Loss=2.9915, Val PPL=20.03


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 55: Train Loss=2.9617, PPL=19.60 | Val Loss=2.9836, Val PPL=19.87


Training: 100%|██████████| 10/10 [00:05<00:00,  1.72it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 56: Train Loss=2.9635, PPL=19.48 | Val Loss=2.9907, Val PPL=20.00


Training: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 57: Train Loss=2.9614, PPL=19.56 | Val Loss=2.9673, Val PPL=19.54


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 58: Train Loss=2.9765, PPL=19.78 | Val Loss=2.9735, Val PPL=19.66


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.33it/s]


Epoch 59: Train Loss=2.9073, PPL=18.48 | Val Loss=2.9688, Val PPL=19.57


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 60: Train Loss=2.9626, PPL=19.48 | Val Loss=2.9624, Val PPL=19.44


Training: 100%|██████████| 10/10 [00:05<00:00,  1.68it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 61: Train Loss=2.9108, PPL=18.45 | Val Loss=2.9638, Val PPL=19.46


Training: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 62: Train Loss=2.9040, PPL=18.52 | Val Loss=2.9616, Val PPL=19.42


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 63: Train Loss=2.9373, PPL=19.20 | Val Loss=2.9682, Val PPL=19.55


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 64: Train Loss=2.9146, PPL=18.61 | Val Loss=2.9498, Val PPL=19.19


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 65: Train Loss=2.9753, PPL=19.75 | Val Loss=2.9674, Val PPL=19.53


Training: 100%|██████████| 10/10 [00:05<00:00,  1.67it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 66: Train Loss=2.8805, PPL=17.89 | Val Loss=2.9494, Val PPL=19.18


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 67: Train Loss=2.9212, PPL=18.82 | Val Loss=2.9539, Val PPL=19.25


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 68: Train Loss=2.9588, PPL=19.54 | Val Loss=2.9613, Val PPL=19.39


Training: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 69: Train Loss=2.9066, PPL=18.40 | Val Loss=2.9428, Val PPL=19.03


Training: 100%|██████████| 10/10 [00:05<00:00,  1.71it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 70: Train Loss=2.8944, PPL=18.27 | Val Loss=2.9416, Val PPL=19.01


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 71: Train Loss=2.8957, PPL=18.24 | Val Loss=2.9415, Val PPL=19.01


Training: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 72: Train Loss=2.9259, PPL=18.81 | Val Loss=2.9425, Val PPL=19.03


Training: 100%|██████████| 10/10 [00:05<00:00,  1.72it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 73: Train Loss=2.9125, PPL=18.57 | Val Loss=2.9337, Val PPL=18.86


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 74: Train Loss=2.8778, PPL=17.96 | Val Loss=2.9375, Val PPL=18.93


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 75: Train Loss=2.9028, PPL=18.39 | Val Loss=2.9441, Val PPL=19.06


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 76: Train Loss=2.8657, PPL=17.66 | Val Loss=2.9358, Val PPL=18.90


Training: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 77: Train Loss=2.9088, PPL=18.44 | Val Loss=2.9187, Val PPL=18.57


Training: 100%|██████████| 10/10 [00:06<00:00,  1.63it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 78: Train Loss=2.8991, PPL=18.31 | Val Loss=2.9197, Val PPL=18.59


Training: 100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 79: Train Loss=2.8470, PPL=17.31 | Val Loss=2.9164, Val PPL=18.52


Training: 100%|██████████| 10/10 [00:06<00:00,  1.64it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 80: Train Loss=2.9213, PPL=18.76 | Val Loss=2.9256, Val PPL=18.69


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 81: Train Loss=2.8638, PPL=17.62 | Val Loss=2.9225, Val PPL=18.63


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 82: Train Loss=2.8666, PPL=17.70 | Val Loss=2.9130, Val PPL=18.45


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 83: Train Loss=2.8437, PPL=17.31 | Val Loss=2.9118, Val PPL=18.43


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 84: Train Loss=2.9137, PPL=18.59 | Val Loss=2.9269, Val PPL=18.71


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 85: Train Loss=2.8567, PPL=17.39 | Val Loss=2.9090, Val PPL=18.38


Training: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 86: Train Loss=2.8612, PPL=17.60 | Val Loss=2.8991, Val PPL=18.20


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 87: Train Loss=2.8405, PPL=17.21 | Val Loss=2.9098, Val PPL=18.39


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 88: Train Loss=2.8864, PPL=18.08 | Val Loss=2.9179, Val PPL=18.54


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 89: Train Loss=2.8585, PPL=17.56 | Val Loss=2.9030, Val PPL=18.27


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 90: Train Loss=2.8812, PPL=17.98 | Val Loss=2.9014, Val PPL=18.24


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 91: Train Loss=2.8898, PPL=18.21 | Val Loss=2.9085, Val PPL=18.36


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 92: Train Loss=2.8466, PPL=17.31 | Val Loss=2.9008, Val PPL=18.22


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 93: Train Loss=2.8998, PPL=18.38 | Val Loss=2.8890, Val PPL=18.01


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 94: Train Loss=2.8611, PPL=17.60 | Val Loss=2.9004, Val PPL=18.22


Training: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 95: Train Loss=2.8356, PPL=17.04 | Val Loss=2.8952, Val PPL=18.12


Training: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 96: Train Loss=2.8683, PPL=17.83 | Val Loss=2.8957, Val PPL=18.13


Training: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 97: Train Loss=2.8041, PPL=16.58 | Val Loss=2.8905, Val PPL=18.04


Training: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 98: Train Loss=2.8743, PPL=17.83 | Val Loss=2.8919, Val PPL=18.06


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 99: Train Loss=2.8839, PPL=18.00 | Val Loss=2.8885, Val PPL=18.00


Training: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]

Epoch 100: Train Loss=2.8836, PPL=18.03 | Val Loss=2.8940, Val PPL=18.10





# Test Parameters

In [7]:
BATCHES_PER_TEST = 1
GREEDY_DECODE = True
TEST_MAX_LEN = 256
TEST_TOP_P = 0.9
TEST_TEMPERATURE = 0.9

# Test

In [8]:
slice_test_loader = islice(test_loader, BATCHES_PER_TEST)
test_stats = evaluate(model, slice_test_loader, device, pad_id, num_batches=BATCHES_PER_TEST)
print(f"Test Loss={test_stats['val_loss']:.4f}, Test PPL={test_stats['val_ppl']:.2f}")

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.94it/s]

Test Loss=2.8456, Test PPL=17.21





# Test Report Generation

In [9]:
with torch.no_grad():
    for pixel_values, ids_loader, paths, raw_labels in test_loader:
        pixel_values = pixel_values.to(device)
        gen_ids = model.generate(
            pixel_values=pixel_values,
            bos_id=bos_id, eos_id=eos_id,
            max_new_tokens=TEST_MAX_LEN, top_p=TEST_TOP_P, temperature=TEST_TEMPERATURE, greedy=GREEDY_DECODE
        )

        info = model.generate_with_logging(
            pixel_values=pixel_values,             # [B, C, H, W]
            bos_id=tokenizer.bos_token_id,
            eos_id=tokenizer.eos_token_id,
            tokenizer=tokenizer,
            preset="safe_sample",
            stop_sequences=None, #["\n\n", "Impression:"],
            max_new_tokens=128,
        )
        print("sequences:", info["sequences"].shape)
        for i, s in enumerate(info["per_sample"]):
            print(f"[{i}] EOS={s['stopping']['hit_eos']} rep={s['repetition']}")
            print(s["text"].get("generated", "")[:200])

        eval_results = evaluate_all_metrics(raw_labels, [s["text"]["generated"] for s in info["per_sample"]], evaluation_mode="CheXagent")
        for metric, scores in eval_results.items():
            print(f"{metric}: {scores}")

        print("Predictions (first batch):")
        for i in range(gen_ids.size(0)):
            text_gen = tokenizer.decode(gen_ids[i].tolist())
            text_tgt = tokenizer.decode(ids_loader[i].tolist())
            print(f"\nGEN {i+1}:", text_gen)
            print(f"TGT {i+1}:", text_tgt)
            results = evaluate_all_metrics([text_tgt], [text_gen], evaluation_mode="CheXagent")
            for metric, scores in results.items():
                print(f"{metric}: {scores}")
        del pixel_values, ids_loader, paths, raw_labels, gen_ids
        torch.cuda.empty_cache()
        break

sequences: torch.Size([8, 129])
[0] EOS=True rep={'max_token_run': 3, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
initial initial initial0101017777772
[1] EOS=True rep={'max_token_run': 3, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
gradual gradual gradual views views views reveal study study study dated 52 52 522
[2] EOS=True rep={'max_token_run': 3, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
sequence sequence sequenceic angles angles angles rotation rotation rotation artifact artifactual the the preceding t t t444 857575 of of treitzzzgan.
[3] EOS=False rep={'max_token_run': 3, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
negative negative for confirmationionionion clearly clearly clearly central centrally questioned questioned questioned described described described below below below just just just beyond beyond beyo
[4] EOS=True rep={'max_token_run': 3, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
gradual gradual gradual vertebral bodies bodies bodiesll cephalic

In [10]:
# Print number of model parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total model parameters: {total_params}")

Total model parameters: 119691252


In [11]:
info

{'per_sample': [{'preset': 'safe_sample',
   'params': {'do_sample': True,
    'temperature': 0.7,
    'top_p': 0.9,
    'top_k': 50,
    'repetition_penalty': 1.15,
    'no_repeat_ngram_size': 3},
   'lengths': {'prompt_tokens': 1, 'new_tokens': 11, 'total_tokens': 12},
   'stopping': {'hit_eos': True, 'eos_pos': 10, 'stop_sequences': []},
   'repetition': {'max_token_run': 3,
    'max_repeat_trigram': 1,
    'max_repeat_4gram': 1},
   'probes': [{'step': 1,
     'entropy': 1.865197777748108,
     'topk': [{'token_id': 18258, 'p': 0.419748455286026},
      {'token_id': 4366, 'p': 0.21426939964294434},
      {'token_id': 3068, 'p': 0.08862230181694031},
      {'token_id': 4954, 'p': 0.0687548890709877},
      {'token_id': 23589, 'p': 0.03689390793442726}]},
    {'step': 2,
     'entropy': -0.0,
     'topk': [{'token_id': 3288, 'p': 1.0},
      {'token_id': 2, 'p': 0.0},
      {'token_id': 0, 'p': 0.0},
      {'token_id': 3, 'p': 0.0},
      {'token_id': 1, 'p': 0.0}]},
    {'step': 3,


In [12]:
text = "1.  STABLE SMALL LEFT INTERNAL JUGULAR OPACITIES WITH PATCHY TUBE AND NASOGASTRIC TUBES, RIGHT LOWER MEDIASTINAL SIDED CATHETER.  NO SIGNIFICANT CHANGE IN THE PREVIOUS STUDYDEMONSTRATE ATELECTASIS O"
encoded = tokenizer.encode(text)
print("BOS token id:", tokenizer.bos_token_id, "EOS token id:", tokenizer.eos_token_id, "PAD token id:", tokenizer.pad_token_id)
print(encoded)

BOS token id: 101 EOS token id: 102 PAD token id: 0
[101, 122, 119, 6111, 1353, 1286, 4422, 34986, 5552, 39280, 49176, 1114, 10085, 1183, 7159, 1105, 9468, 7301, 32519, 11182, 117, 1268, 2211, 2394, 34979, 7050, 11641, 5855, 30682, 119, 1185, 2418, 1849, 1107, 1103, 2166, 2025, 31386, 8756, 18465, 14229, 184, 102]
