In [1]:
import os
os.chdir(os.path.dirname(os.getcwd()))
import pandas as pd
from itertools import islice
import torch
from torch.utils.data import DataLoader
from utils.text_metrics import evaluate_all_metrics, save_metrics_to_json
from utils.temp_utils import *
from utils.lstm_models import DinoLSTMAttnCaptioner, DinoBiLSTMAttnCaptioner
from utils.chexpert_dataset import CheXpertDataset
from utils.padchest_dataset import PadChestGRDataset

# Data

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

CSV_PATH = "Datasets/CheXpertPlus/df_chexpert_plus_240401.csv"
IMG_ROOT = "Datasets/CheXpertPlus/PNG"

TEXT_COL = "section_impression"
PATH_COL = "path_to_image"

IMG_SIZE = 224
MAX_LEN = 64
NUM_BATCH = 8

tf = dino_image_transform(img_size=IMG_SIZE)

ds_train = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="train", transform=tf, text_col=TEXT_COL)
ds_valid = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="valid", transform=tf, text_col=TEXT_COL)
ds_test = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="test", transform=tf, text_col=TEXT_COL)

tokenizer = build_tokenizer_from_labels()
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
bos_id = tokenizer.bos_token_id
collate_fn = CaptionCollate(tokenizer, pad_id)

train_loader = DataLoader(ds_train, batch_size=NUM_BATCH, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(ds_valid, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(ds_test, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)

Using device: cuda
[INFO] Kept 47494/223462 rows with existing PNGs
[INFO] Kept 47494/223462 rows with existing PNGs
[INFO] Kept 47494/223462 rows with existing PNGs


# Model

In [3]:
# DINO ViT-S/16 hidden size is 384 
EMBEDDING_D_IMG = 384
N_PREFIX = (IMG_SIZE // 16) ** 2  # number of visual prefix tokens (including CLS)

model = DinoBiLSTMAttnCaptioner(
    vocab_size=tokenizer.vocab_size,
    d_img=EMBEDDING_D_IMG,
    d_h=512,
    pad_id=pad_id,
    dino_model_id="facebook/dinov3-vits16-pretrain-lvd1689m",
    freeze_dino=True,
).to(device)

# Train Parameters

In [4]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2
)
loss = sequence_ce_loss
NUM_EPOCHS = 100
BATCHES_PER_EPOCH = 10

# Training

In [5]:
import time

time_start = time.time()
for epoch in range(NUM_EPOCHS):
    slice_train_loader = islice(train_loader, BATCHES_PER_EPOCH)
    slice_valid_loader = islice(valid_loader, BATCHES_PER_EPOCH)
    train_stats = train_one_epoch(model, slice_train_loader, optimizer, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss, grad_clip=1.0)
    val_stats = evaluate(model, slice_valid_loader, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss)
    print(f"Epoch {epoch + 1}: Train Loss={train_stats['loss']:.4f}, PPL={train_stats['ppl']:.2f} | "
            f"Val Loss={val_stats['val_loss']:.4f}, Val PPL={val_stats['val_ppl']:.2f}")
training_time = time.time() - time_start

Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 1: Train Loss=9.6034, PPL=23935.71 | Val Loss=7.8257, Val PPL=2540.15


Training: 100%|██████████| 10/10 [00:05<00:00,  1.91it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 2: Train Loss=7.0715, PPL=1393.85 | Val Loss=6.0644, Val PPL=445.74


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 3: Train Loss=5.8736, PPL=367.91 | Val Loss=5.2337, Val PPL=195.75


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 4: Train Loss=5.1469, PPL=177.90 | Val Loss=4.7249, Val PPL=118.56


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.36it/s]


Epoch 5: Train Loss=4.6599, PPL=109.44 | Val Loss=4.3749, Val PPL=83.71


Training: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 6: Train Loss=4.4001, PPL=86.50 | Val Loss=4.1468, Val PPL=66.23


Training: 100%|██████████| 10/10 [00:05<00:00,  1.91it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 7: Train Loss=4.2901, PPL=76.23 | Val Loss=3.9576, Val PPL=54.55


Training: 100%|██████████| 10/10 [00:05<00:00,  1.71it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 8: Train Loss=4.0198, PPL=56.49 | Val Loss=3.8261, Val PPL=47.61


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 9: Train Loss=3.8730, PPL=49.43 | Val Loss=3.7453, Val PPL=43.82


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 10: Train Loss=3.8115, PPL=47.46 | Val Loss=3.6792, Val PPL=40.91


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.41it/s]


Epoch 11: Train Loss=3.7456, PPL=43.07 | Val Loss=3.5901, Val PPL=37.34


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.30it/s]


Epoch 12: Train Loss=3.6334, PPL=38.62 | Val Loss=3.5222, Val PPL=34.73


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 13: Train Loss=3.4943, PPL=33.28 | Val Loss=3.5017, Val PPL=33.97


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 14: Train Loss=3.4592, PPL=32.60 | Val Loss=3.4514, Val PPL=32.25


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.28it/s]


Epoch 15: Train Loss=3.4612, PPL=32.12 | Val Loss=3.3971, Val PPL=30.49


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 16: Train Loss=3.3974, PPL=30.65 | Val Loss=3.3704, Val PPL=29.66


Training: 100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.39it/s]


Epoch 17: Train Loss=3.2854, PPL=27.13 | Val Loss=3.3432, Val PPL=28.83


Training: 100%|██████████| 10/10 [00:05<00:00,  1.86it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 18: Train Loss=3.3351, PPL=29.01 | Val Loss=3.3302, Val PPL=28.42


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 19: Train Loss=3.4103, PPL=31.09 | Val Loss=3.2878, Val PPL=27.21


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 20: Train Loss=3.2678, PPL=26.65 | Val Loss=3.2856, Val PPL=27.13


Training: 100%|██████████| 10/10 [00:05<00:00,  1.71it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 21: Train Loss=3.3593, PPL=29.07 | Val Loss=3.2483, Val PPL=26.12


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 22: Train Loss=3.3145, PPL=28.54 | Val Loss=3.2391, Val PPL=25.85


Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.33it/s]


Epoch 23: Train Loss=3.2248, PPL=25.50 | Val Loss=3.2111, Val PPL=25.12


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 24: Train Loss=3.1540, PPL=23.57 | Val Loss=3.2156, Val PPL=25.22


Training: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 25: Train Loss=3.1893, PPL=24.45 | Val Loss=3.1920, Val PPL=24.63


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 26: Train Loss=3.1517, PPL=23.77 | Val Loss=3.1887, Val PPL=24.54


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 27: Train Loss=3.1456, PPL=23.56 | Val Loss=3.1706, Val PPL=24.09


Training: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 28: Train Loss=3.1724, PPL=24.28 | Val Loss=3.1528, Val PPL=23.65


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.32it/s]


Epoch 29: Train Loss=3.0600, PPL=21.57 | Val Loss=3.1375, Val PPL=23.26


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Epoch 30: Train Loss=3.1302, PPL=23.06 | Val Loss=3.1499, Val PPL=23.54


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 31: Train Loss=3.1866, PPL=24.55 | Val Loss=3.1125, Val PPL=22.67


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 32: Train Loss=3.0955, PPL=22.33 | Val Loss=3.1384, Val PPL=23.25


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.38it/s]


Epoch 33: Train Loss=3.0929, PPL=22.31 | Val Loss=3.1011, Val PPL=22.40


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 34: Train Loss=3.0532, PPL=21.31 | Val Loss=3.0995, Val PPL=22.36


Training: 100%|██████████| 10/10 [00:05<00:00,  1.85it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.43it/s]


Epoch 35: Train Loss=3.0883, PPL=22.28 | Val Loss=3.1114, Val PPL=22.62


Training: 100%|██████████| 10/10 [00:05<00:00,  1.99it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.31it/s]


Epoch 36: Train Loss=3.1173, PPL=22.97 | Val Loss=3.0741, Val PPL=21.78


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]


Epoch 37: Train Loss=3.0294, PPL=20.97 | Val Loss=3.0685, Val PPL=21.66


Training: 100%|██████████| 10/10 [00:05<00:00,  1.94it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 38: Train Loss=3.0197, PPL=20.76 | Val Loss=3.0782, Val PPL=21.87


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.28it/s]


Epoch 39: Train Loss=3.0673, PPL=21.74 | Val Loss=3.0626, Val PPL=21.52


Training: 100%|██████████| 10/10 [00:05<00:00,  1.86it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 40: Train Loss=3.0268, PPL=20.90 | Val Loss=3.0533, Val PPL=21.33


Training: 100%|██████████| 10/10 [00:05<00:00,  1.91it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.43it/s]


Epoch 41: Train Loss=3.0096, PPL=20.45 | Val Loss=3.0541, Val PPL=21.34


Training: 100%|██████████| 10/10 [00:05<00:00,  1.99it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.40it/s]


Epoch 42: Train Loss=2.9991, PPL=20.18 | Val Loss=3.0335, Val PPL=20.91


Training: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 43: Train Loss=3.0091, PPL=20.37 | Val Loss=3.0580, Val PPL=21.42


Training: 100%|██████████| 10/10 [00:05<00:00,  1.94it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 44: Train Loss=3.0552, PPL=21.49 | Val Loss=3.0347, Val PPL=20.93


Training: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.30it/s]


Epoch 45: Train Loss=2.9817, PPL=19.92 | Val Loss=3.0143, Val PPL=20.50


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 46: Train Loss=2.9940, PPL=20.12 | Val Loss=3.0268, Val PPL=20.76


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.39it/s]


Epoch 47: Train Loss=2.9507, PPL=19.35 | Val Loss=3.0230, Val PPL=20.68


Training: 100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 48: Train Loss=2.9496, PPL=19.24 | Val Loss=3.0014, Val PPL=20.23


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 49: Train Loss=3.0135, PPL=20.54 | Val Loss=2.9932, Val PPL=20.06


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]


Epoch 50: Train Loss=2.9737, PPL=19.62 | Val Loss=2.9929, Val PPL=20.06


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 51: Train Loss=2.9350, PPL=18.98 | Val Loss=3.0056, Val PPL=20.32


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Epoch 52: Train Loss=2.9928, PPL=20.21 | Val Loss=3.0012, Val PPL=20.22


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.37it/s]


Epoch 53: Train Loss=2.9518, PPL=19.52 | Val Loss=2.9844, Val PPL=19.88


Training: 100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.36it/s]


Epoch 54: Train Loss=2.9173, PPL=18.69 | Val Loss=2.9860, Val PPL=19.91


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.35it/s]


Epoch 55: Train Loss=2.9446, PPL=19.19 | Val Loss=2.9790, Val PPL=19.77


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]


Epoch 56: Train Loss=2.9635, PPL=19.63 | Val Loss=2.9729, Val PPL=19.65


Training: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 57: Train Loss=2.9441, PPL=19.12 | Val Loss=2.9807, Val PPL=19.80


Training: 100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 58: Train Loss=2.9971, PPL=20.36 | Val Loss=2.9662, Val PPL=19.51


Training: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 59: Train Loss=2.9985, PPL=20.19 | Val Loss=2.9558, Val PPL=19.31


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


Epoch 60: Train Loss=2.9177, PPL=18.68 | Val Loss=2.9715, Val PPL=19.61


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 61: Train Loss=2.9559, PPL=19.43 | Val Loss=2.9569, Val PPL=19.33


Training: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 62: Train Loss=2.8902, PPL=18.18 | Val Loss=2.9547, Val PPL=19.28


Training: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 63: Train Loss=2.9456, PPL=19.26 | Val Loss=2.9672, Val PPL=19.52


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 64: Train Loss=2.9107, PPL=18.47 | Val Loss=2.9430, Val PPL=19.05


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


Epoch 65: Train Loss=2.9355, PPL=19.14 | Val Loss=2.9483, Val PPL=19.15


Training: 100%|██████████| 10/10 [00:05<00:00,  1.93it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.32it/s]


Epoch 66: Train Loss=2.9109, PPL=18.51 | Val Loss=2.9477, Val PPL=19.13


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.25it/s]


Epoch 67: Train Loss=2.9256, PPL=18.81 | Val Loss=2.9547, Val PPL=19.26


Training: 100%|██████████| 10/10 [00:05<00:00,  1.84it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 68: Train Loss=2.8977, PPL=18.22 | Val Loss=2.9375, Val PPL=18.93


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 69: Train Loss=2.8702, PPL=17.84 | Val Loss=2.9327, Val PPL=18.84


Training: 100%|██████████| 10/10 [00:05<00:00,  1.82it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 70: Train Loss=2.9179, PPL=18.69 | Val Loss=2.9345, Val PPL=18.88


Training: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]


Epoch 71: Train Loss=2.8637, PPL=17.53 | Val Loss=2.9332, Val PPL=18.85


Training: 100%|██████████| 10/10 [00:05<00:00,  1.72it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 72: Train Loss=2.8959, PPL=18.29 | Val Loss=2.9233, Val PPL=18.67


Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 73: Train Loss=2.9288, PPL=18.74 | Val Loss=2.9107, Val PPL=18.43


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 74: Train Loss=2.8597, PPL=17.41 | Val Loss=2.9280, Val PPL=18.75


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 75: Train Loss=2.8962, PPL=18.25 | Val Loss=2.9150, Val PPL=18.50


Training: 100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 76: Train Loss=2.8944, PPL=18.13 | Val Loss=2.9144, Val PPL=18.49


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.32it/s]


Epoch 77: Train Loss=2.8747, PPL=17.94 | Val Loss=2.9161, Val PPL=18.52


Training: 100%|██████████| 10/10 [00:05<00:00,  1.83it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 78: Train Loss=2.9126, PPL=18.59 | Val Loss=2.9208, Val PPL=18.61


Training: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.26it/s]


Epoch 79: Train Loss=2.9237, PPL=18.91 | Val Loss=2.9233, Val PPL=18.66


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 80: Train Loss=2.8927, PPL=18.18 | Val Loss=2.9070, Val PPL=18.36


Training: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 81: Train Loss=2.8603, PPL=17.54 | Val Loss=2.8974, Val PPL=18.18


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 82: Train Loss=2.8898, PPL=18.19 | Val Loss=2.9152, Val PPL=18.51


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.35it/s]


Epoch 83: Train Loss=2.8646, PPL=17.60 | Val Loss=2.9102, Val PPL=18.41


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.32it/s]


Epoch 84: Train Loss=2.8300, PPL=16.96 | Val Loss=2.9086, Val PPL=18.38


Training: 100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 85: Train Loss=2.8821, PPL=18.02 | Val Loss=2.9007, Val PPL=18.24


Training: 100%|██████████| 10/10 [00:05<00:00,  1.71it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 86: Train Loss=2.8459, PPL=17.25 | Val Loss=2.8947, Val PPL=18.13


Training: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 87: Train Loss=2.8612, PPL=17.66 | Val Loss=2.8906, Val PPL=18.05


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 88: Train Loss=2.8851, PPL=17.98 | Val Loss=2.8934, Val PPL=18.10


Training: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.31it/s]


Epoch 89: Train Loss=2.8697, PPL=17.82 | Val Loss=2.8906, Val PPL=18.05


Training: 100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 90: Train Loss=2.8629, PPL=17.70 | Val Loss=2.8869, Val PPL=17.99


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]


Epoch 91: Train Loss=2.8897, PPL=18.14 | Val Loss=2.8808, Val PPL=17.88


Training: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]


Epoch 92: Train Loss=2.8930, PPL=18.20 | Val Loss=2.8931, Val PPL=18.10


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 93: Train Loss=2.8642, PPL=17.61 | Val Loss=2.8984, Val PPL=18.19


Training: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 94: Train Loss=2.8630, PPL=17.77 | Val Loss=2.8873, Val PPL=17.99


Training: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.33it/s]


Epoch 95: Train Loss=2.8479, PPL=17.39 | Val Loss=2.8842, Val PPL=17.94


Training: 100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.30it/s]


Epoch 96: Train Loss=2.8684, PPL=17.77 | Val Loss=2.8780, Val PPL=17.82


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s]


Epoch 97: Train Loss=2.8534, PPL=17.39 | Val Loss=2.8786, Val PPL=17.84


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]


Epoch 98: Train Loss=2.8504, PPL=17.54 | Val Loss=2.8770, Val PPL=17.81


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Epoch 99: Train Loss=2.9240, PPL=18.81 | Val Loss=2.8895, Val PPL=18.03


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.16it/s]

Epoch 100: Train Loss=2.8731, PPL=17.83 | Val Loss=2.8834, Val PPL=17.92





# Test Parameters

In [6]:
BATCHES_PER_TEST = 1
GREEDY_DECODE = True
TEST_MAX_LEN = 256
TEST_TOP_P = 0.9
TEST_TEMPERATURE = 0.9

# Test

In [7]:
slice_test_loader = islice(test_loader, BATCHES_PER_TEST)
test_stats = evaluate(model, slice_test_loader, device, pad_id, num_batches=BATCHES_PER_TEST)
print(f"Test Loss={test_stats['val_loss']:.4f}, Test PPL={test_stats['val_ppl']:.2f}")

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.11it/s]

Test Loss=2.8424, Test PPL=17.16





# Test Report Generation

In [8]:
generated_text = []
target_text = []
iteration = 0

with torch.no_grad():
    for pixel_values, ids_loader, paths, raw_labels in test_loader:
        iteration += 1
        pixel_values = pixel_values.to(device)
        # gen_ids = model.generate(
        #     pixel_values=pixel_values,
        #     bos_id=bos_id, eos_id=eos_id,
        #     max_new_tokens=TEST_MAX_LEN, top_p=TEST_TOP_P, temperature=TEST_TEMPERATURE, greedy=GREEDY_DECODE
        # )

        info = model.generate_with_logging(
            pixel_values=pixel_values,             # [B, C, H, W]
            bos_id=tokenizer.bos_token_id,
            eos_id=tokenizer.eos_token_id,
            tokenizer=tokenizer,
            preset="safe_sample",
            stop_sequences=None, #["\n\n", "Impression:"],
            max_new_tokens=128,
        )
        # print("sequences:", info["sequences"].shape)
        # for i, s in enumerate(info["per_sample"]):
        #     print(f"[{i}] EOS={s['stopping']['hit_eos']} rep={s['repetition']}")
        #     print(s["text"].get("generated", "")[:200])

        # eval_results = evaluate_all_metrics(raw_labels, [s["text"]["generated"] for s in info["per_sample"]], evaluation_mode="CheXagent")
        # for metric, scores in eval_results.items():
        #     print(f"{metric}: {scores}")

        generated_text.extend([s["text"]["generated"] for s in info["per_sample"]])
        target_text.extend(raw_labels)
        # save_metrics_to_json(eval_results, f"lstm-vs-gpt/results/bilstm_model_results_{NUM_EPOCHS}_Chexpert.json")

        # print("Predictions (first batch):")
        # for i in range(gen_ids.size(0)):
        #     text_gen = tokenizer.decode(gen_ids[i].tolist())
        #     text_tgt = tokenizer.decode(ids_loader[i].tolist())
        #     print(f"\nGEN {i+1}:", text_gen)
        #     print(f"TGT {i+1}:", text_tgt)
        #     results = evaluate_all_metrics([text_tgt], [text_gen], evaluation_mode="CheXagent")
        #     for metric, scores in results.items():
        #         print(f"{metric}: {scores}")
        del pixel_values, ids_loader, paths, raw_labels, info
        torch.cuda.empty_cache()
        # break
        if iteration >= 10:  # Limit to 10 iterations for testing
            break

eval_results = evaluate_all_metrics(generated_text, target_text, evaluation_mode="CheXagent")
for metric, scores in eval_results.items():
    print(f"{metric}: {scores}")
eval_results["training_time_seconds"] = training_time
save_metrics_to_json(eval_results, f"lstm-vs-gpt/results/bilstm_model_results_{NUM_EPOCHS}_Chexpert.json")

Using device: cuda:0
chexbert_f1_weighted: 0.13673738744272532
chexbert_f1_micro: 0.18848167539267016
chexbert_f1_macro: 0.08115589064211391
chexbert_f1_micro_5: 0.05555555555555555
chexbert_f1_macro_5: 0.06183966840109356
radgraph_f1_RG_E: 0.0030077086656034027
radgraph_f1_RG_ER: 0.0021904761904761906


In [9]:
# Print number of model parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total model parameters: {total_params}")

Total model parameters: 119691252


In [10]:
info

NameError: name 'info' is not defined

In [None]:
text = "1.  STABLE SMALL LEFT INTERNAL JUGULAR OPACITIES WITH PATCHY TUBE AND NASOGASTRIC TUBES, RIGHT LOWER MEDIASTINAL SIDED CATHETER.  NO SIGNIFICANT CHANGE IN THE PREVIOUS STUDYDEMONSTRATE ATELECTASIS O"
encoded = tokenizer.encode(text)
print("BOS token id:", tokenizer.bos_token_id, "EOS token id:", tokenizer.eos_token_id, "PAD token id:", tokenizer.pad_token_id)
print(encoded)

BOS token id: 101 EOS token id: 102 PAD token id: 0
[101, 122, 119, 6111, 1353, 1286, 4422, 34986, 5552, 39280, 49176, 1114, 10085, 1183, 7159, 1105, 9468, 7301, 32519, 11182, 117, 1268, 2211, 2394, 34979, 7050, 11641, 5855, 30682, 119, 1185, 2418, 1849, 1107, 1103, 2166, 2025, 31386, 8756, 18465, 14229, 184, 102]
