In [4]:
import torch

print(torch.nn.functional.scaled_dot_product_attention.__doc__)


scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> Tensor:

    Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed,
    and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be
    specified as a keyword argument.

    .. code-block:: python

        # Efficient implementation equivalent to the following:
        def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
                is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
            L, S = query.size(-2), key.size(-2)
            scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
            attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
            if is_causal:
                assert attn_mask is None
                temp_mask = torch.ones(

In [1]:
import os
import torch
from torch.utils.data import DataLoader
from utils.train_comparison import *
from utils.processing import image_transform
from utils.data.chexpert_dataset import CheXpertDataset

In [2]:
tok = build_tokenizer_from_labels(gpt2=True)
pad_id = tok.pad_token_id
eos_id = tok.eos_token_id
bos_id = tok.bos_token_id

Using GPT2 tokenizer.


In [3]:
from utils.data.dataloaders import create_dataloaders

# CheXpert
CHEXPERT_DIR = "Datasets/CheXpertPlus"
chexpert_paths = {
    "chexpert_data_path": f"{CHEXPERT_DIR}/PNG",  # base PNG folder
    "chexpert_data_csv": f"{CHEXPERT_DIR}/df_chexpert_plus_240401_findings.csv",
}

# MIMIC
MIMIC_DIR = "Datasets/MIMIC"
mimic_paths = {
    "mimic_data_path": MIMIC_DIR,
    "mimic_splits_csv": f"{MIMIC_DIR}/mimic-cxr-2.0.0-split.csv.gz",
    "mimic_metadata_csv": f"{MIMIC_DIR}/mimic-cxr-2.0.0-metadata-findings-only.csv",
    "mimic_reports_path": f"{MIMIC_DIR}/cxr-record-list.csv.gz",  # must contain 'path'
    "mimic_images_dir": f"{MIMIC_DIR}/matched_images_and_masks_mimic_224/images",
}

import os
kwargs = {
    # "num_workers": os.cpu_count() // 2 if os.cpu_count() else 4,  # adjust on your VM
    # "persistent_workers": True,           # reuses workers between iterations
    # "prefetch_factor": 4,                 # each worker prefetches batches
    # "pin_memory": True,                   # if using CUDA
    # "drop_last": False
}

test_loader = create_dataloaders(
    chexpert_paths, 
    mimic_paths, 
    batch_size=4,
    split="test", 
    sampling_ratio=0.7,
    **kwargs
)

Filtering rows with missing PNGs...
[INFO] Kept 19/63 rows with existing PNGs


In [4]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from utils.text_metrics import evaluate_all_metrics, save_metrics_to_json
# Load weights directly to DEVICE
from utils.models.complete_model import create_complete_model, load_complete_model

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEGMENTER_MODEL_PATH="models/dino_unet_decoder_finetuned.pth"
model = create_complete_model(device=DEVICE, SEGMENTER_MODEL_PATH=SEGMENTER_MODEL_PATH)
best_model_path = "checkpoints/model_best.pth"
ckpt = torch.load(best_model_path, map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"], strict=False)
model.eval()

generated_text, target_text = [], []
iteration = 0

with torch.inference_mode():
    for pixel_values, ids_loader, paths, raw_labels in test_loader:
        iteration += 1
        
        pixel_values = pixel_values.to(model.device, non_blocking=True)

        # Visual path
        patches = model.encoder(pixel_values)                           # [B,Np,Cenc]
        projected_patches = model.linear_projection(patches)            # [B,Np,n_embd]

        # Segmentation path per layer
        segmented_layers = model.segmenter(pixel_values, model.num_layers) # [B,n_layers,H,W] (per current decoder)


        # Generate (disable all plotting/diagnostics for speed)
        gen_ids = model.decoder.generate(
            inputs_embeds=projected_patches,
            max_new_tokens=100,
            do_sample=False,
            repetition_penalty=1.2,
            eos_token_id=eos_id,
            pad_token_id=pad_id,
            use_cache=True,
            segmentation_mask=segmented_layers,
            prefix_allowed_length=0,
            plot_attention_mask=False,
            plot_attention_mask_layer=[],
            plot_attention_map=False,
            plot_attention_map_layer=[],
            plot_attention_map_generation=0,
        )
        # Move only the ids needed for decoding to CPU
        texts = model.tokenizer.batch_decode(gen_ids.detach().cpu(), skip_special_tokens=True)

        # Accumulate for final metric pass (metrics often run on CPU/strings anyway)
        generated_text.extend(texts)
        target_text.extend(ids_loader)

        if iteration >= 200:  # your test cap
            break

# Evaluate once per model
eval_results = evaluate_all_metrics(
    generated=generated_text,
    original=target_text,
    evaluation_mode="CheXagent"
)

print(f"\nOverall results for model trained {100} epochs:")
for metric, scores in eval_results.items():
    print(f"{metric}: {scores}")

# add training walltime you tracked
eval_results["training_time_seconds"] = 3600*7

# # Save metrics
# save_metrics_to_json(
#     eval_results,
#     f"lstm-vs-gpt/results/cloud_best_model_{100}_Chexpert.json"
# )

Loaded segmenter weights from models/dino_unet_decoder_finetuned.pth


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Using device: cuda:0

Overall results for model trained 100 epochs:
chexbert_f1_weighted: 0.22205567772508789
chexbert_f1_micro: 0.29654100272055967
chexbert_f1_macro: 0.14572258507884892
chexbert_f1_micro_5: 0.4147615937295885
chexbert_f1_macro_5: 0.29443589263681913
radgraph_f1_RG_E: 0.20190294429311714
radgraph_f1_RG_ER: 0.16685646459235928
