In [None]:
import os
import pandas as pd
from itertools import islice
import torch
from torch.utils.data import DataLoader
from utils.text_metrics import evaluate_all_metrics
from utils.temp_utils import *
from utils.gpt_models import DinoGPTCaptioner, DinoGPT2Captioner
from utils.chexpert_dataset import CheXpertDataset
from utils.padchest_dataset import PadChestGRDataset

# Data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

os.chdir(os.path.dirname(os.getcwd()))

CSV_PATH = "Datasets/CheXpertPlus/df_chexpert_plus_240401.csv"
IMG_ROOT = "Datasets/CheXpertPlus/PNG"
TEXT_COL = "section_impression"
PATH_COL = "path_to_image"

IMG_SIZE = 224
MAX_LEN = 64
NUM_BATCH = 8

tf = dino_image_transform(img_size=IMG_SIZE)

ds_train = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="train", transform=tf, text_col=TEXT_COL)
ds_valid = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="valid", transform=tf, text_col=TEXT_COL)
ds_test = CheXpertDataset(img_root=IMG_ROOT, csv_path=CSV_PATH, split="test", transform=tf, text_col=TEXT_COL)

tokenizer = build_tokenizer_from_labels(gpt2=True)
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
bos_id = tokenizer.bos_token_id
collate_fn = CaptionCollate(tokenizer, pad_id)

train_loader = DataLoader(ds_train, batch_size=NUM_BATCH, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(ds_valid, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(ds_test, batch_size=NUM_BATCH, shuffle=False, collate_fn=collate_fn)

Using device: cuda
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG
[INFO] Kept 47494/223462 rows with existing PNGs under C:\Users\emman\Desktop\PROYECTOS_VS_CODE\PRUEBAS_DE_PYTHON\CheXpertPlus\PNG


# Model

In [3]:
# DINO ViT-S/16 hidden size is 384 
EMBEDDING_D_IMG = 384
N_PREFIX = (IMG_SIZE // 16) ** 2  # number of visual prefix tokens (including CLS)

def pick_heads(d_model, target_head_dim=64):
    h = max(1, round(d_model / target_head_dim))
    while d_model % h != 0: h -= 1
    return h

D_MODEL = 768
N_HEAD = pick_heads(D_MODEL, 64)  # -> 12


model = DinoGPT2Captioner(
    d_img=EMBEDDING_D_IMG,
    num_prefix_tokens=N_PREFIX,
    gpt2_name="gpt2",
    dino_model_id="facebook/dinov3-vits16-pretrain-lvd1689m",
    freeze_dino=True
).to(device)

# Print model parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total model parameters: {total_params / 1_000_000:.2f} Millions")

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable model parameters: {trainable_params / 1_000_000:.2f} Millions")

# Print model footprint
model_footprint_in_gb = (total_params * 4) * (1e-9)  # assuming 4 bytes per parameter (float32)
print(f"Approximate model footprint: {model_footprint_in_gb:.2f} GB")

# after model init
#model.decoder.lm_head.weight = model.decoder.tok_emb.weight  # weight tying

Total model parameters: 146.33 Millions
Trainable model parameters: 124.74 Millions
Approximate model footprint: 0.59 GB


# Train Parameters

In [4]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2
)
loss = sequence_ce_loss
NUM_EPOCHS = 10
BATCHES_PER_EPOCH = 10

# Training

In [5]:
for epoch in range(NUM_EPOCHS):
    slice_train_loader = islice(train_loader, BATCHES_PER_EPOCH)
    slice_valid_loader = islice(valid_loader, BATCHES_PER_EPOCH)
    train_stats = train_one_epoch(model, slice_train_loader, optimizer, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss, grad_clip=1.0)
    val_stats = evaluate(model, slice_valid_loader, device, pad_id, num_batches=BATCHES_PER_EPOCH, loss_fn=loss)
    print(f"Epoch {epoch + 1}: Train Loss={train_stats['loss']:.4f}, PPL={train_stats['ppl']:.2f} | "
            f"Val Loss={val_stats['val_loss']:.4f}, Val PPL={val_stats['val_ppl']:.2f}")

  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
Training: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]


Epoch 1: Train Loss=6.5548, PPL=1269.17 | Val Loss=5.5504, Val PPL=260.68


Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 2: Train Loss=5.4843, PPL=243.14 | Val Loss=5.0903, Val PPL=165.74


Training: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Epoch 3: Train Loss=5.1047, PPL=168.19 | Val Loss=4.8614, Val PPL=133.16


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 4: Train Loss=4.9339, PPL=141.39 | Val Loss=4.7477, Val PPL=118.38


Training: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 5: Train Loss=4.8142, PPL=126.06 | Val Loss=4.6583, Val PPL=108.25


Training: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.05it/s]


Epoch 6: Train Loss=4.7483, PPL=117.84 | Val Loss=4.6407, Val PPL=106.10


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Epoch 7: Train Loss=4.6956, PPL=111.46 | Val Loss=4.5851, Val PPL=100.69


Training: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.06it/s]


Epoch 8: Train Loss=4.6052, PPL=100.46 | Val Loss=4.5214, Val PPL=94.54


Training: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]
Evaluating: 100%|██████████| 10/10 [00:04<00:00,  2.05it/s]


Epoch 9: Train Loss=4.6529, PPL=105.57 | Val Loss=4.5060, Val PPL=92.40


Training: 100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
Evaluating: 100%|██████████| 10/10 [00:05<00:00,  2.00it/s]

Epoch 10: Train Loss=4.6073, PPL=102.13 | Val Loss=4.4974, Val PPL=91.83





# Test Parameters

In [6]:
BATCHES_PER_TEST = 1
GREEDY_DECODE = True
TEST_MAX_LEN = 256
TEST_TOP_P = 0.9
TEST_TEMPERATURE = 0.9

# Test

In [7]:
slice_test_loader = islice(test_loader, BATCHES_PER_TEST)
test_stats = evaluate(model, slice_test_loader, device, pad_id, num_batches=BATCHES_PER_TEST)
print(f"Test Loss={test_stats['val_loss']:.4f}, Test PPL={test_stats['val_ppl']:.2f}")

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s]

Test Loss=4.4568, Test PPL=86.21





# Test Report Generation

In [8]:
with torch.no_grad():
    for pixel_values, ids_loader, paths, raw_labels in test_loader:
        pixel_values = pixel_values.to(device)
        gen_ids = model.generate(
            pixel_values=pixel_values,
            input_ids=ids_loader.to(device),
            max_new_tokens=64
        ).to(device)

        info = model.generate_with_logging(
            pixel_values=pixel_values,
            input_ids=ids_loader.to(device),
            tokenizer=tokenizer,
            preset="safe_sample",
            stop_sequences=["\n\n", "Impression:"],
            max_new_tokens=128,
        )
        print("out shape:", info["sequences"].shape)
        for i, s in enumerate(info["per_sample"]):
            print(f"[{i}] EOS={s['stopping']['hit_eos']} rep={s['repetition']}")
            print(s["text"].get("generated","")[:200])
            print("[Target text]", raw_labels[i])

        eval_results = evaluate_all_metrics(raw_labels, [s["text"]["generated"] for s in info["per_sample"]], evaluation_mode="CheXagent")
        for metric, scores in eval_results.items():
            print(f"{metric}: {scores}")



        print("Predictions (first batch):")
        for i in range(gen_ids.size(0)):
            text_gen = tokenizer.decode(gen_ids[i].tolist())
            text_tgt = tokenizer.decode(ids_loader[i].tolist())
            print(f"\nGEN {i+1}:", text_gen)
            print(f"TGT {i+1}:", text_tgt)
            try: 
                results = evaluate_all_metrics([text_tgt], [text_gen], evaluation_mode="CheXagent")
                for metric, scores in results.items():
                    print(f"{metric}: {scores}")
            except Exception as e:
                print("Error in evaluation:", e)
        del pixel_values, ids_loader, paths, raw_labels, gen_ids
        torch.cuda.empty_cache()
        break

out shape: torch.Size([8, 240])
[0] EOS=False rep={'max_token_run': 1, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
interstitial pulmonary edema and mild pulmonary edophorax are not significantly changed. small bilateral subcutaneous pleural fluid remain in place. no evidence of pneumothorsax. persistent bibasilar 
[Target text] interval placement of a right internal jugular venous sheath with the distal tip in the proximal superior vena cava. no pneumothorax. stable position of nasogastric tube feeding tube tracheostomy canula left internal jugular central venous catheter and left upper extremity picc. no significant interval change in hyperexpanded lung volumes right basilar opacities small bilateral pleural effusions tenting of the right hemidiaphragm and biapical pleural thickening.
[1] EOS=False rep={'max_token_run': 1, 'max_repeat_trigram': 1, 'max_repeat_4gram': 1}
two leads are obtained. persistent low lung volumes and left pleural effusion. small bilateral pleural fluid is 

In [9]:
text = "1.  STABLE SMALL LEFT INTERNAL JUGULAR OPACITIES WITH PATCHY TUBE AND NASOGASTRIC TUBES, RIGHT LOWER MEDIASTINAL SIDED CATHETER.  NO SIGNIFICANT CHANGE IN THE PREVIOUS STUDYDEMONSTRATE ATELECTASIS O"
print("Original text:", text)
encoded = tokenizer.encode(text)
words = text.split()
print("Number of words:", len(words), "Number of tokens:", len(encoded), "pad_id:", pad_id, "eos_id:", eos_id, "bos_id:", bos_id)
print("BOS token id:", tokenizer.bos_token_id, "EOS token id:", tokenizer.eos_token_id, "PAD token id:", tokenizer.pad_token_id)
print(encoded)
for token_id in encoded:    
    print(f"Token ID: {token_id}, Token: {tokenizer.decode([token_id])}")

print("\nAfter lowercasing:")
textlower = text.lower()
encoded = tokenizer.encode(textlower)
words = textlower.split()
print("Number of words:", len(words), "Number of tokens:", len(encoded), "pad_id:", pad_id, "eos_id:", eos_id, "bos_id:", bos_id)
print("BOS token id:", tokenizer.bos_token_id, "EOS token id:", tokenizer.eos_token_id, "PAD token id:", tokenizer.pad_token_id)
print(encoded)
for token_id in encoded:    
    print(f"Token ID: {token_id}, Token: {tokenizer.decode([token_id])}")

Original text: 1.  STABLE SMALL LEFT INTERNAL JUGULAR OPACITIES WITH PATCHY TUBE AND NASOGASTRIC TUBES, RIGHT LOWER MEDIASTINAL SIDED CATHETER.  NO SIGNIFICANT CHANGE IN THE PREVIOUS STUDYDEMONSTRATE ATELECTASIS O
Number of words: 27 Number of tokens: 76 pad_id: 50256 eos_id: 50256 bos_id: 50256
BOS token id: 50256 EOS token id: 50256 PAD token id: 50256
[50256, 16, 13, 220, 3563, 17534, 9447, 7036, 12509, 9792, 23255, 45, 1847, 449, 7340, 37232, 13349, 2246, 30383, 13315, 350, 11417, 56, 309, 10526, 36, 5357, 7210, 7730, 1921, 5446, 2149, 309, 10526, 1546, 11, 33621, 406, 36048, 26112, 40, 11262, 17961, 311, 2389, 1961, 327, 12599, 2767, 1137, 13, 220, 8005, 36771, 30643, 8643, 5870, 27746, 3268, 3336, 22814, 12861, 20958, 49348, 35755, 3620, 1340, 18601, 6158, 5161, 36, 16779, 1921, 1797, 440, 50256]
Token ID: 50256, Token: 
Token ID: 16, Token: 1
Token ID: 13, Token: .
Token ID: 220, Token:  
Token ID: 3563, Token:  ST
Token ID: 17534, Token: ABLE
Token ID: 9447, Token:  SM
Token ID

In [10]:
import re
import string

def clean_text(text: str) -> str:
    # lowercase
    text = text.lower()

    # remove enumerators like "1." or "23." but KEEP decimals like "2.5"
    # (?<!\d) ensures no digit right before; (?!\d) ensures no digit right after the dot
    text = re.sub(r'(?<!\d)\b\d+\.(?!\d)', ' ', text)

    # remove all punctuation EXCEPT "."
    punctuation = string.punctuation.replace('.', '')
    text = text.translate(str.maketrans('', '', punctuation))

    # normalize spaces around periods to " . " → ". "
    text = re.sub(r'\s*\.\s*', '. ', text)

    # collapse multiple spaces and trim
    text = re.sub(r'\s+', ' ', text).strip()

    return text

# Example
text = "1.  STABLE SMALL LEFT INTERNAL JUGULAR OPACITIES... 2.5 cm nodule; item 2. next. 3. Done."
print(clean_text(text))


# Example
text = """
 1.  INTERVAL PLACEMENT OF A RIGHT INTERNAL JUGULAR VENOUS SHEATH 
WITH THE DISTAL TIP IN THE PROXIMAL SUPERIOR VENA CAVA.  NO 
PNEUMOTHORAX.
 
 2.  STABLE POSITION OF NASOGASTRIC TUBE, FEEDING TUBE, TRACHEOSTOMY 
CANULA, LEFT INTERNAL JUGULAR CENTRAL VENOUS CATHETER, AND LEFT UPPER 
EXTREMITY PICC.  
 
 3.  NO SIGNIFICANT INTERVAL CHANGE IN HYPEREXPANDED LUNG VOLUMES, 
RIGHT BASILAR OPACITIES, SMALL BILATERAL PLEURAL EFFUSIONS, TENTING 
OF THE RIGHT HEMIDIAPHRAGM AND BIAPICAL PLEURAL THICKENING. 
 
 """
cleaned_text = clean_text(text)
print(cleaned_text)


stable small left internal jugular opacities. . . 2. 5 cm nodule item next. done.
interval placement of a right internal jugular venous sheath with the distal tip in the proximal superior vena cava. no pneumothorax. stable position of nasogastric tube feeding tube tracheostomy canula left internal jugular central venous catheter and left upper extremity picc. no significant interval change in hyperexpanded lung volumes right basilar opacities small bilateral pleural effusions tenting of the right hemidiaphragm and biapical pleural thickening.


In [11]:
encoded = tokenizer.encode(cleaned_text)
words = cleaned_text.split()
print("Number of words:", len(words), "Number of tokens:", len(encoded), "pad_id:", pad_id, "eos_id:", eos_id, "bos_id:", bos_id)
print("BOS token id:", tokenizer.bos_token_id, "EOS token id:", tokenizer.eos_token_id, "PAD token id:", tokenizer.pad_token_id)
print(encoded)
for token_id in encoded:    
    print(f"Token ID: {token_id}, Token: {tokenizer.decode([token_id])}")

Number of words: 65 Number of tokens: 112 pad_id: 50256 eos_id: 50256 bos_id: 50256
BOS token id: 50256 EOS token id: 50256 PAD token id: 50256
[50256, 3849, 2100, 13127, 286, 257, 826, 5387, 45808, 934, 8710, 516, 673, 776, 351, 262, 1233, 282, 8171, 287, 262, 14793, 4402, 9098, 410, 8107, 269, 4170, 13, 645, 29631, 849, 273, 897, 13, 8245, 2292, 286, 25221, 519, 459, 1173, 12403, 13017, 12403, 491, 4891, 455, 9145, 460, 4712, 1364, 5387, 45808, 934, 4318, 8710, 516, 3797, 43332, 290, 1364, 6727, 8963, 414, 8301, 66, 13, 645, 2383, 16654, 1487, 287, 20606, 21510, 79, 12249, 12317, 15343, 826, 1615, 1794, 1034, 330, 871, 1402, 24537, 3339, 1523, 914, 15880, 11105, 278, 286, 262, 826, 339, 13602, 72, 6570, 22562, 76, 290, 3182, 499, 605, 3339, 1523, 6546, 3101, 13, 50256]
Token ID: 50256, Token: 
Token ID: 3849, Token: inter
Token ID: 2100, Token: val
Token ID: 13127, Token:  placement
Token ID: 286, Token:  of
Token ID: 257, Token:  a
Token ID: 826, Token:  right
Token ID: 5387, Token: