In [None]:
import os
import pandas as pd
from itertools import islice
import torch
from torch.utils.data import DataLoader
import sys
from pathlib import Path
sys.path.append(str(Path().resolve().parents[1]))
from utils.text_metrics import evaluate_all_metrics, save_metrics_to_json
from utils.train_comparison import *
from utils.processing import image_transform
from utils.models.gpt_models import DinoGPTCaptioner, DinoGPT2Captioner
from utils.data.chexpert_dataset import CheXpertDataset
from utils.data.padchest_dataset import PadChestGRDataset

# Data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

CSV_PATH = "../../../Datasets/CheXpertPlus/df_chexpert_plus_240401.csv"
IMG_ROOT = "../../../Datasets/CheXpertPlus/PNG"

TEXT_COL = "section_impression"
PATH_COL = "path_to_image"

IMG_SIZE = 224
MAX_LEN = 64
NUM_BATCH = 8

tf = image_transform(img_size=IMG_SIZE)

ds_train = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="train", transform=tf, text_col=TEXT_COL)
ds_valid = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="valid", transform=tf, text_col=TEXT_COL)
ds_test = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="test", transform=tf, text_col=TEXT_COL)

tokenizer = build_tokenizer_from_labels(gpt2=True)
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
bos_id = tokenizer.bos_token_id
collate_fn = CaptionCollate(tokenizer, pad_id)

train_loader = DataLoader(ds_train, batch_size=NUM_BATCH, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(ds_valid, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(ds_test, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)

Using device: cuda
[INFO] Kept 95718/223462 rows with existing PNGs
[INFO] Kept 95718/223462 rows with existing PNGs
[INFO] Kept 95718/223462 rows with existing PNGs
Using GPT2 tokenizer.


# Model

In [3]:
# DINO ViT-S/16 hidden size is 384 
EMBEDDING_D_IMG = 384
N_PREFIX = (IMG_SIZE // 16) ** 2  # number of visual prefix tokens (including CLS)

def pick_heads(d_model, target_head_dim=64):
    h = max(1, round(d_model / target_head_dim))
    while d_model % h != 0: h -= 1
    return h

D_MODEL = 768
N_HEAD = pick_heads(D_MODEL, 64)  # -> 12


model = DinoGPT2Captioner(
    d_img=EMBEDDING_D_IMG,
    num_prefix_tokens=N_PREFIX,
    gpt2_name="gpt2",
    dino_model_id="facebook/dinov3-vits16-pretrain-lvd1689m",
    freeze_dino=True
).to(device)

# Print model parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total model parameters: {total_params / 1_000_000:.2f} Millions")

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable model parameters: {trainable_params / 1_000_000:.2f} Millions")

# Print model footprint
model_footprint_in_gb = (total_params * 4) * (1e-9)  # assuming 4 bytes per parameter (float32)
print(f"Approximate model footprint: {model_footprint_in_gb:.2f} GB")

# after model init
#model.decoder.lm_head.weight = model.decoder.tok_emb.weight  # weight tying

Total model parameters: 146.33 Millions
Trainable model parameters: 124.74 Millions
Approximate model footprint: 0.59 GB


# Train Parameters

In [4]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2
)
loss = sequence_ce_loss
NUM_EPOCHS = 100
BATCHES_PER_EPOCH = 10

# Training

In [5]:
import time
time_start = time.time()
for epoch in range(NUM_EPOCHS):
    slice_train_loader = islice(train_loader, BATCHES_PER_EPOCH)
    slice_valid_loader = islice(valid_loader, BATCHES_PER_EPOCH)
    train_stats = train_one_epoch(model, slice_train_loader, optimizer, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss, grad_clip=1.0)
    val_stats = evaluate(model, slice_valid_loader, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss)
    print(f"Epoch {epoch + 1}: Train Loss={train_stats['loss']:.4f}, PPL={train_stats['ppl']:.2f} | "
            f"Val Loss={val_stats['val_loss']:.4f}, Val PPL={val_stats['val_ppl']:.2f}")
training_time = time.time() - time_start

Training: 100%|██████████| 10/10 [00:02<00:00,  4.58it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.22it/s]


Epoch 1: Train Loss=6.7158, PPL=1813.97 | Val Loss=5.5622, Val PPL=262.83


Training: 100%|██████████| 10/10 [00:01<00:00,  5.50it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.63it/s]


Epoch 2: Train Loss=5.2775, PPL=199.30 | Val Loss=5.1959, Val PPL=182.40


Training: 100%|██████████| 10/10 [00:01<00:00,  5.29it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.51it/s]


Epoch 3: Train Loss=5.1363, PPL=176.36 | Val Loss=5.0597, Val PPL=159.29


Training: 100%|██████████| 10/10 [00:01<00:00,  5.39it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.45it/s]


Epoch 4: Train Loss=5.0928, PPL=164.14 | Val Loss=4.9212, Val PPL=138.27


Training: 100%|██████████| 10/10 [00:01<00:00,  5.75it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.57it/s]


Epoch 5: Train Loss=4.8543, PPL=132.79 | Val Loss=4.8401, Val PPL=127.69


Training: 100%|██████████| 10/10 [00:01<00:00,  5.93it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.64it/s]


Epoch 6: Train Loss=4.7736, PPL=119.17 | Val Loss=4.8010, Val PPL=122.62


Training: 100%|██████████| 10/10 [00:01<00:00,  5.85it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.51it/s]


Epoch 7: Train Loss=4.6625, PPL=107.18 | Val Loss=4.7699, Val PPL=119.19


Training: 100%|██████████| 10/10 [00:01<00:00,  5.38it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.57it/s]


Epoch 8: Train Loss=4.6174, PPL=102.43 | Val Loss=4.7247, Val PPL=113.77


Training: 100%|██████████| 10/10 [00:01<00:00,  5.30it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.83it/s]


Epoch 9: Train Loss=4.6686, PPL=109.43 | Val Loss=4.6807, Val PPL=109.06


Training: 100%|██████████| 10/10 [00:01<00:00,  5.82it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.70it/s]


Epoch 10: Train Loss=4.5254, PPL=93.41 | Val Loss=4.6412, Val PPL=104.53


Training: 100%|██████████| 10/10 [00:01<00:00,  5.94it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.70it/s]


Epoch 11: Train Loss=4.6875, PPL=111.05 | Val Loss=4.6002, Val PPL=100.28


Training: 100%|██████████| 10/10 [00:01<00:00,  5.76it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.83it/s]


Epoch 12: Train Loss=4.5354, PPL=95.20 | Val Loss=4.6134, Val PPL=101.64


Training: 100%|██████████| 10/10 [00:01<00:00,  5.59it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.55it/s]


Epoch 13: Train Loss=4.6473, PPL=115.21 | Val Loss=4.5891, Val PPL=99.35


Training: 100%|██████████| 10/10 [00:01<00:00,  6.10it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.64it/s]


Epoch 14: Train Loss=4.4294, PPL=84.81 | Val Loss=4.5974, Val PPL=100.01


Training: 100%|██████████| 10/10 [00:01<00:00,  5.87it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 15: Train Loss=4.4812, PPL=92.18 | Val Loss=4.5758, Val PPL=98.13


Training: 100%|██████████| 10/10 [00:01<00:00,  6.00it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 16: Train Loss=4.5033, PPL=92.33 | Val Loss=4.5623, Val PPL=96.77


Training: 100%|██████████| 10/10 [00:01<00:00,  5.60it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.41it/s]


Epoch 17: Train Loss=4.4959, PPL=90.61 | Val Loss=4.5652, Val PPL=97.12


Training: 100%|██████████| 10/10 [00:01<00:00,  5.35it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s]


Epoch 18: Train Loss=4.4418, PPL=86.25 | Val Loss=4.5487, Val PPL=95.54


Training: 100%|██████████| 10/10 [00:01<00:00,  5.40it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.52it/s]


Epoch 19: Train Loss=4.5314, PPL=93.68 | Val Loss=4.4959, Val PPL=90.43


Training: 100%|██████████| 10/10 [00:01<00:00,  5.57it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.48it/s]


Epoch 20: Train Loss=4.4347, PPL=85.43 | Val Loss=4.5064, Val PPL=91.65


Training: 100%|██████████| 10/10 [00:01<00:00,  5.35it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.41it/s]


Epoch 21: Train Loss=4.4903, PPL=90.43 | Val Loss=4.4876, Val PPL=89.71


Training: 100%|██████████| 10/10 [00:01<00:00,  5.60it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.36it/s]


Epoch 22: Train Loss=4.4232, PPL=85.78 | Val Loss=4.4971, Val PPL=90.49


Training: 100%|██████████| 10/10 [00:01<00:00,  5.60it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.46it/s]


Epoch 23: Train Loss=4.4943, PPL=90.68 | Val Loss=4.4764, Val PPL=88.62


Training: 100%|██████████| 10/10 [00:01<00:00,  5.10it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.37it/s]


Epoch 24: Train Loss=4.4069, PPL=84.59 | Val Loss=4.4611, Val PPL=87.28


Training: 100%|██████████| 10/10 [00:01<00:00,  5.71it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.76it/s]


Epoch 25: Train Loss=4.4546, PPL=87.87 | Val Loss=4.4413, Val PPL=85.62


Training: 100%|██████████| 10/10 [00:01<00:00,  5.60it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 26: Train Loss=4.3386, PPL=78.25 | Val Loss=4.4431, Val PPL=85.82


Training: 100%|██████████| 10/10 [00:01<00:00,  5.31it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.61it/s]


Epoch 27: Train Loss=4.4538, PPL=87.60 | Val Loss=4.4178, Val PPL=83.72


Training: 100%|██████████| 10/10 [00:01<00:00,  6.08it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 28: Train Loss=4.3600, PPL=80.06 | Val Loss=4.4073, Val PPL=82.76


Training: 100%|██████████| 10/10 [00:01<00:00,  6.19it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.23it/s]


Epoch 29: Train Loss=4.3091, PPL=76.45 | Val Loss=4.4047, Val PPL=82.52


Training: 100%|██████████| 10/10 [00:01<00:00,  5.82it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 30: Train Loss=4.2779, PPL=73.19 | Val Loss=4.4061, Val PPL=82.65


Training: 100%|██████████| 10/10 [00:01<00:00,  6.06it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 31: Train Loss=4.3537, PPL=79.87 | Val Loss=4.3873, Val PPL=81.13


Training: 100%|██████████| 10/10 [00:01<00:00,  5.77it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 32: Train Loss=4.3062, PPL=74.32 | Val Loss=4.3871, Val PPL=81.14


Training: 100%|██████████| 10/10 [00:01<00:00,  5.53it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s]


Epoch 33: Train Loss=4.3404, PPL=78.82 | Val Loss=4.3871, Val PPL=81.11


Training: 100%|██████████| 10/10 [00:01<00:00,  5.60it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.57it/s]


Epoch 34: Train Loss=4.3125, PPL=76.94 | Val Loss=4.3906, Val PPL=81.32


Training: 100%|██████████| 10/10 [00:01<00:00,  5.82it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s]


Epoch 35: Train Loss=4.2836, PPL=73.74 | Val Loss=4.3701, Val PPL=79.63


Training: 100%|██████████| 10/10 [00:01<00:00,  5.55it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 36: Train Loss=4.3928, PPL=83.31 | Val Loss=4.3571, Val PPL=78.69


Training: 100%|██████████| 10/10 [00:01<00:00,  5.80it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.65it/s]


Epoch 37: Train Loss=4.2521, PPL=71.29 | Val Loss=4.3486, Val PPL=77.99


Training: 100%|██████████| 10/10 [00:01<00:00,  5.66it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 38: Train Loss=4.2698, PPL=73.11 | Val Loss=4.3498, Val PPL=78.01


Training: 100%|██████████| 10/10 [00:01<00:00,  5.85it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.73it/s]


Epoch 39: Train Loss=4.3112, PPL=75.72 | Val Loss=4.3412, Val PPL=77.30


Training: 100%|██████████| 10/10 [00:01<00:00,  5.82it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.24it/s]


Epoch 40: Train Loss=4.1995, PPL=67.66 | Val Loss=4.3434, Val PPL=77.43


Training: 100%|██████████| 10/10 [00:01<00:00,  5.60it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.71it/s]


Epoch 41: Train Loss=4.2707, PPL=71.94 | Val Loss=4.3350, Val PPL=76.97


Training: 100%|██████████| 10/10 [00:01<00:00,  5.22it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.55it/s]


Epoch 42: Train Loss=4.3207, PPL=76.45 | Val Loss=4.3397, Val PPL=77.39


Training: 100%|██████████| 10/10 [00:01<00:00,  5.88it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.52it/s]


Epoch 43: Train Loss=4.3443, PPL=77.67 | Val Loss=4.3223, Val PPL=76.08


Training: 100%|██████████| 10/10 [00:01<00:00,  5.51it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 44: Train Loss=4.3410, PPL=77.61 | Val Loss=4.3110, Val PPL=75.18


Training: 100%|██████████| 10/10 [00:01<00:00,  5.85it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.61it/s]


Epoch 45: Train Loss=4.2869, PPL=73.34 | Val Loss=4.3215, Val PPL=76.02


Training: 100%|██████████| 10/10 [00:01<00:00,  5.93it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.65it/s]


Epoch 46: Train Loss=4.2020, PPL=68.75 | Val Loss=4.3145, Val PPL=75.42


Training: 100%|██████████| 10/10 [00:01<00:00,  6.40it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.48it/s]


Epoch 47: Train Loss=4.2128, PPL=69.20 | Val Loss=4.3289, Val PPL=76.50


Training: 100%|██████████| 10/10 [00:01<00:00,  5.71it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.83it/s]


Epoch 48: Train Loss=4.2781, PPL=73.00 | Val Loss=4.3252, Val PPL=76.30


Training: 100%|██████████| 10/10 [00:01<00:00,  5.68it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s]


Epoch 49: Train Loss=4.2603, PPL=71.88 | Val Loss=4.3025, Val PPL=74.52


Training: 100%|██████████| 10/10 [00:01<00:00,  5.96it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 50: Train Loss=4.1643, PPL=65.51 | Val Loss=4.3101, Val PPL=75.09


Training: 100%|██████████| 10/10 [00:01<00:00,  5.26it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.63it/s]


Epoch 51: Train Loss=4.2839, PPL=74.08 | Val Loss=4.3022, Val PPL=74.43


Training: 100%|██████████| 10/10 [00:01<00:00,  5.74it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.47it/s]


Epoch 52: Train Loss=4.3505, PPL=78.92 | Val Loss=4.2994, Val PPL=74.16


Training: 100%|██████████| 10/10 [00:01<00:00,  5.54it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.60it/s]


Epoch 53: Train Loss=4.1660, PPL=65.27 | Val Loss=4.3019, Val PPL=74.36


Training: 100%|██████████| 10/10 [00:01<00:00,  5.76it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.50it/s]


Epoch 54: Train Loss=4.2003, PPL=67.75 | Val Loss=4.2986, Val PPL=74.10


Training: 100%|██████████| 10/10 [00:01<00:00,  5.93it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 55: Train Loss=4.2424, PPL=71.59 | Val Loss=4.3067, Val PPL=74.74


Training: 100%|██████████| 10/10 [00:01<00:00,  6.24it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.33it/s]


Epoch 56: Train Loss=4.2419, PPL=70.07 | Val Loss=4.2917, Val PPL=73.75


Training: 100%|██████████| 10/10 [00:01<00:00,  5.93it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 57: Train Loss=4.2449, PPL=70.77 | Val Loss=4.2820, Val PPL=72.96


Training: 100%|██████████| 10/10 [00:01<00:00,  5.87it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.71it/s]


Epoch 58: Train Loss=4.1035, PPL=61.56 | Val Loss=4.2960, Val PPL=73.92


Training: 100%|██████████| 10/10 [00:01<00:00,  5.57it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.65it/s]


Epoch 59: Train Loss=4.1801, PPL=66.59 | Val Loss=4.2896, Val PPL=73.49


Training: 100%|██████████| 10/10 [00:01<00:00,  5.63it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.66it/s]


Epoch 60: Train Loss=4.2938, PPL=75.16 | Val Loss=4.2739, Val PPL=72.30


Training: 100%|██████████| 10/10 [00:01<00:00,  5.82it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.79it/s]


Epoch 61: Train Loss=4.1769, PPL=66.16 | Val Loss=4.2619, Val PPL=71.35


Training: 100%|██████████| 10/10 [00:01<00:00,  5.98it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.54it/s]


Epoch 62: Train Loss=4.1991, PPL=68.40 | Val Loss=4.2783, Val PPL=72.60


Training: 100%|██████████| 10/10 [00:01<00:00,  6.12it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.87it/s]


Epoch 63: Train Loss=4.1764, PPL=66.63 | Val Loss=4.2646, Val PPL=71.63


Training: 100%|██████████| 10/10 [00:01<00:00,  5.60it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.75it/s]


Epoch 64: Train Loss=4.2638, PPL=72.67 | Val Loss=4.2567, Val PPL=71.14


Training: 100%|██████████| 10/10 [00:01<00:00,  6.31it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.51it/s]


Epoch 65: Train Loss=4.1963, PPL=67.15 | Val Loss=4.2588, Val PPL=71.37


Training: 100%|██████████| 10/10 [00:01<00:00,  5.72it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s]


Epoch 66: Train Loss=4.2914, PPL=73.51 | Val Loss=4.2438, Val PPL=70.23


Training: 100%|██████████| 10/10 [00:01<00:00,  6.15it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.74it/s]


Epoch 67: Train Loss=4.1556, PPL=64.96 | Val Loss=4.2425, Val PPL=70.10


Training: 100%|██████████| 10/10 [00:01<00:00,  5.74it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.78it/s]


Epoch 68: Train Loss=4.1287, PPL=62.92 | Val Loss=4.2560, Val PPL=71.14


Training: 100%|██████████| 10/10 [00:01<00:00,  6.31it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.64it/s]


Epoch 69: Train Loss=4.1934, PPL=67.21 | Val Loss=4.2635, Val PPL=71.60


Training: 100%|██████████| 10/10 [00:01<00:00,  5.94it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.74it/s]


Epoch 70: Train Loss=4.2171, PPL=69.16 | Val Loss=4.2425, Val PPL=70.03


Training: 100%|██████████| 10/10 [00:01<00:00,  6.00it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 71: Train Loss=4.2311, PPL=69.97 | Val Loss=4.2368, Val PPL=69.69


Training: 100%|██████████| 10/10 [00:01<00:00,  5.44it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.64it/s]


Epoch 72: Train Loss=4.1958, PPL=67.03 | Val Loss=4.2339, Val PPL=69.44


Training: 100%|██████████| 10/10 [00:01<00:00,  5.62it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s]


Epoch 73: Train Loss=4.2008, PPL=67.56 | Val Loss=4.2290, Val PPL=69.26


Training: 100%|██████████| 10/10 [00:01<00:00,  5.56it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 74: Train Loss=4.1714, PPL=65.41 | Val Loss=4.2317, Val PPL=69.37


Training: 100%|██████████| 10/10 [00:01<00:00,  5.98it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 75: Train Loss=4.1853, PPL=67.27 | Val Loss=4.2276, Val PPL=69.14


Training: 100%|██████████| 10/10 [00:01<00:00,  5.88it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.36it/s]


Epoch 76: Train Loss=4.0893, PPL=61.00 | Val Loss=4.2255, Val PPL=68.95


Training: 100%|██████████| 10/10 [00:01<00:00,  6.02it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.81it/s]


Epoch 77: Train Loss=4.1423, PPL=63.58 | Val Loss=4.2174, Val PPL=68.39


Training: 100%|██████████| 10/10 [00:01<00:00,  6.00it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.53it/s]


Epoch 78: Train Loss=4.1456, PPL=63.99 | Val Loss=4.2208, Val PPL=68.65


Training: 100%|██████████| 10/10 [00:01<00:00,  5.82it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.85it/s]


Epoch 79: Train Loss=4.1288, PPL=63.29 | Val Loss=4.2212, Val PPL=68.70


Training: 100%|██████████| 10/10 [00:01<00:00,  5.88it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 80: Train Loss=4.0871, PPL=61.13 | Val Loss=4.2355, Val PPL=69.68


Training: 100%|██████████| 10/10 [00:01<00:00,  5.76it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.52it/s]


Epoch 81: Train Loss=4.1837, PPL=66.24 | Val Loss=4.2390, Val PPL=69.95


Training: 100%|██████████| 10/10 [00:01<00:00,  5.71it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.57it/s]


Epoch 82: Train Loss=4.0973, PPL=60.40 | Val Loss=4.2196, Val PPL=68.52


Training: 100%|██████████| 10/10 [00:01<00:00,  5.96it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s]


Epoch 83: Train Loss=4.1457, PPL=63.65 | Val Loss=4.2209, Val PPL=68.58


Training: 100%|██████████| 10/10 [00:01<00:00,  5.94it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.63it/s]


Epoch 84: Train Loss=4.1158, PPL=61.81 | Val Loss=4.2157, Val PPL=68.22


Training: 100%|██████████| 10/10 [00:01<00:00,  5.85it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.63it/s]


Epoch 85: Train Loss=4.1444, PPL=63.95 | Val Loss=4.2125, Val PPL=68.01


Training: 100%|██████████| 10/10 [00:01<00:00,  5.79it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s]


Epoch 86: Train Loss=4.1374, PPL=63.03 | Val Loss=4.2217, Val PPL=68.60


Training: 100%|██████████| 10/10 [00:01<00:00,  5.71it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 87: Train Loss=4.1118, PPL=61.63 | Val Loss=4.2185, Val PPL=68.39


Training: 100%|██████████| 10/10 [00:01<00:00,  5.66it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.77it/s]


Epoch 88: Train Loss=4.1953, PPL=68.14 | Val Loss=4.2093, Val PPL=67.76


Training: 100%|██████████| 10/10 [00:01<00:00,  5.71it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s]


Epoch 89: Train Loss=4.1347, PPL=63.24 | Val Loss=4.1957, Val PPL=66.88


Training: 100%|██████████| 10/10 [00:01<00:00,  5.88it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.22it/s]


Epoch 90: Train Loss=4.1361, PPL=63.82 | Val Loss=4.2030, Val PPL=67.39


Training: 100%|██████████| 10/10 [00:01<00:00,  5.55it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.00it/s]


Epoch 91: Train Loss=4.1446, PPL=63.85 | Val Loss=4.2008, Val PPL=67.27


Training: 100%|██████████| 10/10 [00:01<00:00,  5.80it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.51it/s]


Epoch 92: Train Loss=4.1492, PPL=64.31 | Val Loss=4.1922, Val PPL=66.77


Training: 100%|██████████| 10/10 [00:01<00:00,  5.66it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.37it/s]


Epoch 93: Train Loss=4.0192, PPL=56.55 | Val Loss=4.1844, Val PPL=66.17


Training: 100%|██████████| 10/10 [00:01<00:00,  5.87it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.44it/s]


Epoch 94: Train Loss=4.1705, PPL=65.42 | Val Loss=4.1926, Val PPL=66.73


Training: 100%|██████████| 10/10 [00:01<00:00,  5.76it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.25it/s]


Epoch 95: Train Loss=4.0760, PPL=59.81 | Val Loss=4.1829, Val PPL=66.03


Training: 100%|██████████| 10/10 [00:01<00:00,  5.76it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.40it/s]


Epoch 96: Train Loss=4.0928, PPL=61.41 | Val Loss=4.1821, Val PPL=65.93


Training: 100%|██████████| 10/10 [00:01<00:00,  5.66it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.36it/s]


Epoch 97: Train Loss=4.1577, PPL=64.96 | Val Loss=4.1862, Val PPL=66.22


Training: 100%|██████████| 10/10 [00:01<00:00,  5.40it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.39it/s]


Epoch 98: Train Loss=4.0656, PPL=58.62 | Val Loss=4.1857, Val PPL=66.29


Training: 100%|██████████| 10/10 [00:01<00:00,  5.66it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.52it/s]


Epoch 99: Train Loss=4.1006, PPL=60.99 | Val Loss=4.1925, Val PPL=66.71


Training: 100%|██████████| 10/10 [00:01<00:00,  5.26it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.02it/s]

Epoch 100: Train Loss=4.1827, PPL=67.28 | Val Loss=4.1887, Val PPL=66.35





# Test Parameters

In [6]:
BATCHES_PER_TEST = 1
GREEDY_DECODE = True
TEST_MAX_LEN = 256
TEST_TOP_P = 0.9
TEST_TEMPERATURE = 0.9

# Test

In [7]:
slice_test_loader = islice(test_loader, BATCHES_PER_TEST)
test_stats = evaluate(model, slice_test_loader, device, pad_id, num_batches=BATCHES_PER_TEST)
print(f"Test Loss={test_stats['val_loss']:.4f}, Test PPL={test_stats['val_ppl']:.2f}")

Evaluating: 100%|██████████| 1/1 [00:00<00:00, 10.15it/s]

Test Loss=4.2559, Test PPL=70.52





# Test Report Generation

In [None]:
generated_text = []
target_text = []
iteration = 0

with torch.no_grad():
    for pixel_values, ids_loader, paths, raw_labels in test_loader:
        iteration += 1
        pixel_values = pixel_values.to(device)
        # gen_ids = model.generate(
        #     pixel_values=pixel_values,
        #     input_ids=ids_loader.to(device),
        #     max_new_tokens=64
        # ).to(device)

        info = model.generate_with_logging(
            pixel_values=pixel_values,
            input_ids=ids_loader.to(device),
            tokenizer=tokenizer,
            preset="safe_sample",
            stop_sequences=["\n\n", "Impression:"],
            max_new_tokens=128,
        )
        # print("out shape:", info["sequences"].shape)
        # for i, s in enumerate(info["per_sample"]):
        #     print(f"[{i}] EOS={s['stopping']['hit_eos']} rep={s['repetition']}")
        #     print(s["text"].get("generated","")[:200])
        #     print("[Target text]", raw_labels[i])

        # eval_results = evaluate_all_metrics(raw_labels, [s["text"]["generated"] for s in info["per_sample"]], evaluation_mode="CheXagent")
        # for metric, scores in eval_results.items():
        #     print(f"{metric}: {scores}")

        generated_text.extend([s["text"]["generated"] for s in info["per_sample"]])
        target_text.extend(raw_labels)
        # save_metrics_to_json(eval_results, f"./results/gpt2_model_results_{NUM_EPOCHS}_Chexpert.json")

        # print("Predictions (first batch):")
        # for i in range(gen_ids.size(0)):
        #     text_gen = tokenizer.decode(gen_ids[i].tolist())
        #     text_tgt = tokenizer.decode(ids_loader[i].tolist())
        #     print(f"\nGEN {i+1}:", text_gen)
        #     print(f"TGT {i+1}:", text_tgt)
        #     try: 
        #         results = evaluate_all_metrics([text_tgt], [text_gen], evaluation_mode="CheXagent")
        #         for metric, scores in results.items():
        #             print(f"{metric}: {scores}")
        #     except Exception as e:
        #         print("Error in evaluation:", e)
        # del pixel_values, ids_loader, paths, raw_labels, info
        # torch.cuda.empty_cache()
        # break

        if iteration >= 100:  # Limit to 100 iteration for testing
            break

eval_results = evaluate_all_metrics(generated_text, target_text, evaluation_mode="CheXagent")
for metric, scores in eval_results.items():
    print(f"{metric}: {scores}")
eval_results["training_time_seconds"] = training_time
save_metrics_to_json(eval_results, f"./results/gpt2_model_results_{NUM_EPOCHS}_Chexpert.json")

Using device: cuda:0
chexbert_f1_weighted: 0.4496145919513179
chexbert_f1_micro: 0.43919753086419755
chexbert_f1_macro: 0.2433516039883008
chexbert_f1_micro_5: 0.41468531468531467
chexbert_f1_macro_5: 0.3240299027279697
radgraph_f1_RG_E: 0.1445787453931649
radgraph_f1_RG_ER: 0.12288872377250878
