In [1]:
import json

from urllib.request import urlopen
from PIL import Image
import torch
from huggingface_hub import hf_hub_download
from open_clip import create_model_and_transforms, get_tokenizer
from open_clip.factory import HF_HUB_PREFIX, _MODEL_CONFIGS


# Load the model and config files
model_name = "biomedclip_local"

with open("checkpoints/open_clip_config.json", "r") as f:
    config = json.load(f)
    model_cfg = config["model_cfg"]
    preprocess_cfg = config["preprocess_cfg"]


if (not model_name.startswith(HF_HUB_PREFIX)
    and model_name not in _MODEL_CONFIGS
    and config is not None):
    _MODEL_CONFIGS[model_name] = model_cfg

tokenizer = get_tokenizer(model_name)

model, _, preprocess = create_model_and_transforms(
    model_name=model_name,
    pretrained="checkpoints/open_clip_pytorch_model.bin",
    **{f"image_{k}": v for k, v in preprocess_cfg.items()},
)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [34]:
# Zero-shot image classification
template = 'this is a photo of a'
labels = [
    'adenocarcinoma histopathology',
    'brain MRI',
    'covid line chart',
    'squamous cell carcinoma histopathology',
    'immunohistochemistry histopathology',
    'bone X-ray',
    'chest X-ray',
    'pie chart',
    'hematoxylin and eosin histopathology'
]

dataset_url = 'https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/resolve/main/example_data/biomed_image_classification_example_data/'
test_imgs = [
    'squamous_cell_carcinoma_histopathology.jpeg',
    'H_and_E_histopathology.jpg',
    'bone_X-ray.jpg',
    'adenocarcinoma_histopathology.jpg',
    'covid_line_chart.png',
    'IHC_histopathology.jpg',
    'chest_X-ray.jpg',
    'brain_MRI.jpg',
    'pie_chart.png'
]
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()

context_length = 256

images = torch.stack([preprocess(Image.open(urlopen(dataset_url + img))) for img in test_imgs]).to(device)
texts = tokenizer([template + l for l in labels], context_length=context_length).to(device)
with torch.no_grad():
    image_features, text_features, logit_scale = model(images, texts)

    logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
    sorted_indices = torch.argsort(logits, dim=-1, descending=True)

    logits = logits.cpu().numpy()
    sorted_indices = sorted_indices.cpu().numpy()

top_k = -1

for i, img in enumerate(test_imgs):
    pred = labels[sorted_indices[i][0]]

    top_k = len(labels) if top_k == -1 else top_k
    print(img.split('/')[-1] + ':')
    for j in range(top_k):
        jth_index = sorted_indices[i][j]
        print(f'{labels[jth_index]}: {logits[i][jth_index]}')
    print('\n')

squamous_cell_carcinoma_histopathology.jpeg:
squamous cell carcinoma histopathology: 0.9993261098861694
adenocarcinoma histopathology: 0.0005834370385855436
immunohistochemistry histopathology: 5.3123229008633643e-05
hematoxylin and eosin histopathology: 3.730620301212184e-05
pie chart: 2.1773455260998276e-10
covid line chart: 9.560737618263815e-11
brain MRI: 9.659328024935743e-12
chest X-ray: 7.296285564660845e-14
bone X-ray: 4.939202597809868e-14


H_and_E_histopathology.jpg:
adenocarcinoma histopathology: 0.9988821148872375
squamous cell carcinoma histopathology: 0.0010510372230783105
pie chart: 4.1645635064924136e-05
chest X-ray: 1.2630177479877602e-05
brain MRI: 9.431676517124288e-06
hematoxylin and eosin histopathology: 2.600068455649307e-06
immunohistochemistry histopathology: 4.84693998714647e-07
covid line chart: 3.019886563038199e-09
bone X-ray: 6.696965898500551e-11


bone_X-ray.jpg:
chest X-ray: 0.9959741234779358
bone X-ray: 0.00390210235491395
pie chart: 6.584016227861866

In [31]:
t = tokenizer("normal", context_length=context_length)
t

tensor([[   2, 2488,    3,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,  

In [29]:
import pandas as pd
metadata_csv = "/home/E19_FYP_Domain_Gen_Data/metadata.csv" 
patches_dir = "/home/E19_FYP_Domain_Gen_Data/patches"       
metadata_df = pd.read_csv(metadata_csv, index_col=[0])
metadata_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 455954 entries, 0 to 455953
Data columns (total 8 columns):
 #   Column   Non-Null Count   Dtype
---  ------   --------------   -----
 0   patient  455954 non-null  int64
 1   node     455954 non-null  int64
 2   x_coord  455954 non-null  int64
 3   y_coord  455954 non-null  int64
 4   tumor    455954 non-null  int64
 5   slide    455954 non-null  int64
 6   center   455954 non-null  int64
 7   split    455954 non-null  int64
dtypes: int64(8)
memory usage: 31.3 MB


In [39]:
metadata_df['split'].value_counts()

split
0    410359
1     45595
Name: count, dtype: int64

# Selecting a source domain (Center 0)

In [43]:
center0_df = metadata_df[metadata_df['center'] == 0].reset_index(drop=True)
center0_df['split'].value_counts()

split
0    53425
1     6011
Name: count, dtype: int64

# Testing on Center 0

In [None]:
import os
import json
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from open_clip import create_model_and_transforms, get_tokenizer
from open_clip.factory import HF_HUB_PREFIX, _MODEL_CONFIGS
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    classification_report,
    roc_auc_score
)
from tqdm import tqdm

# 1. Paths & constants
METADATA_CSV = "/home/E19_FYP_Domain_Gen_Data/metadata.csv"
PATCHES_DIR  = "/home/E19_FYP_Domain_Gen_Data/patches"
CONFIG_PATH  = "checkpoints/open_clip_config.json"
WEIGHTS_PATH = "checkpoints/open_clip_pytorch_model.bin"
MODEL_NAME   = "biomedclip_local"
CONTEXT_LENGTH = 256
BATCH_SIZE   = 512
NUM_WORKERS  = 4
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Load metadata and filter center=0
metadata_df = pd.read_csv(METADATA_CSV, index_col=0)
center0_df  = metadata_df[metadata_df.center == 0].copy()

# 3. Build filenames and full filepaths
center0_df["filename"] = center0_df.apply(
    lambda r: f"patch_patient_{r.patient:03d}_node_{r.node}_x_{r.x_coord}_y_{r.y_coord}.png",
    axis=1
)
center0_df["filepath"] = center0_df.apply(
    lambda r: os.path.join(
        PATCHES_DIR,
        f"patient_{r.patient:03d}_node_{r.node}",
        r.filename
    ), 
    axis = 1
)

# 4. Split into train/val/test
train_df = center0_df[center0_df.split == 0]
test_df  = center0_df[center0_df.split  == 1]

# 5. Define a Dataset for loading & preprocessing
class BiomedCLIPDataset(Dataset):
    def __init__(self, df, preprocess):
        self.filepaths = df["filepath"].tolist()
        self.labels    = df["tumor"].astype(int).tolist()
        self.preproc   = preprocess

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        img = Image.open(self.filepaths[idx]).convert("RGB")
        img = self.preproc(img)           # yields a torch.Tensor (C,H,W)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, label

# 6. Load BiomedCLIP model + tokenizer + preprocess
with open(CONFIG_PATH, "r") as f:
    cfg = json.load(f)
model_cfg, preproc_cfg = cfg["model_cfg"], cfg["preprocess_cfg"]

# register local config if needed
if (not MODEL_NAME.startswith(HF_HUB_PREFIX)
    and MODEL_NAME not in _MODEL_CONFIGS):
    _MODEL_CONFIGS[MODEL_NAME] = model_cfg

tokenizer = get_tokenizer(MODEL_NAME)
model, _, preprocess = create_model_and_transforms(
    model_name=MODEL_NAME,
    pretrained=WEIGHTS_PATH,
    **{f"image_{k}": v for k,v in preproc_cfg.items()}
)

model = model.to(DEVICE).eval()

# 7. Prepare DataLoader for test set
test_ds = BiomedCLIPDataset(test_df, preprocess)
test_loader = DataLoader(
    test_ds, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=NUM_WORKERS
)

# 8. Pre-tokenize your two prompts once
prompts = [
   "Tumor is not present in this image"
   "This is an image of a tumor"
]
text_inputs = tokenizer(
    prompts, context_length=CONTEXT_LENGTH
).to(DEVICE)

with torch.no_grad():
    text_feats = model.encode_text(text_inputs)               # (2, D)
    text_feats = text_feats / text_feats.norm(dim=1, keepdim=True)
    logit_scale = model.logit_scale.exp()

    all_preds  = []
    all_probs  = []
    all_labels = []

    # 9. Inference loop
    for imgs, labels in tqdm(test_loader, desc="Evaluating"):
        imgs   = imgs.to(DEVICE)
        labels = labels.to(DEVICE)

        img_feats = model.encode_image(imgs)                    # (B, D)
        img_feats = img_feats / img_feats.norm(dim=1, keepdim=True)

        logits = logit_scale * (img_feats @ text_feats.t())     # (B, 2)
        probs  = logits.softmax(dim=1)                          # (B, 2)
        preds  = logits.argmax(dim=1)                           # (B,)

        all_preds .append(preds.cpu())
        all_probs .append(probs[:, 0].cpu())  # tumor-class prob
        all_labels.append(labels.cpu())

    # concatenate
    y_pred = torch.cat(all_preds).numpy()
    y_prob = torch.cat(all_probs).numpy()
    y_true = torch.cat(all_labels).numpy()

# 10. Compute & print metrics
acc   = accuracy_score(y_true, y_pred)
cm    = confusion_matrix(y_true, y_pred)
report = classification_report(y_true, y_pred, digits=4)
auc   = roc_auc_score(y_true, y_prob)

print(f"\nTest Accuracy:    {acc:.4f}")
print(f"ROC AUC:          {auc:.4f}")
print("\nConfusion Matrix:")
print(cm)
print("\nClassification Report:")
print(report)


  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
Evaluating: 100%|██████████| 12/12 [00:18<00:00,  1.52s/it]


Test Accuracy:    0.8197
ROC AUC:          0.0937

Confusion Matrix:
[[2285  747]
 [ 337 2642]]

Classification Report:
              precision    recall  f1-score   support

           0     0.8715    0.7536    0.8083      3032
           1     0.7796    0.8869    0.8298      2979

    accuracy                         0.8197      6011
   macro avg     0.8255    0.8203    0.8190      6011
weighted avg     0.8259    0.8197    0.8189      6011




