In [1]:
import os
os.chdir(os.path.dirname(os.getcwd()))
import pandas as pd
from itertools import islice
import torch
from torch.utils.data import DataLoader
from utils.text_metrics import evaluate_all_metrics, save_metrics_to_json
from utils.temp_utils import *
from utils.gpt_models import DinoGPTCaptioner, DinoGPT2Captioner
from utils.chexpert_dataset import CheXpertDataset
from utils.padchest_dataset import PadChestGRDataset

# Data

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

CSV_PATH = "Datasets/CheXpertPlus/df_chexpert_plus_240401.csv"
IMG_ROOT = "Datasets/CheXpertPlus/PNG"

TEXT_COL = "section_impression"
PATH_COL = "path_to_image"

IMG_SIZE = 224
MAX_LEN = 64
NUM_BATCH = 8

tf = dino_image_transform(img_size=IMG_SIZE)

ds_train = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="train", transform=tf, text_col=TEXT_COL)
ds_valid = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="valid", transform=tf, text_col=TEXT_COL)
ds_test = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="test", transform=tf, text_col=TEXT_COL)

#labels = pd.read_csv(CSV_PATH)[TEXT_COL].tolist()

tokenizer = build_tokenizer_from_labels(captions=None)
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
bos_id = tokenizer.bos_token_id
collate_fn = CaptionCollate(tokenizer, pad_id)

train_loader = DataLoader(ds_train, batch_size=NUM_BATCH, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(ds_valid, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(ds_test, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)

Using device: cuda
[INFO] Kept 47494/223462 rows with existing PNGs
[INFO] Kept 47494/223462 rows with existing PNGs
[INFO] Kept 47494/223462 rows with existing PNGs


In [3]:
tokenizer_size = tokenizer.vocab_size
print("Tokenizer size:", tokenizer_size)

Tokenizer size: 58996


# Model

In [4]:
# DINO ViT-S/16 hidden size is 384 
EMBEDDING_D_IMG = 384
N_PREFIX = (IMG_SIZE // 16) ** 2  # number of visual prefix tokens (including CLS)

def pick_heads(d_model, target_head_dim=64):
    h = max(1, round(d_model / target_head_dim))
    while d_model % h != 0: h -= 1
    return h

D_MODEL = 768
N_HEAD = pick_heads(D_MODEL, 64)  # -> 12


model = DinoGPTCaptioner(
    vocab_size=tokenizer.vocab_size,
    d_img=EMBEDDING_D_IMG,
    pad_id=pad_id,
    d_model=D_MODEL,
    n_layer=12,
    n_head=N_HEAD,
    n_prefix=N_PREFIX,           # number of visual prefix tokens
    max_seq_len=512,
    dino_model_id="facebook/dinov3-vits16-pretrain-lvd1689m",
    freeze_dino=True,
).to(device)

# Print model parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total model parameters: {total_params / 1_000_000:.2f} Millions")

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable model parameters: {trainable_params / 1_000_000:.2f} Millions")

# Print model footprint
model_footprint_in_gb = (total_params * 4) * (1e-9)  # assuming 4 bytes per parameter (float32)
print(f"Approximate model footprint: {model_footprint_in_gb:.2f} GB")

# after model init
model.decoder.lm_head.weight = model.decoder.tok_emb.weight  # weight tying

Total model parameters: 198.08 Millions
Trainable model parameters: 176.48 Millions
Approximate model footprint: 0.79 GB


# Train Parameters

In [5]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2
)
loss = sequence_ce_loss
NUM_EPOCHS = 100
BATCHES_PER_EPOCH = 10

# Training

In [6]:
import time
time_start = time.time()
for epoch in range(NUM_EPOCHS):
    slice_train_loader = islice(train_loader, BATCHES_PER_EPOCH)
    slice_valid_loader = islice(valid_loader, BATCHES_PER_EPOCH)
    train_stats = train_one_epoch(model, slice_train_loader, optimizer, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss, grad_clip=1.0)
    val_stats = evaluate(model, slice_valid_loader, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss)
    print(f"Epoch {epoch + 1}: Train Loss={train_stats['loss']:.4f}, PPL={train_stats['ppl']:.2f} | "
            f"Val Loss={val_stats['val_loss']:.4f}, Val PPL={val_stats['val_ppl']:.2f}")
training_time = time.time() - time_start

Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]


Epoch 1: Train Loss=9.5661, PPL=19159.79 | Val Loss=8.3382, Val PPL=4230.72


Training: 100%|██████████| 10/10 [00:04<00:00,  2.01it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.31it/s]


Epoch 2: Train Loss=7.8289, PPL=2561.26 | Val Loss=7.5083, Val PPL=1862.58


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.94it/s]


Epoch 3: Train Loss=7.4871, PPL=1828.95 | Val Loss=7.4384, Val PPL=1731.43


Training: 100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.01it/s]


Epoch 4: Train Loss=7.3127, PPL=1517.05 | Val Loss=7.3380, Val PPL=1566.00


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 5: Train Loss=7.2270, PPL=1391.35 | Val Loss=7.2024, Val PPL=1370.34


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 6: Train Loss=7.1741, PPL=1347.62 | Val Loss=7.0020, Val PPL=1126.62


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Epoch 7: Train Loss=6.8880, PPL=976.55 | Val Loss=6.7528, Val PPL=886.24


Training: 100%|██████████| 10/10 [00:05<00:00,  1.96it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]


Epoch 8: Train Loss=6.5891, PPL=751.70 | Val Loss=6.4879, Val PPL=679.97


Training: 100%|██████████| 10/10 [00:05<00:00,  1.97it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 9: Train Loss=6.5090, PPL=702.30 | Val Loss=6.2712, Val PPL=550.43


Training: 100%|██████████| 10/10 [00:05<00:00,  1.99it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 10: Train Loss=6.1295, PPL=464.15 | Val Loss=6.1091, Val PPL=471.01


Training: 100%|██████████| 10/10 [00:05<00:00,  1.95it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 11: Train Loss=6.2127, PPL=514.06 | Val Loss=5.9493, Val PPL=401.50


Training: 100%|██████████| 10/10 [00:05<00:00,  1.86it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 12: Train Loss=5.9656, PPL=405.46 | Val Loss=5.7982, Val PPL=351.18


Training: 100%|██████████| 10/10 [00:05<00:00,  1.93it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


Epoch 13: Train Loss=5.7674, PPL=326.29 | Val Loss=5.6676, Val PPL=309.93


Training: 100%|██████████| 10/10 [00:05<00:00,  1.99it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 14: Train Loss=5.8936, PPL=369.87 | Val Loss=5.5743, Val PPL=282.77


Training: 100%|██████████| 10/10 [00:04<00:00,  2.05it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.30it/s]


Epoch 15: Train Loss=5.5744, PPL=270.72 | Val Loss=5.5136, Val PPL=264.16


Training: 100%|██████████| 10/10 [00:05<00:00,  1.94it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 16: Train Loss=5.3193, PPL=208.29 | Val Loss=5.4165, Val PPL=238.02


Training: 100%|██████████| 10/10 [00:05<00:00,  1.99it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


Epoch 17: Train Loss=5.4271, PPL=233.59 | Val Loss=5.3473, Val PPL=223.22


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 18: Train Loss=5.6804, PPL=305.99 | Val Loss=5.3119, Val PPL=217.36


Training: 100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 19: Train Loss=5.2464, PPL=197.15 | Val Loss=5.2457, Val PPL=204.52


Training: 100%|██████████| 10/10 [00:05<00:00,  1.93it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 20: Train Loss=5.3472, PPL=221.54 | Val Loss=5.2191, Val PPL=198.48


Training: 100%|██████████| 10/10 [00:05<00:00,  1.99it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]


Epoch 21: Train Loss=5.3776, PPL=218.93 | Val Loss=5.1759, Val PPL=189.02


Training: 100%|██████████| 10/10 [00:05<00:00,  1.94it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 22: Train Loss=5.2944, PPL=206.76 | Val Loss=5.1666, Val PPL=187.58


Training: 100%|██████████| 10/10 [00:05<00:00,  1.92it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 23: Train Loss=5.0902, PPL=169.63 | Val Loss=5.1399, Val PPL=183.08


Training: 100%|██████████| 10/10 [00:05<00:00,  1.92it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 24: Train Loss=5.1629, PPL=178.68 | Val Loss=5.1131, Val PPL=176.80


Training: 100%|██████████| 10/10 [00:05<00:00,  1.93it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 25: Train Loss=5.1285, PPL=176.86 | Val Loss=5.0789, Val PPL=169.84


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 26: Train Loss=5.2163, PPL=189.85 | Val Loss=5.0760, Val PPL=169.34


Training: 100%|██████████| 10/10 [00:05<00:00,  1.95it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.31it/s]


Epoch 27: Train Loss=5.1009, PPL=167.57 | Val Loss=5.0382, Val PPL=163.02


Training: 100%|██████████| 10/10 [00:04<00:00,  2.04it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 28: Train Loss=5.0201, PPL=155.88 | Val Loss=5.0186, Val PPL=159.36


Training: 100%|██████████| 10/10 [00:04<00:00,  2.03it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 29: Train Loss=5.0833, PPL=168.76 | Val Loss=5.0301, Val PPL=161.08


Training: 100%|██████████| 10/10 [00:04<00:00,  2.05it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 30: Train Loss=5.0812, PPL=171.86 | Val Loss=5.0109, Val PPL=158.72


Training: 100%|██████████| 10/10 [00:05<00:00,  1.93it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 31: Train Loss=5.0518, PPL=157.35 | Val Loss=4.9732, Val PPL=152.43


Training: 100%|██████████| 10/10 [00:04<00:00,  2.03it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 32: Train Loss=5.1413, PPL=178.76 | Val Loss=4.9479, Val PPL=146.70


Training: 100%|██████████| 10/10 [00:05<00:00,  1.97it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 33: Train Loss=5.0423, PPL=165.52 | Val Loss=4.9117, Val PPL=141.25


Training: 100%|██████████| 10/10 [00:05<00:00,  1.97it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


Epoch 34: Train Loss=4.9128, PPL=138.29 | Val Loss=4.9104, Val PPL=142.14


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 35: Train Loss=4.9662, PPL=152.74 | Val Loss=4.8865, Val PPL=138.59


Training: 100%|██████████| 10/10 [00:05<00:00,  1.91it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 36: Train Loss=5.0353, PPL=155.52 | Val Loss=4.8688, Val PPL=136.37


Training: 100%|██████████| 10/10 [00:05<00:00,  1.93it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 37: Train Loss=5.1004, PPL=172.76 | Val Loss=4.8766, Val PPL=137.60


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 38: Train Loss=4.8806, PPL=135.75 | Val Loss=4.8721, Val PPL=136.02


Training: 100%|██████████| 10/10 [00:05<00:00,  1.99it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 39: Train Loss=5.0328, PPL=155.74 | Val Loss=4.8291, Val PPL=130.99


Training: 100%|██████████| 10/10 [00:04<00:00,  2.02it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 40: Train Loss=4.9926, PPL=153.21 | Val Loss=4.8184, Val PPL=129.08


Training: 100%|██████████| 10/10 [00:05<00:00,  1.92it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 41: Train Loss=4.8700, PPL=134.38 | Val Loss=4.8230, Val PPL=129.88


Training: 100%|██████████| 10/10 [00:05<00:00,  1.90it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 42: Train Loss=4.9255, PPL=145.17 | Val Loss=4.8267, Val PPL=130.99


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 43: Train Loss=4.8479, PPL=131.32 | Val Loss=4.8064, Val PPL=129.04


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 44: Train Loss=4.9548, PPL=145.13 | Val Loss=4.8090, Val PPL=128.52


Training: 100%|██████████| 10/10 [00:05<00:00,  1.86it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 45: Train Loss=4.8624, PPL=134.47 | Val Loss=4.7812, Val PPL=124.75


Training: 100%|██████████| 10/10 [00:04<00:00,  2.02it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


Epoch 46: Train Loss=4.8019, PPL=124.17 | Val Loss=4.7909, Val PPL=126.59


Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 47: Train Loss=4.7668, PPL=125.75 | Val Loss=4.8002, Val PPL=127.58


Training: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 48: Train Loss=4.8410, PPL=132.26 | Val Loss=4.7956, Val PPL=126.13


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 49: Train Loss=4.7683, PPL=120.57 | Val Loss=4.7827, Val PPL=124.53


Training: 100%|██████████| 10/10 [00:05<00:00,  1.92it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 50: Train Loss=4.7495, PPL=116.46 | Val Loss=4.7528, Val PPL=121.26


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 51: Train Loss=4.7659, PPL=120.24 | Val Loss=4.7409, Val PPL=119.97


Training: 100%|██████████| 10/10 [00:04<00:00,  2.02it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 52: Train Loss=4.7762, PPL=123.60 | Val Loss=4.7345, Val PPL=119.83


Training: 100%|██████████| 10/10 [00:05<00:00,  1.91it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 53: Train Loss=4.8255, PPL=127.81 | Val Loss=4.7421, Val PPL=120.07


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 54: Train Loss=4.7706, PPL=125.76 | Val Loss=4.7279, Val PPL=118.49


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 55: Train Loss=4.8070, PPL=125.15 | Val Loss=4.7310, Val PPL=118.11


Training: 100%|██████████| 10/10 [00:05<00:00,  1.86it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Epoch 56: Train Loss=4.7453, PPL=117.48 | Val Loss=4.6946, Val PPL=114.12


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 57: Train Loss=4.8063, PPL=128.01 | Val Loss=4.6655, Val PPL=110.96


Training: 100%|██████████| 10/10 [00:05<00:00,  1.99it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 58: Train Loss=4.8166, PPL=129.32 | Val Loss=4.6652, Val PPL=110.79


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 59: Train Loss=4.6743, PPL=109.76 | Val Loss=4.6577, Val PPL=110.16


Training: 100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 60: Train Loss=4.6186, PPL=103.32 | Val Loss=4.6552, Val PPL=109.86


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 61: Train Loss=4.7345, PPL=115.55 | Val Loss=4.6444, Val PPL=108.74


Training: 100%|██████████| 10/10 [00:05<00:00,  1.91it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 62: Train Loss=4.8699, PPL=132.56 | Val Loss=4.6333, Val PPL=107.39


Training: 100%|██████████| 10/10 [00:05<00:00,  1.92it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 63: Train Loss=4.7985, PPL=124.14 | Val Loss=4.6426, Val PPL=107.74


Training: 100%|██████████| 10/10 [00:05<00:00,  1.96it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 64: Train Loss=4.6981, PPL=110.76 | Val Loss=4.6430, Val PPL=108.30


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 65: Train Loss=4.6109, PPL=102.65 | Val Loss=4.6454, Val PPL=108.60


Training: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 66: Train Loss=4.7564, PPL=118.97 | Val Loss=4.6487, Val PPL=108.22


Training: 100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 67: Train Loss=4.6833, PPL=109.75 | Val Loss=4.6570, Val PPL=109.26


Training: 100%|██████████| 10/10 [00:05<00:00,  1.97it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 68: Train Loss=4.7128, PPL=116.60 | Val Loss=4.6336, Val PPL=107.36


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 69: Train Loss=4.6403, PPL=107.64 | Val Loss=4.6361, Val PPL=107.14


Training: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 70: Train Loss=4.6304, PPL=105.76 | Val Loss=4.6207, Val PPL=105.45


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 71: Train Loss=4.6514, PPL=107.33 | Val Loss=4.6348, Val PPL=107.12


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 72: Train Loss=4.6242, PPL=105.06 | Val Loss=4.6161, Val PPL=104.96


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 73: Train Loss=4.8019, PPL=125.37 | Val Loss=4.6193, Val PPL=105.40


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 74: Train Loss=4.6755, PPL=108.96 | Val Loss=4.6036, Val PPL=103.98


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 75: Train Loss=4.6893, PPL=111.45 | Val Loss=4.6182, Val PPL=105.26


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


Epoch 76: Train Loss=4.7083, PPL=115.66 | Val Loss=4.5859, Val PPL=102.03


Training: 100%|██████████| 10/10 [00:05<00:00,  1.90it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 77: Train Loss=4.6301, PPL=105.51 | Val Loss=4.5975, Val PPL=103.03


Training: 100%|██████████| 10/10 [00:05<00:00,  1.86it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 78: Train Loss=4.7975, PPL=124.38 | Val Loss=4.5807, Val PPL=101.13


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 79: Train Loss=4.5658, PPL=97.33 | Val Loss=4.5775, Val PPL=100.79


Training: 100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 80: Train Loss=4.6886, PPL=113.90 | Val Loss=4.5807, Val PPL=101.64


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 81: Train Loss=4.5505, PPL=98.52 | Val Loss=4.5836, Val PPL=101.93


Training: 100%|██████████| 10/10 [00:05<00:00,  1.96it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


Epoch 82: Train Loss=4.6525, PPL=107.22 | Val Loss=4.5700, Val PPL=100.29


Training: 100%|██████████| 10/10 [00:05<00:00,  1.90it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 83: Train Loss=4.6386, PPL=107.31 | Val Loss=4.5857, Val PPL=101.76


Training: 100%|██████████| 10/10 [00:05<00:00,  1.93it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 84: Train Loss=4.6720, PPL=107.95 | Val Loss=4.5691, Val PPL=100.41


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 85: Train Loss=4.7284, PPL=117.44 | Val Loss=4.5649, Val PPL=99.73


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 86: Train Loss=4.5829, PPL=100.99 | Val Loss=4.5806, Val PPL=101.37


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 87: Train Loss=4.5937, PPL=101.29 | Val Loss=4.5575, Val PPL=99.68


Training: 100%|██████████| 10/10 [00:05<00:00,  1.94it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 88: Train Loss=4.6470, PPL=107.52 | Val Loss=4.5560, Val PPL=99.33


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 89: Train Loss=4.5924, PPL=103.91 | Val Loss=4.5549, Val PPL=99.15


Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 90: Train Loss=4.5257, PPL=94.70 | Val Loss=4.5453, Val PPL=98.38


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 91: Train Loss=4.5900, PPL=102.49 | Val Loss=4.5441, Val PPL=97.72


Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 92: Train Loss=4.6121, PPL=103.60 | Val Loss=4.5339, Val PPL=96.86


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 93: Train Loss=4.6042, PPL=103.45 | Val Loss=4.5411, Val PPL=97.69


Training: 100%|██████████| 10/10 [00:05<00:00,  1.98it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 94: Train Loss=4.6087, PPL=102.22 | Val Loss=4.5296, Val PPL=96.55


Training: 100%|██████████| 10/10 [00:05<00:00,  1.94it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 95: Train Loss=4.6308, PPL=105.69 | Val Loss=4.5330, Val PPL=96.61


Training: 100%|██████████| 10/10 [00:05<00:00,  1.95it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 96: Train Loss=4.5826, PPL=100.37 | Val Loss=4.5281, Val PPL=96.22


Training: 100%|██████████| 10/10 [00:05<00:00,  1.95it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 97: Train Loss=4.6005, PPL=101.23 | Val Loss=4.5330, Val PPL=96.77


Training: 100%|██████████| 10/10 [00:05<00:00,  1.90it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 98: Train Loss=4.6557, PPL=107.24 | Val Loss=4.5327, Val PPL=96.76


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 99: Train Loss=4.7361, PPL=117.86 | Val Loss=4.5205, Val PPL=95.47


Training: 100%|██████████| 10/10 [00:05<00:00,  1.99it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]

Epoch 100: Train Loss=4.6628, PPL=109.61 | Val Loss=4.5331, Val PPL=96.71





# Test Parameters

In [7]:
BATCHES_PER_TEST = 1
GREEDY_DECODE = True
TEST_MAX_LEN = 256
TEST_TOP_P = 0.9
TEST_TEMPERATURE = 0.9

# Test

In [8]:
slice_test_loader = islice(test_loader, BATCHES_PER_TEST)
test_stats = evaluate(model, slice_test_loader, device, pad_id, num_batches=BATCHES_PER_TEST)
print(f"Test Loss={test_stats['val_loss']:.4f}, Test PPL={test_stats['val_ppl']:.2f}")

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]

Test Loss=4.4065, Test PPL=81.98





# Test Report Generation

In [9]:
# capitalize first word and first word after each "."
def capitalize_sentences(s):
    parts = s.split('. ')
    parts = [p[:1].upper() + p[1:] if p else '' for p in parts]
    return '. '.join(parts)

generated_text = []
target_text = []
iteration = 0

with torch.no_grad():
    for pixel_values, ids_loader, paths, raw_labels in test_loader:
        iteration += 1
        pixel_values = pixel_values.to(device)
        # gen_ids = model.generate(
        #     pixel_values=pixel_values,
        #     bos_id=bos_id,
        #     eos_id=eos_id,
        #     max_new_tokens=TEST_MAX_LEN,
        #     beam_size=3,                # Set your desired beam size
        #     temperature=TEST_TEMPERATURE
        # )

        info = model.generate_with_logging(
            pixel_values=pixel_values,          # [B, C, H, W]
            bos_id=tokenizer.bos_token_id,
            eos_id=tokenizer.eos_token_id,
            tokenizer=tokenizer,
            preset="safe_sample",
            stop_sequences=["\n\n", "Impression:"],
            max_new_tokens=256,
        )

        # print("batch sequences shape:", info["sequences"].shape)
        # for i, s in enumerate(info["per_sample"]):
        #     print(f"[sample {i}] hit_eos={s['stopping']['hit_eos']} repetition={s['repetition']}")
        #     if "generated" in s["text"]:
        #         print(capitalize_sentences(s["text"]["generated"]))
        #         print("[Target text]", capitalize_sentences(raw_labels[i]))

        generated_text.extend([s["text"]["generated"] for s in info["per_sample"]])
        target_text.extend(raw_labels)

        # eval_results = evaluate_all_metrics(raw_labels, [s["text"]["generated"] for s in info["per_sample"]], evaluation_mode="CheXagent")
        # for metric, scores in eval_results.items():
        #     print(f"{metric}: {scores}")

        # save_metrics_to_json(eval_results, f"lstm-vs-gpt/results/gpt_model_results_{NUM_EPOCHS}_Chexpert.json")


        # print("Predictions (first batch):")
        # for i in range(gen_ids.size(0)):
        #     text_gen = tokenizer.decode(gen_ids[i].tolist())
        #     text_tgt = tokenizer.decode(ids_loader[i].tolist())
        #     print(f"\nGEN {i+1}:", capitalize_sentences(text_gen))
        #     print(f"TGT {i+1}:", capitalize_sentences(text_tgt))
        #     results = evaluate_all_metrics([text_tgt], [text_gen], evaluation_mode="CheXagent")
        #     for metric, scores in results.items():
        #         print(f"{metric}: {scores}")
        del pixel_values, ids_loader, paths, raw_labels, info
        torch.cuda.empty_cache()
        # break
        if iteration >= 10:  # Limit to 10 iterations for testing
            break

eval_results = evaluate_all_metrics(generated_text, target_text, evaluation_mode="CheXagent")
for metric, scores in eval_results.items():
    print(f"{metric}: {scores}")
eval_results["training_time_seconds"] = training_time
save_metrics_to_json(eval_results, f"lstm-vs-gpt/results/gpt_model_results_{NUM_EPOCHS}_Chexpert.json")

Using device: cuda:0
chexbert_f1_weighted: 0.36491390171440097
chexbert_f1_micro: 0.3927855711422846
chexbert_f1_macro: 0.20486577822308455
chexbert_f1_micro_5: 0.35125448028673834
chexbert_f1_macro_5: 0.3219951063429324
radgraph_f1_RG_E: 0.11892152991772556
radgraph_f1_RG_ER: 0.10539613373027526


In [10]:
import re
import string

def clean_text(text: str) -> str:
    # lowercase
    text = text.lower()

    # remove enumerators like "1." or "23." but KEEP decimals like "2.5"
    text = re.sub(r'(?<!\d)\b\d+\.(?!\d)', ' ', text)

    # remove all punctuation EXCEPT "."
    punctuation = string.punctuation.replace('.', '')
    text = text.translate(str.maketrans('', '', punctuation))

    # normalize spaces around periods to ". " → ". "
    text = re.sub(r'\s*\.\s*', '. ', text)

    # collapse multiple spaces and trim
    text = re.sub(r'\s+', ' ', text).strip()

    # capitalize first word and first word after each "."
    def capitalize_sentences(s):
        parts = s.split('. ')
        parts = [p[:1].upper() + p[1:] if p else '' for p in parts]
        return '. '.join(parts)
    text = capitalize_sentences(text)

    return text

# Example
text = "1.  STABLE SMALL LEFT INTERNAL JUGULAR OPACITIES... 2.5 cm nodule; item 2. next. 3. Done."
print(clean_text(text))


# Example
text = """
 1.  INTERVAL PLACEMENT OF A RIGHT INTERNAL JUGULAR VENOUS SHEATH 
WITH THE DISTAL TIP IN THE PROXIMAL SUPERIOR VENA CAVA.  NO 
PNEUMOTHORAX.
 
 2.  STABLE POSITION OF NASOGASTRIC TUBE, FEEDING TUBE, TRACHEOSTOMY 
CANULA, LEFT INTERNAL JUGULAR CENTRAL VENOUS CATHETER, AND LEFT UPPER 
EXTREMITY PICC.  
 
 3.  NO SIGNIFICANT INTERVAL CHANGE IN HYPEREXPANDED LUNG VOLUMES, 
RIGHT BASILAR OPACITIES, SMALL BILATERAL PLEURAL EFFUSIONS, TENTING 
OF THE RIGHT HEMIDIAPHRAGM AND BIAPICAL PLEURAL THICKENING. 
 
 """
cleaned_text = clean_text(text)
print(cleaned_text)


Stable small left internal jugular opacities. . . 2. 5 cm nodule item next. Done.
Interval placement of a right internal jugular venous sheath with the distal tip in the proximal superior vena cava. No pneumothorax. Stable position of nasogastric tube feeding tube tracheostomy canula left internal jugular central venous catheter and left upper extremity picc. No significant interval change in hyperexpanded lung volumes right basilar opacities small bilateral pleural effusions tenting of the right hemidiaphragm and biapical pleural thickening.


In [11]:
text = "1.  STABLE SMALL LEFT INTERNAL JUGULAR OPACITIES WITH PATCHY TUBE AND NASOGASTRIC TUBES, RIGHT LOWER MEDIASTINAL SIDED CATHETER.  NO SIGNIFICANT CHANGE IN THE PREVIOUS STUDYDEMONSTRATE ATELECTASIS O"
print("Original text:", text)
encoded = tokenizer.encode(text)
words = text.split()
print("Number of words:", len(words), "Number of tokens:", len(encoded), "pad_id:", pad_id, "eos_id:", eos_id, "bos_id:", bos_id)
print("BOS token id:", tokenizer.bos_token_id, "EOS token id:", tokenizer.eos_token_id, "PAD token id:", tokenizer.pad_token_id)
print(encoded)
for token_id in encoded:    
    print(f"Token ID: {token_id}, Token: {tokenizer.decode([token_id])}")

print("\nAfter lowercasing:")
textlower = cleaned_text
encoded = tokenizer.encode(textlower)
words = textlower.split()
print("Number of words:", len(words), "Number of tokens:", len(encoded), "pad_id:", pad_id, "eos_id:", eos_id, "bos_id:", bos_id)
print("BOS token id:", tokenizer.bos_token_id, "EOS token id:", tokenizer.eos_token_id, "PAD token id:", tokenizer.pad_token_id)
print(encoded)
for token_id in encoded:    
    print(f"Token ID: {token_id}, Token: {tokenizer.decode([token_id])}")

Original text: 1.  STABLE SMALL LEFT INTERNAL JUGULAR OPACITIES WITH PATCHY TUBE AND NASOGASTRIC TUBES, RIGHT LOWER MEDIASTINAL SIDED CATHETER.  NO SIGNIFICANT CHANGE IN THE PREVIOUS STUDYDEMONSTRATE ATELECTASIS O
Number of words: 27 Number of tokens: 43 pad_id: 0 eos_id: 102 bos_id: 101
BOS token id: 101 EOS token id: 102 PAD token id: 0
[101, 122, 119, 6111, 1353, 1286, 4422, 34986, 5552, 39280, 49176, 1114, 10085, 1183, 7159, 1105, 9468, 7301, 32519, 11182, 117, 1268, 2211, 2394, 34979, 7050, 11641, 5855, 30682, 119, 1185, 2418, 1849, 1107, 1103, 2166, 2025, 31386, 8756, 18465, 14229, 184, 102]
Token ID: 101, Token: 
Token ID: 122, Token: 1
Token ID: 119, Token: .
Token ID: 6111, Token: stable
Token ID: 1353, Token: small
Token ID: 1286, Token: left
Token ID: 4422, Token: internal
Token ID: 34986, Token: jug
Token ID: 5552, Token: ##ular
Token ID: 39280, Token: opa
Token ID: 49176, Token: ##cities
Token ID: 1114, Token: with
Token ID: 10085, Token: patch
Token ID: 1183, Token: ##y
T