In [None]:
import os
os.chdir(os.path.dirname(os.getcwd()))
import pandas as pd
from itertools import islice
import torch
from torch.utils.data import DataLoader
from utils.text_metrics import evaluate_all_metrics
from utils.temp_utils import *
from utils.lstm_models import DinoLSTMAttnCaptioner
from utils.chexpert_dataset import CheXpertDataset
from utils.padchest_dataset import PadChestGRDataset

# Data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

CSV_PATH = "Datasets/CheXpertPlus/df_chexpert_plus_240401.csv"
IMG_ROOT = "Datasets/CheXpertPlus/PNG"

CSV_PATH = os.path.join(os.getcwd(), CSV_PATH)
IMG_ROOT = os.path.join(os.getcwd(), IMG_ROOT)

TEXT_COL = "section_impression"
PATH_COL = "path_to_image"

IMG_SIZE = 224
MAX_LEN = 64
NUM_BATCH = 8

tf = dino_image_transform(img_size=IMG_SIZE)

ds_train = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="train", transform=tf, text_col=TEXT_COL)
ds_valid = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="valid", transform=tf, text_col=TEXT_COL)
ds_test = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="test", transform=tf, text_col=TEXT_COL)

tokenizer = build_tokenizer_from_labels()
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
bos_id = tokenizer.bos_token_id
collate_fn = CaptionCollate(tokenizer, pad_id)

train_loader = DataLoader(ds_train, batch_size=NUM_BATCH, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(ds_valid, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(ds_test, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)

Using device: cuda
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG


# Model

In [3]:
# DINO ViT-S/16 hidden size is 384 
EMBEDDING_D_IMG = 384
N_PREFIX = (IMG_SIZE // 16) ** 2  # number of visual prefix tokens (including CLS)

model = DinoLSTMAttnCaptioner(
    vocab_size=tokenizer.vocab_size,
    d_img=EMBEDDING_D_IMG,
    d_h=512,
    pad_id=pad_id,
    dino_model_id="facebook/dinov3-vits16-pretrain-lvd1689m",
    freeze_dino=True,
).to(device)

# Train Parameters

In [4]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2
)
loss = sequence_ce_loss
NUM_EPOCHS = 100
BATCHES_PER_EPOCH = 10

# Training

In [5]:
for epoch in range(NUM_EPOCHS):
    slice_train_loader = islice(train_loader, BATCHES_PER_EPOCH)
    slice_valid_loader = islice(valid_loader, BATCHES_PER_EPOCH)
    train_stats = train_one_epoch(model, slice_train_loader, optimizer, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss, grad_clip=1.0)
    val_stats = evaluate(model, slice_valid_loader, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss)
    print(f"Epoch {epoch + 1}: Train Loss={train_stats['loss']:.4f}, PPL={train_stats['ppl']:.2f} | "
            f"Val Loss={val_stats['val_loss']:.4f}, Val PPL={val_stats['val_ppl']:.2f}")

  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 1: Train Loss=10.3416, PPL=36499.38 | Val Loss=9.3065, Val PPL=11065.84


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 2: Train Loss=8.7528, PPL=6905.63 | Val Loss=7.7062, Val PPL=2252.32


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 3: Train Loss=7.2850, PPL=1557.32 | Val Loss=6.7683, Val PPL=899.52


Training: 100%|██████████| 10/10 [00:07<00:00,  1.27it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 4: Train Loss=6.6514, PPL=785.19 | Val Loss=6.3782, Val PPL=619.88


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 5: Train Loss=6.3132, PPL=569.04 | Val Loss=6.0854, Val PPL=471.01


Training: 100%|██████████| 10/10 [00:06<00:00,  1.50it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 6: Train Loss=5.9947, PPL=421.42 | Val Loss=5.8820, Val PPL=387.13


Training: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 7: Train Loss=5.8213, PPL=345.86 | Val Loss=5.7029, Val PPL=324.73


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 8: Train Loss=5.7076, PPL=307.29 | Val Loss=5.5494, Val PPL=278.83


Training: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 9: Train Loss=5.6241, PPL=288.15 | Val Loss=5.4300, Val PPL=248.24


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 10: Train Loss=5.4615, PPL=242.34 | Val Loss=5.3393, Val PPL=227.58


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 11: Train Loss=5.4447, PPL=236.25 | Val Loss=5.2699, Val PPL=211.81


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 12: Train Loss=5.2083, PPL=184.42 | Val Loss=5.2023, Val PPL=197.30


Training: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


Epoch 13: Train Loss=5.3511, PPL=218.36 | Val Loss=5.1519, Val PPL=188.03


Training: 100%|██████████| 10/10 [00:06<00:00,  1.48it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 14: Train Loss=5.2659, PPL=198.38 | Val Loss=5.1089, Val PPL=180.45


Training: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 15: Train Loss=5.1191, PPL=174.66 | Val Loss=5.0765, Val PPL=174.38


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 16: Train Loss=5.2471, PPL=199.43 | Val Loss=5.0524, Val PPL=170.57


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 17: Train Loss=5.1280, PPL=171.29 | Val Loss=5.0081, Val PPL=162.24


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 18: Train Loss=5.0894, PPL=170.30 | Val Loss=4.9779, Val PPL=156.92


Training: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 19: Train Loss=5.0582, PPL=168.38 | Val Loss=4.9546, Val PPL=152.24


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 20: Train Loss=4.9681, PPL=147.47 | Val Loss=4.9079, Val PPL=144.98


Training: 100%|██████████| 10/10 [00:06<00:00,  1.46it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 21: Train Loss=4.9519, PPL=146.74 | Val Loss=4.8712, Val PPL=139.22


Training: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 22: Train Loss=4.9991, PPL=156.18 | Val Loss=4.8530, Val PPL=136.45


Training: 100%|██████████| 10/10 [00:06<00:00,  1.48it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 23: Train Loss=4.8893, PPL=137.21 | Val Loss=4.8349, Val PPL=133.86


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 24: Train Loss=4.8886, PPL=140.94 | Val Loss=4.8174, Val PPL=131.40


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 25: Train Loss=4.9023, PPL=141.66 | Val Loss=4.7992, Val PPL=129.05


Training: 100%|██████████| 10/10 [00:06<00:00,  1.56it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


Epoch 26: Train Loss=4.8785, PPL=136.80 | Val Loss=4.7799, Val PPL=126.35


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 27: Train Loss=4.7982, PPL=128.34 | Val Loss=4.7594, Val PPL=123.48


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 28: Train Loss=5.0566, PPL=174.17 | Val Loss=4.7517, Val PPL=122.07


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 29: Train Loss=4.8842, PPL=137.80 | Val Loss=4.7435, Val PPL=120.82


Training: 100%|██████████| 10/10 [00:07<00:00,  1.40it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 30: Train Loss=4.8371, PPL=129.51 | Val Loss=4.7358, Val PPL=119.98


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 31: Train Loss=4.7833, PPL=121.94 | Val Loss=4.7241, Val PPL=118.45


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 32: Train Loss=4.7580, PPL=117.35 | Val Loss=4.7131, Val PPL=117.36


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 33: Train Loss=4.7554, PPL=118.46 | Val Loss=4.6965, Val PPL=115.48


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 34: Train Loss=4.7131, PPL=114.50 | Val Loss=4.6945, Val PPL=115.26


Training: 100%|██████████| 10/10 [00:07<00:00,  1.40it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 35: Train Loss=4.7077, PPL=119.81 | Val Loss=4.6661, Val PPL=111.95


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 36: Train Loss=4.9060, PPL=139.44 | Val Loss=4.6556, Val PPL=110.71


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 37: Train Loss=4.7566, PPL=119.95 | Val Loss=4.6555, Val PPL=110.53


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 38: Train Loss=4.7010, PPL=114.81 | Val Loss=4.6533, Val PPL=109.99


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 39: Train Loss=4.7015, PPL=116.94 | Val Loss=4.6399, Val PPL=108.64


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 40: Train Loss=4.6762, PPL=109.21 | Val Loss=4.6243, Val PPL=107.03


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 41: Train Loss=4.5673, PPL=97.51 | Val Loss=4.6222, Val PPL=106.85


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 42: Train Loss=4.7534, PPL=121.79 | Val Loss=4.6108, Val PPL=105.75


Training: 100%|██████████| 10/10 [00:07<00:00,  1.28it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 43: Train Loss=4.6041, PPL=101.44 | Val Loss=4.6127, Val PPL=106.12


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 44: Train Loss=4.5601, PPL=98.20 | Val Loss=4.5909, Val PPL=103.66


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 45: Train Loss=4.6506, PPL=107.17 | Val Loss=4.5822, Val PPL=102.73


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 46: Train Loss=4.7800, PPL=128.76 | Val Loss=4.5911, Val PPL=103.46


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 47: Train Loss=4.6637, PPL=107.17 | Val Loss=4.5679, Val PPL=100.96


Training: 100%|██████████| 10/10 [00:06<00:00,  1.43it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 48: Train Loss=4.5105, PPL=92.75 | Val Loss=4.5632, Val PPL=100.65


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 49: Train Loss=4.6152, PPL=102.76 | Val Loss=4.5552, Val PPL=99.72


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 50: Train Loss=4.5823, PPL=99.87 | Val Loss=4.5433, Val PPL=98.54


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 51: Train Loss=4.5980, PPL=102.68 | Val Loss=4.5404, Val PPL=98.19


Training: 100%|██████████| 10/10 [00:07<00:00,  1.43it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 52: Train Loss=4.6248, PPL=109.09 | Val Loss=4.5446, Val PPL=98.74


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 53: Train Loss=4.5645, PPL=100.45 | Val Loss=4.5292, Val PPL=97.07


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 54: Train Loss=4.6377, PPL=104.80 | Val Loss=4.5207, Val PPL=96.24


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 55: Train Loss=4.5874, PPL=100.77 | Val Loss=4.5184, Val PPL=96.07


Training: 100%|██████████| 10/10 [00:07<00:00,  1.43it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 56: Train Loss=4.5338, PPL=95.16 | Val Loss=4.5093, Val PPL=95.17


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 57: Train Loss=4.4972, PPL=90.36 | Val Loss=4.5197, Val PPL=96.23


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 58: Train Loss=4.6179, PPL=104.54 | Val Loss=4.5250, Val PPL=96.82


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 59: Train Loss=4.5906, PPL=101.90 | Val Loss=4.5041, Val PPL=94.73


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 60: Train Loss=4.5077, PPL=93.21 | Val Loss=4.4891, Val PPL=93.31


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 61: Train Loss=4.5029, PPL=91.85 | Val Loss=4.4908, Val PPL=93.29


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 62: Train Loss=4.5937, PPL=102.31 | Val Loss=4.4821, Val PPL=92.34


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 63: Train Loss=4.4607, PPL=88.09 | Val Loss=4.4744, Val PPL=91.81


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 64: Train Loss=4.6190, PPL=104.75 | Val Loss=4.4749, Val PPL=91.70


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 65: Train Loss=4.4515, PPL=89.10 | Val Loss=4.4638, Val PPL=90.65


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 66: Train Loss=4.5229, PPL=95.41 | Val Loss=4.4781, Val PPL=92.03


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 67: Train Loss=4.4725, PPL=89.10 | Val Loss=4.4708, Val PPL=91.36


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 68: Train Loss=4.3936, PPL=82.61 | Val Loss=4.4760, Val PPL=91.84


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 69: Train Loss=4.6366, PPL=106.44 | Val Loss=4.4592, Val PPL=90.28


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 70: Train Loss=4.4537, PPL=87.78 | Val Loss=4.4595, Val PPL=90.27


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 71: Train Loss=4.4571, PPL=89.73 | Val Loss=4.4503, Val PPL=89.52


Training: 100%|██████████| 10/10 [00:06<00:00,  1.43it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 72: Train Loss=4.6250, PPL=102.95 | Val Loss=4.4669, Val PPL=90.84


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 73: Train Loss=4.4781, PPL=90.38 | Val Loss=4.4569, Val PPL=90.12


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 74: Train Loss=4.4321, PPL=85.57 | Val Loss=4.4611, Val PPL=90.63


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 75: Train Loss=4.4677, PPL=88.13 | Val Loss=4.4505, Val PPL=89.60


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 76: Train Loss=4.4426, PPL=88.20 | Val Loss=4.4473, Val PPL=89.36


Training: 100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 77: Train Loss=4.5203, PPL=93.32 | Val Loss=4.4435, Val PPL=88.83


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 78: Train Loss=4.4715, PPL=89.81 | Val Loss=4.4475, Val PPL=89.21


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 79: Train Loss=4.5185, PPL=93.62 | Val Loss=4.4404, Val PPL=88.47


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 80: Train Loss=4.3786, PPL=80.79 | Val Loss=4.4344, Val PPL=87.86


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 81: Train Loss=4.4841, PPL=89.83 | Val Loss=4.4250, Val PPL=87.09


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 82: Train Loss=4.3866, PPL=81.52 | Val Loss=4.4246, Val PPL=87.01


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 83: Train Loss=4.4277, PPL=86.41 | Val Loss=4.4281, Val PPL=87.37


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 84: Train Loss=4.5567, PPL=98.61 | Val Loss=4.4254, Val PPL=87.14


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 85: Train Loss=4.4524, PPL=87.88 | Val Loss=4.4229, Val PPL=87.07


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 86: Train Loss=4.3904, PPL=81.13 | Val Loss=4.4161, Val PPL=86.35


Training: 100%|██████████| 10/10 [00:06<00:00,  1.43it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 87: Train Loss=4.5072, PPL=93.19 | Val Loss=4.4195, Val PPL=86.65


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 88: Train Loss=4.3771, PPL=80.65 | Val Loss=4.4138, Val PPL=86.07


Training: 100%|██████████| 10/10 [00:07<00:00,  1.40it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 89: Train Loss=4.3813, PPL=81.23 | Val Loss=4.4090, Val PPL=85.71


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 90: Train Loss=4.3828, PPL=81.23 | Val Loss=4.3988, Val PPL=84.89


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 91: Train Loss=4.3929, PPL=82.33 | Val Loss=4.3923, Val PPL=84.22


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 92: Train Loss=4.3815, PPL=81.62 | Val Loss=4.4074, Val PPL=85.62


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 93: Train Loss=4.3768, PPL=81.08 | Val Loss=4.4070, Val PPL=85.60


Training: 100%|██████████| 10/10 [00:06<00:00,  1.48it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 94: Train Loss=4.3455, PPL=78.80 | Val Loss=4.4084, Val PPL=85.54


Training: 100%|██████████| 10/10 [00:07<00:00,  1.43it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 95: Train Loss=4.3526, PPL=79.25 | Val Loss=4.4128, Val PPL=86.04


Training: 100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 96: Train Loss=4.3894, PPL=81.61 | Val Loss=4.4044, Val PPL=85.12


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 97: Train Loss=4.4600, PPL=88.05 | Val Loss=4.3938, Val PPL=84.21


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 98: Train Loss=4.4969, PPL=92.10 | Val Loss=4.3872, Val PPL=83.80


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 99: Train Loss=4.4949, PPL=92.83 | Val Loss=4.3745, Val PPL=82.72


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]

Epoch 100: Train Loss=4.4679, PPL=89.16 | Val Loss=4.3852, Val PPL=83.51





# Test Parameters

In [6]:
BATCHES_PER_TEST = 1
GREEDY_DECODE = True
TEST_MAX_LEN = 256
TEST_TOP_P = 0.9
TEST_TEMPERATURE = 0.9

# Test

In [7]:
slice_test_loader = islice(test_loader, BATCHES_PER_TEST)
test_stats = evaluate(model, slice_test_loader, device, pad_id, num_batches=BATCHES_PER_TEST)
print(f"Test Loss={test_stats['val_loss']:.4f}, Test PPL={test_stats['val_ppl']:.2f}")

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.01it/s]

Test Loss=4.1979, Test PPL=66.55





# Test Report Generation

In [8]:
with torch.no_grad():
    for pixel_values, ids_loader, paths, raw_labels in test_loader:
        pixel_values = pixel_values.to(device)
        gen_ids = model.generate(
            pixel_values=pixel_values,
            bos_id=bos_id, eos_id=eos_id,
            max_new_tokens=TEST_MAX_LEN, top_p=TEST_TOP_P, temperature=TEST_TEMPERATURE, greedy=GREEDY_DECODE
        )

        info = model.generate_with_logging(
            pixel_values=pixel_values,             # [B, C, H, W]
            bos_id=tokenizer.bos_token_id,
            eos_id=tokenizer.eos_token_id,
            tokenizer=tokenizer,
            preset="safe_sample",
            stop_sequences=["\n\n", "Impression:"],
            max_new_tokens=128,
        )
        print("sequences:", info["sequences"].shape)
        for i, s in enumerate(info["per_sample"]):
            print(f"[{i}] EOS={s['stopping']['hit_eos']} rep={s['repetition']}")
            print(s["text"].get("generated", "")[:200])
        
        eval_results = evaluate_all_metrics(raw_labels, [s["text"]["generated"] for s in info["per_sample"]], evaluation_mode="CheXagent")
        for metric, scores in eval_results.items():
            print(f"{metric}: {scores}")

        print("Predictions (first batch):")
        for i in range(gen_ids.size(0)):
            text_gen = tokenizer.decode(gen_ids[i].tolist())
            text_tgt = tokenizer.decode(ids_loader[i].tolist())
            print(f"\nGEN {i+1}:", text_gen)
            print(f"TGT {i+1}:", text_tgt)
            results = evaluate_all_metrics([text_tgt], [text_gen], evaluation_mode="CheXagent")
            for metric, scores in results.items():
                print(f"{metric}: {scores}")
        del pixel_values, ids_loader, paths, raw_labels, gen_ids
        torch.cuda.empty_cache()
        break

sequences: torch.Size([8, 58])
[0] EOS=True rep={'max_token_run': 1, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
there is a new left internal jugular line. right ij catheter tip in the proximal svc.
[1] EOS=True rep={'max_token_run': 1, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
no change in bilateral pleural effusions.
[2] EOS=True rep={'max_token_run': 1, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
a single ap view of the chest demonstrates interval placement of a right internal jugular catheter. there is also a small leftsided pleural effusion and basilar atelectasis.
[3] EOS=True rep={'max_token_run': 1, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
a single semiupright view of the chest demonstrates interval removal of right internal jugular line. no pneumothorax.
[4] EOS=True rep={'max_token_run': 1, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
no change in cardiopulmonary status with persistent bibasilar opacities and bilateral pleural effusions.
[5] EOS=True r

In [9]:
# Print number of model parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total model parameters: {total_params}")

Total model parameters: 86334452
