In [1]:
import os
os.chdir(os.path.dirname(os.getcwd()))
import pandas as pd
from itertools import islice
import torch
from torch.utils.data import DataLoader
from utils.text_metrics import evaluate_all_metrics, save_metrics_to_json
from utils.temp_utils import *
from utils.lstm_models import DinoLSTMAttnCaptioner
from utils.chexpert_dataset import CheXpertDataset
from utils.padchest_dataset import PadChestGRDataset

# Data

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

CSV_PATH = "Datasets/CheXpertPlus/df_chexpert_plus_240401.csv"
IMG_ROOT = "Datasets/CheXpertPlus/PNG"

TEXT_COL = "section_impression"
PATH_COL = "path_to_image"

IMG_SIZE = 224
MAX_LEN = 64
NUM_BATCH = 8

tf = dino_image_transform(img_size=IMG_SIZE)

ds_train = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="train", transform=tf, text_col=TEXT_COL)
ds_valid = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="valid", transform=tf, text_col=TEXT_COL)
ds_test = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="test", transform=tf, text_col=TEXT_COL)

tokenizer = build_tokenizer_from_labels()
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
bos_id = tokenizer.bos_token_id
collate_fn = CaptionCollate(tokenizer, pad_id)

train_loader = DataLoader(ds_train, batch_size=NUM_BATCH, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(ds_valid, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(ds_test, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)

Using device: cuda
[INFO] Kept 47494/223462 rows with existing PNGs
[INFO] Kept 47494/223462 rows with existing PNGs
[INFO] Kept 47494/223462 rows with existing PNGs


# Model

In [3]:
# DINO ViT-S/16 hidden size is 384 
EMBEDDING_D_IMG = 384
N_PREFIX = (IMG_SIZE // 16) ** 2  # number of visual prefix tokens (including CLS)

model = DinoLSTMAttnCaptioner(
    vocab_size=tokenizer.vocab_size,
    d_img=EMBEDDING_D_IMG,
    d_h=512,
    pad_id=pad_id,
    dino_model_id="facebook/dinov3-vits16-pretrain-lvd1689m",
    freeze_dino=True,
).to(device)

# Train Parameters

In [4]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2
)
loss = sequence_ce_loss
NUM_EPOCHS = 100
BATCHES_PER_EPOCH = 10

# Training

In [5]:
import time
time_start = time.time()
for epoch in range(NUM_EPOCHS):
    slice_train_loader = islice(train_loader, BATCHES_PER_EPOCH)
    slice_valid_loader = islice(valid_loader, BATCHES_PER_EPOCH)
    train_stats = train_one_epoch(model, slice_train_loader, optimizer, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss, grad_clip=1.0)
    val_stats = evaluate(model, slice_valid_loader, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss)
    print(f"Epoch {epoch + 1}: Train Loss={train_stats['loss']:.4f}, PPL={train_stats['ppl']:.2f} | "
            f"Val Loss={val_stats['val_loss']:.4f}, Val PPL={val_stats['val_ppl']:.2f}")
training_time = time.time() - time_start

Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.28it/s]


Epoch 1: Train Loss=10.3024, PPL=34587.33 | Val Loss=9.2550, Val PPL=10506.63


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.31it/s]


Epoch 2: Train Loss=8.7168, PPL=6575.14 | Val Loss=7.7016, Val PPL=2244.13


Training: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.33it/s]


Epoch 3: Train Loss=7.2312, PPL=1489.60 | Val Loss=6.8034, Val PPL=930.07


Training: 100%|██████████| 10/10 [00:06<00:00,  1.43it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 4: Train Loss=6.6987, PPL=824.07 | Val Loss=6.4041, Val PPL=630.63


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 5: Train Loss=6.2878, PPL=549.93 | Val Loss=6.1255, Val PPL=480.27


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 6: Train Loss=6.0453, PPL=444.73 | Val Loss=5.8898, Val PPL=383.52


Training: 100%|██████████| 10/10 [00:06<00:00,  1.46it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 7: Train Loss=5.6665, PPL=304.78 | Val Loss=5.7380, Val PPL=331.69


Training: 100%|██████████| 10/10 [00:07<00:00,  1.43it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 8: Train Loss=5.7909, PPL=341.86 | Val Loss=5.5745, Val PPL=283.43


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 9: Train Loss=5.8983, PPL=375.82 | Val Loss=5.4489, Val PPL=250.75


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 10: Train Loss=5.5903, PPL=285.91 | Val Loss=5.3403, Val PPL=224.07


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 11: Train Loss=5.4948, PPL=251.67 | Val Loss=5.2526, Val PPL=205.24


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 12: Train Loss=5.3854, PPL=223.70 | Val Loss=5.1845, Val PPL=191.10


Training: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 13: Train Loss=5.3261, PPL=211.79 | Val Loss=5.1202, Val PPL=178.13


Training: 100%|██████████| 10/10 [00:07<00:00,  1.40it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 14: Train Loss=5.3257, PPL=216.87 | Val Loss=5.0835, Val PPL=171.25


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 15: Train Loss=5.4380, PPL=243.90 | Val Loss=5.0407, Val PPL=163.74


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 16: Train Loss=5.0878, PPL=164.25 | Val Loss=5.0040, Val PPL=157.44


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.32it/s]


Epoch 17: Train Loss=5.1237, PPL=176.20 | Val Loss=4.9639, Val PPL=151.68


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 18: Train Loss=5.0571, PPL=165.88 | Val Loss=4.9267, Val PPL=146.37


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 19: Train Loss=4.9649, PPL=148.88 | Val Loss=4.8899, Val PPL=141.06


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 20: Train Loss=5.0185, PPL=155.79 | Val Loss=4.8653, Val PPL=137.49


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 21: Train Loss=5.0225, PPL=161.31 | Val Loss=4.8581, Val PPL=136.47


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 22: Train Loss=5.0127, PPL=156.58 | Val Loss=4.8316, Val PPL=132.51


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 23: Train Loss=4.9420, PPL=148.01 | Val Loss=4.8055, Val PPL=128.88


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 24: Train Loss=4.7842, PPL=121.09 | Val Loss=4.7802, Val PPL=126.04


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 25: Train Loss=4.8038, PPL=123.56 | Val Loss=4.7736, Val PPL=125.23


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 26: Train Loss=4.7636, PPL=118.46 | Val Loss=4.7501, Val PPL=122.28


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 27: Train Loss=4.8132, PPL=125.26 | Val Loss=4.7284, Val PPL=119.53


Training: 100%|██████████| 10/10 [00:06<00:00,  1.48it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 28: Train Loss=4.8407, PPL=131.72 | Val Loss=4.7270, Val PPL=119.14


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 29: Train Loss=4.9329, PPL=142.39 | Val Loss=4.7150, Val PPL=117.62


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Epoch 30: Train Loss=4.6382, PPL=108.70 | Val Loss=4.7078, Val PPL=116.75


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Epoch 31: Train Loss=4.8550, PPL=133.81 | Val Loss=4.6914, Val PPL=114.72


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 32: Train Loss=4.8449, PPL=132.31 | Val Loss=4.6685, Val PPL=112.18


Training: 100%|██████████| 10/10 [00:06<00:00,  1.48it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 33: Train Loss=4.8402, PPL=130.76 | Val Loss=4.6711, Val PPL=112.48


Training: 100%|██████████| 10/10 [00:06<00:00,  1.43it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Epoch 34: Train Loss=4.7432, PPL=118.11 | Val Loss=4.6673, Val PPL=112.05


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 35: Train Loss=4.8123, PPL=126.04 | Val Loss=4.6393, Val PPL=108.97


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 36: Train Loss=4.8269, PPL=127.66 | Val Loss=4.6318, Val PPL=108.24


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 37: Train Loss=4.8371, PPL=129.50 | Val Loss=4.6262, Val PPL=107.33


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 38: Train Loss=4.8852, PPL=143.96 | Val Loss=4.6225, Val PPL=106.88


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 39: Train Loss=4.7216, PPL=113.79 | Val Loss=4.6092, Val PPL=105.55


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 40: Train Loss=4.7286, PPL=117.71 | Val Loss=4.5920, Val PPL=103.53


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 41: Train Loss=4.6585, PPL=107.87 | Val Loss=4.5910, Val PPL=103.58


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 42: Train Loss=4.7397, PPL=117.21 | Val Loss=4.5843, Val PPL=102.72


Training: 100%|██████████| 10/10 [00:06<00:00,  1.57it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 43: Train Loss=4.5509, PPL=96.89 | Val Loss=4.5741, Val PPL=101.64


Training: 100%|██████████| 10/10 [00:06<00:00,  1.46it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 44: Train Loss=4.6205, PPL=103.19 | Val Loss=4.5701, Val PPL=101.31


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 45: Train Loss=4.6192, PPL=102.68 | Val Loss=4.5572, Val PPL=99.83


Training: 100%|██████████| 10/10 [00:07<00:00,  1.40it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Epoch 46: Train Loss=4.6341, PPL=107.82 | Val Loss=4.5435, Val PPL=98.53


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Epoch 47: Train Loss=4.7288, PPL=117.98 | Val Loss=4.5517, Val PPL=99.56


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 48: Train Loss=4.6566, PPL=109.55 | Val Loss=4.5453, Val PPL=98.78


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 49: Train Loss=4.6108, PPL=108.00 | Val Loss=4.5229, Val PPL=96.59


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Epoch 50: Train Loss=4.6605, PPL=113.98 | Val Loss=4.5254, Val PPL=96.73


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 51: Train Loss=4.7396, PPL=118.70 | Val Loss=4.5149, Val PPL=95.52


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 52: Train Loss=4.5034, PPL=90.37 | Val Loss=4.5239, Val PPL=96.32


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 53: Train Loss=4.5390, PPL=95.64 | Val Loss=4.5185, Val PPL=95.84


Training: 100%|██████████| 10/10 [00:06<00:00,  1.46it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 54: Train Loss=4.5590, PPL=97.11 | Val Loss=4.5111, Val PPL=95.30


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 55: Train Loss=4.5367, PPL=96.75 | Val Loss=4.4992, Val PPL=94.13


Training: 100%|██████████| 10/10 [00:06<00:00,  1.46it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 56: Train Loss=4.4254, PPL=85.10 | Val Loss=4.4853, Val PPL=92.78


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Epoch 57: Train Loss=4.5439, PPL=97.21 | Val Loss=4.4777, Val PPL=91.98


Training: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 58: Train Loss=4.5381, PPL=94.69 | Val Loss=4.4825, Val PPL=92.51


Training: 100%|██████████| 10/10 [00:07<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 59: Train Loss=4.6532, PPL=107.72 | Val Loss=4.4789, Val PPL=92.12


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 60: Train Loss=4.4447, PPL=86.90 | Val Loss=4.4731, Val PPL=91.76


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 61: Train Loss=4.5845, PPL=100.57 | Val Loss=4.4794, Val PPL=92.31


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 62: Train Loss=4.4565, PPL=88.13 | Val Loss=4.4740, Val PPL=91.77


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 63: Train Loss=4.4962, PPL=91.16 | Val Loss=4.4542, Val PPL=89.91


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 64: Train Loss=4.4661, PPL=89.77 | Val Loss=4.4514, Val PPL=89.70


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 65: Train Loss=4.6842, PPL=110.20 | Val Loss=4.4649, Val PPL=90.64


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 66: Train Loss=4.5228, PPL=96.01 | Val Loss=4.4462, Val PPL=89.03


Training: 100%|██████████| 10/10 [00:07<00:00,  1.26it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 67: Train Loss=4.5564, PPL=95.68 | Val Loss=4.4329, Val PPL=87.89


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 68: Train Loss=4.5759, PPL=98.81 | Val Loss=4.4356, Val PPL=88.18


Training: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 69: Train Loss=4.4343, PPL=85.29 | Val Loss=4.4270, Val PPL=87.31


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 70: Train Loss=4.4666, PPL=89.48 | Val Loss=4.4330, Val PPL=87.64


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 71: Train Loss=4.4743, PPL=89.59 | Val Loss=4.4143, Val PPL=86.22


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 72: Train Loss=4.6242, PPL=103.95 | Val Loss=4.4223, Val PPL=86.68


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 73: Train Loss=4.4635, PPL=88.81 | Val Loss=4.4261, Val PPL=86.98


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 74: Train Loss=4.4400, PPL=86.86 | Val Loss=4.4099, Val PPL=85.56


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 75: Train Loss=4.3560, PPL=79.89 | Val Loss=4.4160, Val PPL=86.01


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 76: Train Loss=4.5386, PPL=97.23 | Val Loss=4.4125, Val PPL=85.57


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 77: Train Loss=4.5049, PPL=90.88 | Val Loss=4.4039, Val PPL=84.99


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 78: Train Loss=4.3570, PPL=80.10 | Val Loss=4.3944, Val PPL=84.14


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 79: Train Loss=4.3884, PPL=82.09 | Val Loss=4.4058, Val PPL=85.02


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 80: Train Loss=4.4915, PPL=91.75 | Val Loss=4.4060, Val PPL=84.96


Training: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 81: Train Loss=4.4377, PPL=86.29 | Val Loss=4.4118, Val PPL=85.67


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Epoch 82: Train Loss=4.3896, PPL=83.95 | Val Loss=4.4021, Val PPL=84.75


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 83: Train Loss=4.4382, PPL=85.99 | Val Loss=4.3930, Val PPL=83.88


Training: 100%|██████████| 10/10 [00:06<00:00,  1.50it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 84: Train Loss=4.3790, PPL=81.66 | Val Loss=4.3972, Val PPL=84.37


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 85: Train Loss=4.4255, PPL=84.98 | Val Loss=4.3980, Val PPL=84.41


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 86: Train Loss=4.4762, PPL=89.44 | Val Loss=4.3795, Val PPL=82.68


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 87: Train Loss=4.3248, PPL=77.25 | Val Loss=4.3894, Val PPL=83.64


Training: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 88: Train Loss=4.2967, PPL=74.95 | Val Loss=4.3842, Val PPL=83.19


Training: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 89: Train Loss=4.4954, PPL=92.78 | Val Loss=4.3881, Val PPL=83.61


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 90: Train Loss=4.5725, PPL=100.77 | Val Loss=4.3884, Val PPL=83.65


Training: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 91: Train Loss=4.3301, PPL=77.28 | Val Loss=4.3783, Val PPL=82.79


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 92: Train Loss=4.4510, PPL=90.66 | Val Loss=4.3806, Val PPL=82.90


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 93: Train Loss=4.5067, PPL=92.49 | Val Loss=4.3694, Val PPL=82.03


Training: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 94: Train Loss=4.4557, PPL=89.28 | Val Loss=4.3753, Val PPL=82.56


Training: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 95: Train Loss=4.5434, PPL=98.91 | Val Loss=4.3766, Val PPL=82.57


Training: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 96: Train Loss=4.3328, PPL=78.55 | Val Loss=4.3691, Val PPL=81.87


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 97: Train Loss=4.3583, PPL=79.95 | Val Loss=4.3601, Val PPL=81.14


Training: 100%|██████████| 10/10 [00:08<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 98: Train Loss=4.4277, PPL=85.51 | Val Loss=4.3525, Val PPL=80.59


Training: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 99: Train Loss=4.4405, PPL=88.48 | Val Loss=4.3584, Val PPL=81.02


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]

Epoch 100: Train Loss=4.4665, PPL=88.91 | Val Loss=4.3660, Val PPL=81.50





# Test Parameters

In [6]:
BATCHES_PER_TEST = 1
GREEDY_DECODE = True
TEST_MAX_LEN = 256
TEST_TOP_P = 0.9
TEST_TEMPERATURE = 0.9

# Test

In [7]:
slice_test_loader = islice(test_loader, BATCHES_PER_TEST)
test_stats = evaluate(model, slice_test_loader, device, pad_id, num_batches=BATCHES_PER_TEST)
print(f"Test Loss={test_stats['val_loss']:.4f}, Test PPL={test_stats['val_ppl']:.2f}")

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.98it/s]

Test Loss=4.2775, Test PPL=72.06





# Test Report Generation

In [8]:
from utils.text_metrics import evaluate_all_metrics, save_metrics_to_json
text_generated = []
text_target = []
iteration = 0
with torch.no_grad():
    for pixel_values, ids_loader, paths, raw_labels in test_loader:
        iteration += 1
        pixel_values = pixel_values.to(device)
        # gen_ids = model.generate(
        #     pixel_values=pixel_values,
        #     bos_id=bos_id, eos_id=eos_id,
        #     max_new_tokens=TEST_MAX_LEN, top_p=TEST_TOP_P, temperature=TEST_TEMPERATURE, greedy=GREEDY_DECODE
        # )

        info = model.generate_with_logging(
            pixel_values=pixel_values,             # [B, C, H, W]
            bos_id=tokenizer.bos_token_id,
            eos_id=tokenizer.eos_token_id,
            tokenizer=tokenizer,
            preset="safe_sample",
            stop_sequences=["\n\n", "Impression:"],
            max_new_tokens=128,
        )
        # print("sequences:", info["sequences"].shape)
        # for i, s in enumerate(info["per_sample"]):
        #     print(f"[{i}] EOS={s['stopping']['hit_eos']} rep={s['repetition']}")
        #     print(s["text"].get("generated", "")[:200])
        
        text_generated.extend([s["text"]["generated"] for s in info["per_sample"]])
        text_target.extend(raw_labels)

        # eval_results = evaluate_all_metrics(raw_labels, [s["text"]["generated"] for s in info["per_sample"]], evaluation_mode="CheXagent")
        # for metric, scores in eval_results.items():
        #     print(f"{metric}: {scores}")

        # save_metrics_to_json(eval_results, f"lstm-vs-gpt/results/lstm_model_results_{NUM_EPOCHS}_Chexpert.json")

        # print("Predictions (first batch):")
        # for i in range(gen_ids.size(0)):
        #     text_gen = tokenizer.decode(gen_ids[i].tolist())
        #     text_tgt = tokenizer.decode(ids_loader[i].tolist())
        #     print(f"\nGEN {i+1}:", text_gen)
        #     print(f"TGT {i+1}:", text_tgt)
        #     results = evaluate_all_metrics([text_tgt], [text_gen], evaluation_mode="CheXagent")
        #     for metric, scores in results.items():
        #         print(f"{metric}: {scores}")
        del pixel_values, ids_loader, paths, raw_labels, info
        torch.cuda.empty_cache()
        if iteration >= 10:  # Limit to 1 batch for quick testing
            break

eval_results = evaluate_all_metrics(text_generated, text_target, evaluation_mode="CheXagent")
for metric, scores in eval_results.items():
    print(f"{metric}: {scores}")
eval_results["training_time_seconds"] = training_time
save_metrics_to_json(eval_results, f"lstm-vs-gpt/results/lstm_model_results_{NUM_EPOCHS}_Chexpert.json")

Using device: cuda:0
chexbert_f1_weighted: 0.30352609226598665
chexbert_f1_micro: 0.32653061224489793
chexbert_f1_macro: 0.1702417790644423
chexbert_f1_micro_5: 0.25225225225225223
chexbert_f1_macro_5: 0.22538158370880307
radgraph_f1_RG_E: 0.1315466601550468
radgraph_f1_RG_ER: 0.11165696032559828
