In [31]:
import yaml
import torch
from models.model_pretrain import ALBEF
from models.tokenization_bert import BertTokenizer
from torchvision import transforms
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score

In [2]:
# Load config
config = yaml.safe_load(open("configs/Pretrain.yaml"))
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = ALBEF(config=config, text_encoder="bert-base-uncased", tokenizer=tokenizer, init_deit=False)

In [3]:
# Load checkpoint
ckpt = torch.load("output_mimic_a40_transformations/checkpoint_29.pth", map_location="cpu")

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model.load_state_dict(ckpt["model"], strict=False)
model.to(device)
model.eval()

Using device: cpu


ALBEF(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
     

In [9]:
labels_csv = "../data/vindr_cxr/annotations/image_labels_test.csv"
images_root = Path("../data/vindr_cxr/test")

In [10]:
df = pd.read_csv(labels_csv)
df.head()

Unnamed: 0,image_id,Aortic enlargement,Atelectasis,Calcification,Cardiomegaly,Clavicle fracture,Consolidation,Edema,Emphysema,Enlarged PA,...,Pneumothorax,Pulmonary fibrosis,Rib fracture,Other lesion,COPD,Lung tumor,Pneumonia,Tuberculosis,Other disease,No finding
0,e0dc2e79105ad93532484e956ef8a71a,0,1,1,1,0,0,0,0,0,...,1,0,0,0,0,0,1,0,1,0
1,0aed23e64ebdea798486056b4f174424,0,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,1,0,0,0
2,aa15cfcfca7605465ca0513902738b95,0,0,0,0,0,0,0,0,0,...,0,1,0,0,0,0,0,1,0,0
3,665c4a6d2693dc0286d65ab479c9b169,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0
4,42da2c134b53cb5594774d3d29faac59,1,0,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0


In [14]:
df.shape

(3000, 30)

In [11]:
id_col = "image_id"
label_cols = df.columns[1:]

In [12]:
def has_png(row):
    img_path = images_root / f"{row[id_col]}.png"
    return img_path.exists()

In [13]:
df["has_png"] = df.apply(has_png, axis=1)
df_valid = df[df["has_png"]].reset_index(drop=True)

In [15]:
df_valid.shape

(3000, 30)

In [16]:
df_subset = df_valid.iloc[:10].copy()

In [17]:
y_true = df_subset[label_cols].values.astype(int)
num_labels = y_true.shape[1]

In [18]:
num_labels

28

In [19]:
prompts = list(label_cols)
print("Number of labels:", len(prompts))

text_inputs = tokenizer(
    prompts,
    padding=True,
    truncation=True,
    max_length=25,
    return_tensors="pt",
)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}

Number of labels: 28


In [20]:
def get_text_features(model, text_inputs):
    text_output = model.text_encoder.bert(
        input_ids=text_inputs["input_ids"],
        attention_mask=text_inputs["attention_mask"],
        return_dict=True,
        mode="text",
    )
    text_cls = text_output.last_hidden_state[:, 0, :]   
    text_feat = model.text_proj(text_cls)               
    text_feat = F.normalize(text_feat, dim=-1)
    return text_feat

with torch.no_grad():
    text_feat = get_text_features(model, text_inputs)   
    text_feat = text_feat.to(device)

In [21]:
normalize = transforms.Normalize(
    (0.48145466, 0.4578275, 0.40821073),
    (0.26862954, 0.26130258, 0.27577711),
)
transform = transforms.Compose([
    transforms.Resize((config["image_res"], config["image_res"])),
    transforms.ToTensor(),
    normalize,
])

In [22]:
def get_image_features(model, images):
    image_embeds = model.visual_encoder(images)         
    image_cls = image_embeds[:, 0, :]                   
    image_feat = model.vision_proj(image_cls)           
    image_feat = F.normalize(image_feat, dim=-1)
    return image_feat

In [24]:
all_scores = []   # list of (L,) arrays

model.eval()
with torch.no_grad():
    for idx, row in df_subset.iterrows():
        image_id = row[id_col]
        img_path = images_root / f"{image_id}.png"
        img = Image.open(img_path).convert("RGB")

        img_t = transform(img).unsqueeze(0).to(device)   # (1,3,H,W)
        image_feat = get_image_features(model, img_t)    # (1, D)

        # cosine similarity: (1, L)
        sims = image_feat @ text_feat.t()
        scores = sims.squeeze(0).cpu().numpy()           # (L,)
        all_scores.append(scores)

        print(f"[{idx+1}/{10}] {image_id} processed")

[1/10] e0dc2e79105ad93532484e956ef8a71a processed
[2/10] 0aed23e64ebdea798486056b4f174424 processed
[3/10] aa15cfcfca7605465ca0513902738b95 processed
[4/10] 665c4a6d2693dc0286d65ab479c9b169 processed
[5/10] 42da2c134b53cb5594774d3d29faac59 processed
[6/10] c7179539654a1b3b7977e56e7e3009d5 processed
[7/10] bfd1974dc9778aadb407a11b57ab748f processed
[8/10] 618777b8305b062583337d9a6b7a3d4e processed
[9/10] e54b5a593bc03c789ecdc18d8270964e processed
[10/10] 3019aec706bd013e1e3348564fbfd086 processed


In [27]:
scores = np.vstack(all_scores)   
print("scores shape:", scores.shape)
print("y_true shape:", y_true.shape)

scores shape: (10, 28)
y_true shape: (10, 28)


In [32]:
print("\nPer-label ROC–AUC (on first 10 images):")
auc_results = {}
for j, label in enumerate(label_cols):
    y = y_true[:, j]
    # need at least one positive and one negative to compute AUC
    if y.sum() == 0 or y.sum() == len(y):
        print(f"{label:30s} : AUC = N/A (only one class present)")
        auc_results[label] = None
        continue
    try:
        auc = roc_auc_score(y, scores[:, j])
        auc_results[label] = auc
        print(f"{label:30s} : AUC = {auc:.3f}")
    except ValueError as e:
        print(f"{label:30s} : AUC error ({e})")
        auc_results[label] = None


Per-label ROC–AUC (on first 10 images):
Aortic enlargement             : AUC = 0.083
Atelectasis                    : AUC = 1.000
Calcification                  : AUC = 0.208
Cardiomegaly                   : AUC = 1.000
Clavicle fracture              : AUC = N/A (only one class present)
Consolidation                  : AUC = 0.889
Edema                          : AUC = N/A (only one class present)
Emphysema                      : AUC = N/A (only one class present)
Enlarged PA                    : AUC = N/A (only one class present)
ILD                            : AUC = 0.625
Infiltration                   : AUC = 0.889
Lung Opacity                   : AUC = N/A (only one class present)
Lung cavity                    : AUC = N/A (only one class present)
Lung cyst                      : AUC = N/A (only one class present)
Mediastinal shift              : AUC = 0.667
Nodule/Mass                    : AUC = 0.312
Pleural effusion               : AUC = 1.000
Pleural thickening             : 

In [34]:
# For each image, check if the top-scoring label is one of the positive GT labels
top1_idx = scores.argmax(axis=1)   # (N,)
hits = 0
for i in range(10):
    if y_true[i, top1_idx[i]] == 1:
        hits += 1

top1_acc = hits / 10
print(f"\nTop-1 accuracy (first {10} images): {top1_acc:.3f}  ({hits}/{10} correct)")


Top-1 accuracy (first 10 images): 0.000  (0/10 correct)
