In [1]:
import pandas as pd
from itertools import islice
import torch
from torch.utils.data import DataLoader
import sys
sys.path.append(r"C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\Chest-X-ray-Diagnosis-Automated-Reporting-using-CNNs-and-LLMs---UDEM-PEF-Thesis-Fall-2025")

from utils.text_metrics import evaluate_all_metrics
from utils.temp_utils import *
from utils.lstm_models import DinoLSTMAttnCaptioner, DinoBiLSTMAttnCaptioner
from utils.chexpert_dataset import CheXpertDataset
from utils.padchest_dataset import PadChestGRDataset

# Data

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

CSV_PATH = r"C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\df_chexpert_plus_240401.csv"
IMG_ROOT = r"C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG"
TEXT_COL = "section_impression"
PATH_COL = "path_to_image"

IMG_SIZE = 224
MAX_LEN = 64
NUM_BATCH = 8

tf = dino_image_transform(img_size=IMG_SIZE)

ds_train = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="train", transform=tf, text_col=TEXT_COL)
ds_valid = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="valid", transform=tf, text_col=TEXT_COL)
ds_test = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="test", transform=tf, text_col=TEXT_COL)

tokenizer = build_tokenizer_from_labels()
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
bos_id = tokenizer.bos_token_id
collate_fn = CaptionCollate(tokenizer, pad_id)

train_loader = DataLoader(ds_train, batch_size=NUM_BATCH, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(ds_valid, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(ds_test, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)

Using device: cuda
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG


# Model

In [3]:
# DINO ViT-S/16 hidden size is 384 
EMBEDDING_D_IMG = 384
N_PREFIX = (IMG_SIZE // 16) ** 2  # number of visual prefix tokens (including CLS)

model = DinoBiLSTMAttnCaptioner(
    vocab_size=tokenizer.vocab_size,
    d_img=EMBEDDING_D_IMG,
    d_h=512,
    pad_id=pad_id,
    dino_model_id="facebook/dinov3-vits16-pretrain-lvd1689m",
    freeze_dino=True,
).to(device)

# Train Parameters

In [4]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2
)
loss = sequence_ce_loss
NUM_EPOCHS = 10
BATCHES_PER_EPOCH = 10

# Training

In [5]:
for epoch in range(NUM_EPOCHS):
    slice_train_loader = islice(train_loader, BATCHES_PER_EPOCH)
    slice_valid_loader = islice(valid_loader, BATCHES_PER_EPOCH)
    train_stats = train_one_epoch(model, slice_train_loader, optimizer, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss, grad_clip=1.0)
    val_stats = evaluate(model, slice_valid_loader, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss)
    print(f"Epoch {epoch + 1}: Train Loss={train_stats['loss']:.4f}, PPL={train_stats['ppl']:.2f} | "
            f"Val Loss={val_stats['val_loss']:.4f}, Val PPL={val_stats['val_ppl']:.2f}")

Training: 100%|██████████| 10/10 [00:06<00:00,  1.63it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.01it/s]


Epoch 1: Train Loss=9.1222, PPL=18237.58 | Val Loss=7.0573, Val PPL=1222.79


Training: 100%|██████████| 10/10 [00:06<00:00,  1.65it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 2: Train Loss=6.1443, PPL=597.71 | Val Loss=5.1491, Val PPL=189.05


Training: 100%|██████████| 10/10 [00:06<00:00,  1.62it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 3: Train Loss=4.6286, PPL=109.07 | Val Loss=4.2516, Val PPL=78.23


Training: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 4: Train Loss=3.9095, PPL=52.19 | Val Loss=3.6415, Val PPL=41.94


Training: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 5: Train Loss=3.3632, PPL=30.26 | Val Loss=3.2330, Val PPL=27.41


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Epoch 6: Train Loss=3.0337, PPL=21.57 | Val Loss=2.9483, Val PPL=20.26


Training: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 7: Train Loss=2.9123, PPL=18.96 | Val Loss=2.7485, Val PPL=16.45


Training: 100%|██████████| 10/10 [00:06<00:00,  1.63it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.17it/s]


Epoch 8: Train Loss=2.6542, PPL=14.71 | Val Loss=2.6106, Val PPL=14.25


Training: 100%|██████████| 10/10 [00:06<00:00,  1.63it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.12it/s]


Epoch 9: Train Loss=2.5154, PPL=12.68 | Val Loss=2.5003, Val PPL=12.70


Training: 100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]

Epoch 10: Train Loss=2.4297, PPL=11.64 | Val Loss=2.4014, Val PPL=11.48





# Test Parameters

In [6]:
BATCHES_PER_TEST = 1
GREEDY_DECODE = True
TEST_MAX_LEN = 256
TEST_TOP_P = 0.9
TEST_TEMPERATURE = 0.9

# Test

In [7]:
slice_test_loader = islice(test_loader, BATCHES_PER_TEST)
test_stats = evaluate(model, slice_test_loader, device, pad_id, num_batches=BATCHES_PER_TEST)
print(f"Test Loss={test_stats['val_loss']:.4f}, Test PPL={test_stats['val_ppl']:.2f}")

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.77it/s]

Test Loss=2.1606, Test PPL=8.68





# Test Report Generation

In [8]:
with torch.no_grad():
    for pixel_values, ids_loader, paths, raw_labels in test_loader:
        pixel_values = pixel_values.to(device)
        gen_ids = model.generate(
            pixel_values=pixel_values,
            bos_id=bos_id, eos_id=eos_id,
            max_new_tokens=TEST_MAX_LEN, top_p=TEST_TOP_P, temperature=TEST_TEMPERATURE, greedy=GREEDY_DECODE
        )
        print("Predictions (first batch):")
        for i in range(gen_ids.size(0)):
            text_gen = tokenizer.decode(gen_ids[i].tolist())
            text_tgt = tokenizer.decode(ids_loader[i].tolist())
            print(f"\nGEN {i+1}:", text_gen)
            print(f"TGT {i+1}:", text_tgt)
            results = evaluate_all_metrics([text_tgt], [text_gen], evaluation_mode="CheXagent")
            for metric, scores in results.items():
                print(f"{metric}: {scores}")
        del pixel_values, ids_loader, paths, raw_labels, gen_ids
        torch.cuda.empty_cache()
        break

Predictions (first batch):

GEN 1: 1.
TGT 1: 1. interval placement of a right internal jugular venous sheath with the distal tip in the proximal superior vena cava. no pneumothorax. 2. stable position of nasogastric tube, feeding tube, tracheostomy canula, left internal jugular central venous catheter, and left upper extremity picc. 3. no significant interval change in hyperexpanded lung volumes, right basilar opacities, small bilateral pleural effusions, tenting of the right hemidiaphragm and biapical pleural thickening.
Using device: cuda:0
chexbert_f1_weighted: 0.0
chexbert_f1_micro: 0.0
chexbert_f1_macro: 0.0
chexbert_f1_micro_5: 0.0
chexbert_f1_macro_5: 0.0
bertscore_f1: [0.2504917085170746]
radgraph_f1_RG_E: 0.0
radgraph_f1_RG_ER: 0.0
rouge_l: [0.028985507246376812]

GEN 2: 1.
TGT 2: 1. comparison to 4 - 16 - 2015. 2. persistent airspace disease in right lower lobe worrisome for developing pneumonia or focal atelectasis. 3. interval removal of right pigtail catheter. 4. interval 