In [None]:
import os
import pandas as pd
from itertools import islice
import torch
from torch.utils.data import DataLoader
import sys
from pathlib import Path
sys.path.append(str(Path().resolve().parents[1]))
from utils.text_metrics import evaluate_all_metrics, save_metrics_to_json
from utils.train_comparison import *
from utils.processing import image_transform
from utils.models.gpt_models import DinoGPTCaptioner, DinoGPT2Captioner
from utils.data.chexpert_dataset import CheXpertDataset
from utils.data.padchest_dataset import PadChestGRDataset

# Data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

root_dir = "../../../Datasets/PadChest-GR/PadChest_GR"
json_file = "../../../Datasets/PadChest-GR/grounded_reports_20240819.json"
csv_path = "../../../Datasets/PadChest-GR/master_table.csv"

df = pd.read_csv(csv_path)

df_train = df[df['split'] == 'train']
df_validation = df[df['split'] == 'validation']
df_test = df[df['split'] == 'test']


IMG_SIZE = 224
MAX_LEN = 64
NUM_BATCH = 8

tf = image_transform(img_size=IMG_SIZE)

ds_train = PadChestGRDataset(
        dataframe=df_train,
        root_dir=root_dir,
        json_file=json_file,
        max_txt_len=MAX_LEN,
        image_size=IMG_SIZE,
        normalize=True,
        transform=None,
        return_paths=False,
        sentence_key="sentence_en",
    )

ds_valid = PadChestGRDataset(
        dataframe=df_validation,
        root_dir=root_dir,
        json_file=json_file,
        max_txt_len=MAX_LEN,
        image_size=IMG_SIZE,
        normalize=True,
        transform=None,
        return_paths=False,
        sentence_key="sentence_en",
    )

ds_test = PadChestGRDataset(
        dataframe=df_test,
        root_dir=root_dir,
        json_file=json_file,
        max_txt_len=MAX_LEN,
        image_size=IMG_SIZE,
        normalize=True,
        transform=None,
        return_paths=False,
        sentence_key="sentence_en",
    )

tokenizer = build_tokenizer_from_labels()
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
bos_id = tokenizer.bos_token_id
collate_fn = CaptionCollate(tokenizer, pad_id)

train_loader = DataLoader(ds_train, batch_size=NUM_BATCH, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(ds_valid, batch_size=NUM_BATCH, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(ds_test, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)

Using device: cuda


# Model

In [3]:
# DINO ViT-S/16 hidden size is 384 
EMBEDDING_D_IMG = 384
N_PREFIX = (IMG_SIZE // 16) ** 2  # number of visual prefix tokens (including CLS)

model = DinoGPTCaptioner(
    vocab_size=tokenizer.vocab_size,
    d_img=EMBEDDING_D_IMG,
    pad_id=pad_id,
    d_model=512,
    n_layer=8,
    n_head=8,
    n_prefix=N_PREFIX,           # number of visual prefix tokens
    max_seq_len=256,
    dino_model_id="facebook/dinov3-vits16-pretrain-lvd1689m",
    freeze_dino=True,
).to(device)

# Train Parameters

In [4]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2
)
loss = sequence_ce_loss
NUM_EPOCHS = 100
BATCHES_PER_EPOCH = 10

# Training

In [5]:
for epoch in range(NUM_EPOCHS):
    slice_train_loader = islice(train_loader, BATCHES_PER_EPOCH)
    slice_valid_loader = islice(valid_loader, BATCHES_PER_EPOCH)
    train_stats = train_one_epoch(model, slice_train_loader, optimizer, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss, grad_clip=1.0)
    val_stats = evaluate(model, slice_valid_loader, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss)
    print(f"Epoch {epoch + 1}: Train Loss={train_stats['loss']:.4f}, PPL={train_stats['ppl']:.2f} | "
            f"Val Loss={val_stats['val_loss']:.4f}, Val PPL={val_stats['val_ppl']:.2f}")

Training: 100%|██████████| 10/10 [00:08<00:00,  1.17it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.26it/s]


Epoch 1: Train Loss=9.6992, PPL=20625.72 | Val Loss=8.5942, Val PPL=5425.54


Training: 100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.49it/s]


Epoch 2: Train Loss=7.9904, PPL=3372.07 | Val Loss=7.2902, Val PPL=1497.76


Training: 100%|██████████| 10/10 [00:08<00:00,  1.19it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]


Epoch 3: Train Loss=6.7708, PPL=898.24 | Val Loss=6.4553, Val PPL=646.61


Training: 100%|██████████| 10/10 [00:08<00:00,  1.23it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]


Epoch 4: Train Loss=6.2606, PPL=530.23 | Val Loss=6.1979, Val PPL=496.70


Training: 100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]


Epoch 5: Train Loss=6.1830, PPL=499.70 | Val Loss=6.1309, Val PPL=486.48


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Epoch 6: Train Loss=5.9977, PPL=416.53 | Val Loss=5.9688, Val PPL=400.28


Training: 100%|██████████| 10/10 [00:07<00:00,  1.27it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]


Epoch 7: Train Loss=5.6678, PPL=295.47 | Val Loss=5.7205, Val PPL=320.28


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.40it/s]


Epoch 8: Train Loss=5.4669, PPL=245.69 | Val Loss=5.4054, Val PPL=234.21


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]


Epoch 9: Train Loss=5.2721, PPL=205.77 | Val Loss=5.4077, Val PPL=241.67


Training: 100%|██████████| 10/10 [00:06<00:00,  1.43it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]


Epoch 10: Train Loss=5.4488, PPL=240.84 | Val Loss=4.8941, Val PPL=138.87


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Epoch 11: Train Loss=5.0470, PPL=166.08 | Val Loss=5.0260, Val PPL=157.91


Training: 100%|██████████| 10/10 [00:07<00:00,  1.27it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Epoch 12: Train Loss=4.9163, PPL=146.48 | Val Loss=4.7921, Val PPL=123.25


Training: 100%|██████████| 10/10 [00:08<00:00,  1.18it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.57it/s]


Epoch 13: Train Loss=4.8483, PPL=133.46 | Val Loss=4.8289, Val PPL=136.12


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.61it/s]


Epoch 14: Train Loss=4.6007, PPL=103.91 | Val Loss=4.6836, Val PPL=118.20


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Epoch 15: Train Loss=4.5176, PPL=93.16 | Val Loss=4.5753, Val PPL=98.88


Training: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.58it/s]


Epoch 16: Train Loss=4.4639, PPL=93.37 | Val Loss=4.4149, Val PPL=85.54


Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]


Epoch 17: Train Loss=4.4978, PPL=98.17 | Val Loss=4.2300, Val PPL=71.48


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.55it/s]


Epoch 18: Train Loss=4.4582, PPL=88.47 | Val Loss=4.3115, Val PPL=82.09


Training: 100%|██████████| 10/10 [00:08<00:00,  1.23it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.56it/s]


Epoch 19: Train Loss=4.2669, PPL=76.31 | Val Loss=4.0999, Val PPL=63.16


Training: 100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Epoch 20: Train Loss=4.4177, PPL=85.22 | Val Loss=4.2391, Val PPL=74.29


Training: 100%|██████████| 10/10 [00:07<00:00,  1.26it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]


Epoch 21: Train Loss=4.1940, PPL=69.40 | Val Loss=4.0653, Val PPL=60.49


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.38it/s]


Epoch 22: Train Loss=4.0549, PPL=60.28 | Val Loss=4.1790, Val PPL=71.09


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]


Epoch 23: Train Loss=3.8512, PPL=48.45 | Val Loss=4.0299, Val PPL=59.14


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]


Epoch 24: Train Loss=3.8651, PPL=48.54 | Val Loss=4.0653, Val PPL=65.27


Training: 100%|██████████| 10/10 [00:08<00:00,  1.19it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]


Epoch 25: Train Loss=3.7327, PPL=42.69 | Val Loss=3.5995, Val PPL=38.15


Training: 100%|██████████| 10/10 [00:07<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]


Epoch 26: Train Loss=3.8870, PPL=50.45 | Val Loss=3.7103, Val PPL=42.04


Training: 100%|██████████| 10/10 [00:08<00:00,  1.20it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.48it/s]


Epoch 27: Train Loss=3.6300, PPL=38.47 | Val Loss=3.8669, Val PPL=51.90


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]


Epoch 28: Train Loss=3.6371, PPL=39.08 | Val Loss=3.9777, Val PPL=57.56


Training: 100%|██████████| 10/10 [00:07<00:00,  1.28it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]


Epoch 29: Train Loss=3.6689, PPL=41.32 | Val Loss=3.8492, Val PPL=48.79


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.46it/s]


Epoch 30: Train Loss=3.7063, PPL=42.31 | Val Loss=3.8203, Val PPL=50.47


Training: 100%|██████████| 10/10 [00:07<00:00,  1.28it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]


Epoch 31: Train Loss=3.6239, PPL=38.24 | Val Loss=3.6574, Val PPL=40.09


Training: 100%|██████████| 10/10 [00:07<00:00,  1.26it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]


Epoch 32: Train Loss=3.4974, PPL=34.34 | Val Loss=3.6310, Val PPL=40.39


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.56it/s]


Epoch 33: Train Loss=3.3583, PPL=29.53 | Val Loss=3.7694, Val PPL=47.05


Training: 100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]


Epoch 34: Train Loss=3.4198, PPL=32.29 | Val Loss=3.3888, Val PPL=31.08


Training: 100%|██████████| 10/10 [00:08<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.40it/s]


Epoch 35: Train Loss=3.4617, PPL=33.57 | Val Loss=3.4639, Val PPL=33.49


Training: 100%|██████████| 10/10 [00:08<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.48it/s]


Epoch 36: Train Loss=3.4138, PPL=31.78 | Val Loss=3.3831, Val PPL=30.96


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]


Epoch 37: Train Loss=3.4033, PPL=32.35 | Val Loss=3.5560, Val PPL=36.11


Training: 100%|██████████| 10/10 [00:08<00:00,  1.20it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]


Epoch 38: Train Loss=3.4050, PPL=30.98 | Val Loss=3.5565, Val PPL=36.91


Training: 100%|██████████| 10/10 [00:08<00:00,  1.18it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.48it/s]


Epoch 39: Train Loss=3.6064, PPL=38.36 | Val Loss=3.6752, Val PPL=41.03


Training: 100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.49it/s]


Epoch 40: Train Loss=3.3279, PPL=28.85 | Val Loss=3.6297, Val PPL=38.71


Training: 100%|██████████| 10/10 [00:08<00:00,  1.20it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]


Epoch 41: Train Loss=3.4431, PPL=33.35 | Val Loss=3.5081, Val PPL=35.58


Training: 100%|██████████| 10/10 [00:07<00:00,  1.28it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]


Epoch 42: Train Loss=3.2698, PPL=26.95 | Val Loss=3.5004, Val PPL=34.57


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]


Epoch 43: Train Loss=3.4125, PPL=31.04 | Val Loss=3.4497, Val PPL=32.54


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.40it/s]


Epoch 44: Train Loss=3.4940, PPL=34.76 | Val Loss=3.4702, Val PPL=34.52


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.57it/s]


Epoch 45: Train Loss=3.1567, PPL=24.44 | Val Loss=3.3406, Val PPL=30.86


Training: 100%|██████████| 10/10 [00:07<00:00,  1.27it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]


Epoch 46: Train Loss=3.2752, PPL=30.26 | Val Loss=3.5280, Val PPL=35.91


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.59it/s]


Epoch 47: Train Loss=3.2542, PPL=26.38 | Val Loss=3.5755, Val PPL=39.38


Training: 100%|██████████| 10/10 [00:07<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.59it/s]


Epoch 48: Train Loss=3.2904, PPL=28.68 | Val Loss=3.4303, Val PPL=32.14


Training: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]


Epoch 49: Train Loss=3.3181, PPL=29.07 | Val Loss=3.2841, Val PPL=27.40


Training: 100%|██████████| 10/10 [00:07<00:00,  1.27it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]


Epoch 50: Train Loss=3.1297, PPL=23.35 | Val Loss=3.4999, Val PPL=34.72


Training: 100%|██████████| 10/10 [00:08<00:00,  1.18it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]


Epoch 51: Train Loss=3.2227, PPL=25.89 | Val Loss=3.4536, Val PPL=32.81


Training: 100%|██████████| 10/10 [00:08<00:00,  1.18it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.58it/s]


Epoch 52: Train Loss=3.3078, PPL=28.16 | Val Loss=3.4378, Val PPL=35.33


Training: 100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]


Epoch 53: Train Loss=3.0339, PPL=21.97 | Val Loss=3.2278, Val PPL=27.27


Training: 100%|██████████| 10/10 [00:07<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.56it/s]


Epoch 54: Train Loss=3.0140, PPL=21.17 | Val Loss=3.4516, Val PPL=33.45


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.43it/s]


Epoch 55: Train Loss=3.0983, PPL=23.12 | Val Loss=3.2599, Val PPL=27.27


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.58it/s]


Epoch 56: Train Loss=3.0189, PPL=21.27 | Val Loss=3.4863, Val PPL=33.98


Training: 100%|██████████| 10/10 [00:08<00:00,  1.23it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]


Epoch 57: Train Loss=3.0982, PPL=23.36 | Val Loss=3.5669, Val PPL=40.43


Training: 100%|██████████| 10/10 [00:07<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.49it/s]


Epoch 58: Train Loss=3.0928, PPL=23.10 | Val Loss=3.0399, Val PPL=21.51


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.28it/s]


Epoch 59: Train Loss=2.9932, PPL=20.31 | Val Loss=3.3650, Val PPL=29.63


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.61it/s]


Epoch 60: Train Loss=3.0637, PPL=22.31 | Val Loss=3.1977, Val PPL=24.71


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Epoch 61: Train Loss=3.0149, PPL=20.88 | Val Loss=3.5712, Val PPL=39.84


Training: 100%|██████████| 10/10 [00:08<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]


Epoch 62: Train Loss=3.1613, PPL=24.01 | Val Loss=3.3163, Val PPL=30.99


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.40it/s]


Epoch 63: Train Loss=3.1645, PPL=24.72 | Val Loss=3.2938, Val PPL=30.03


Training: 100%|██████████| 10/10 [00:08<00:00,  1.20it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]


Epoch 64: Train Loss=3.1437, PPL=23.76 | Val Loss=3.2738, Val PPL=27.23


Training: 100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]


Epoch 65: Train Loss=3.0188, PPL=21.17 | Val Loss=3.4105, Val PPL=31.16


Training: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]


Epoch 66: Train Loss=2.9364, PPL=19.28 | Val Loss=3.4457, Val PPL=32.41


Training: 100%|██████████| 10/10 [00:08<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.49it/s]


Epoch 67: Train Loss=2.9964, PPL=20.58 | Val Loss=3.5576, Val PPL=37.48


Training: 100%|██████████| 10/10 [00:07<00:00,  1.27it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s]


Epoch 68: Train Loss=3.0420, PPL=21.73 | Val Loss=3.3459, Val PPL=29.30


Training: 100%|██████████| 10/10 [00:07<00:00,  1.26it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]


Epoch 69: Train Loss=2.9161, PPL=18.78 | Val Loss=3.4483, Val PPL=33.12


Training: 100%|██████████| 10/10 [00:07<00:00,  1.31it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]


Epoch 70: Train Loss=3.0082, PPL=20.53 | Val Loss=3.2357, Val PPL=26.79


Training: 100%|██████████| 10/10 [00:08<00:00,  1.13it/s]
Evaluating: 100%|██████████| 10/10 [00:08<00:00,  1.25it/s]


Epoch 71: Train Loss=2.9380, PPL=19.20 | Val Loss=3.3434, Val PPL=29.42


Training: 100%|██████████| 10/10 [00:08<00:00,  1.18it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.50it/s]


Epoch 72: Train Loss=2.9797, PPL=20.47 | Val Loss=3.1888, Val PPL=25.52


Training: 100%|██████████| 10/10 [00:08<00:00,  1.19it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.50it/s]


Epoch 73: Train Loss=2.9270, PPL=19.20 | Val Loss=3.1501, Val PPL=23.82


Training: 100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]


Epoch 74: Train Loss=2.9816, PPL=21.20 | Val Loss=3.4311, Val PPL=33.04


Training: 100%|██████████| 10/10 [00:07<00:00,  1.28it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]


Epoch 75: Train Loss=3.0100, PPL=21.46 | Val Loss=3.1873, Val PPL=25.54


Training: 100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.60it/s]


Epoch 76: Train Loss=2.8452, PPL=17.77 | Val Loss=3.2721, Val PPL=28.49


Training: 100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.55it/s]


Epoch 77: Train Loss=2.9133, PPL=19.32 | Val Loss=3.2215, Val PPL=25.90


Training: 100%|██████████| 10/10 [00:08<00:00,  1.20it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.43it/s]


Epoch 78: Train Loss=2.9852, PPL=20.05 | Val Loss=3.2233, Val PPL=25.99


Training: 100%|██████████| 10/10 [00:07<00:00,  1.28it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]


Epoch 79: Train Loss=2.8248, PPL=17.55 | Val Loss=3.1245, Val PPL=23.03


Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.42it/s]


Epoch 80: Train Loss=2.8429, PPL=17.35 | Val Loss=3.1322, Val PPL=24.05


Training: 100%|██████████| 10/10 [00:07<00:00,  1.25it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  1.68it/s]


Epoch 81: Train Loss=2.9241, PPL=18.77 | Val Loss=3.2599, Val PPL=26.97


Training: 100%|██████████| 10/10 [00:08<00:00,  1.19it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Epoch 82: Train Loss=2.8752, PPL=18.76 | Val Loss=3.2271, Val PPL=26.65


Training: 100%|██████████| 10/10 [00:07<00:00,  1.28it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.57it/s]


Epoch 83: Train Loss=3.0172, PPL=21.36 | Val Loss=3.0754, Val PPL=23.33


Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.43it/s]


Epoch 84: Train Loss=2.9514, PPL=19.66 | Val Loss=3.0872, Val PPL=22.48


Training: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]


Epoch 85: Train Loss=2.9493, PPL=19.83 | Val Loss=3.2271, Val PPL=26.95


Training: 100%|██████████| 10/10 [00:07<00:00,  1.26it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.60it/s]


Epoch 86: Train Loss=2.8944, PPL=18.49 | Val Loss=3.0953, Val PPL=23.14


Training: 100%|██████████| 10/10 [00:07<00:00,  1.27it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.43it/s]


Epoch 87: Train Loss=2.8492, PPL=18.21 | Val Loss=3.2462, Val PPL=27.82


Training: 100%|██████████| 10/10 [00:08<00:00,  1.16it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s]


Epoch 88: Train Loss=2.9736, PPL=19.73 | Val Loss=3.1438, Val PPL=24.17


Training: 100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.46it/s]


Epoch 89: Train Loss=2.9043, PPL=19.16 | Val Loss=2.9692, Val PPL=19.75


Training: 100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.50it/s]


Epoch 90: Train Loss=2.8235, PPL=17.00 | Val Loss=3.3235, Val PPL=29.05


Training: 100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]


Epoch 91: Train Loss=2.9934, PPL=20.56 | Val Loss=3.1459, Val PPL=24.87


Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]


Epoch 92: Train Loss=2.9418, PPL=19.35 | Val Loss=3.1641, Val PPL=24.89


Training: 100%|██████████| 10/10 [00:08<00:00,  1.21it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Epoch 93: Train Loss=2.8185, PPL=17.27 | Val Loss=3.2938, Val PPL=30.15


Training: 100%|██████████| 10/10 [00:10<00:00,  1.00s/it]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]


Epoch 94: Train Loss=2.8998, PPL=18.80 | Val Loss=3.1621, Val PPL=24.30


Training: 100%|██████████| 10/10 [00:08<00:00,  1.23it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.48it/s]


Epoch 95: Train Loss=2.8395, PPL=17.31 | Val Loss=3.2401, Val PPL=26.58


Training: 100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]


Epoch 96: Train Loss=2.7372, PPL=15.63 | Val Loss=3.0695, Val PPL=22.40


Training: 100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
Evaluating: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s]


Epoch 97: Train Loss=2.8161, PPL=17.05 | Val Loss=3.2420, Val PPL=27.38


Training: 100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]


Epoch 98: Train Loss=2.7364, PPL=15.79 | Val Loss=3.3198, Val PPL=28.06


Training: 100%|██████████| 10/10 [00:06<00:00,  1.46it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.44it/s]


Epoch 99: Train Loss=2.7123, PPL=15.26 | Val Loss=3.2018, Val PPL=28.21


Training: 100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
Evaluating: 100%|██████████| 10/10 [00:06<00:00,  1.57it/s]

Epoch 100: Train Loss=2.7243, PPL=15.43 | Val Loss=3.0783, Val PPL=22.54





# Test Parameters

In [14]:
BATCHES_PER_TEST = 1
GREEDY_DECODE = False
TEST_MAX_LEN = 512
TEST_TOP_P = 0.7
TEST_TEMPERATURE = 0.7

# Test

In [15]:
slice_test_loader = islice(test_loader, BATCHES_PER_TEST)
test_stats = evaluate(model, slice_test_loader, device, pad_id, num_batches=BATCHES_PER_TEST)
print(f"Test Loss={test_stats['val_loss']:.4f}, Test PPL={test_stats['val_ppl']:.2f}")

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.42it/s]

Test Loss=2.9071, Test PPL=18.30





# Test Report Generation

In [None]:
with torch.no_grad():
    for pixel_values, ids_loader, paths, raw_labels in test_loader:
        pixel_values = pixel_values.to(device)
        gen_ids = model.generate(
            pixel_values=pixel_values,
            bos_id=bos_id, eos_id=eos_id,
            max_new_tokens=TEST_MAX_LEN, top_p=TEST_TOP_P, temperature=TEST_TEMPERATURE, greedy=GREEDY_DECODE
        )
        print("Predictions (first batch):")
        for i in range(gen_ids.size(0)):
            text_gen = tokenizer.decode(gen_ids[i].tolist())
            text_tgt = tokenizer.decode(ids_loader[i].tolist())
            print(f"\nGEN {i+1}:", text_gen)
            print(f"TGT {i+1}:", text_tgt)
            results = evaluate_all_metrics([text_tgt], [text_gen], evaluation_mode="CheXagent")
            for metric, scores in results.items():
                print(f"{metric}: {scores}")

            save_metrics_to_json(results, f"./results/gpt_model_results_{NUM_EPOCHS}_Padchest.json")
        del pixel_values, ids_loader, paths, raw_labels, gen_ids
        torch.cuda.empty_cache()
        break

Predictions (first batch):

GEN 1: no significant findings.
TGT 1: minimal biapical pleural thickening. slight blunting of the posterior left costophrenic angle. no other significant alterations.
Using device: cuda:0
chexbert_f1_weighted: 0.0
chexbert_f1_micro: 0.0
chexbert_f1_macro: 0.0
chexbert_f1_micro_5: 0.0
chexbert_f1_macro_5: 0.0
bertscore_f1: [0.5073857307434082]
radgraph_f1_RG_E: 0.0
radgraph_f1_RG_ER: 0.0
rouge_l: [0.21052631578947367]

GEN 2: no significant findings.
TGT 2: minimal biapical pleural thickening. slight blunting of the posterior left costophrenic angle. no other significant alterations.
Using device: cuda:0
chexbert_f1_weighted: 0.0
chexbert_f1_micro: 0.0
chexbert_f1_macro: 0.0
chexbert_f1_micro_5: 0.0
chexbert_f1_macro_5: 0.0
bertscore_f1: [0.5073857307434082]
radgraph_f1_RG_E: 0.0
radgraph_f1_RG_ER: 0.0
rouge_l: [0.21052631578947367]

GEN 3: no significant findings.
TGT 3: slight residual atelectasis in the right pulmonary base. minimal blunting of the costop