In [None]:
import os
from pathlib import Path
import json

from conch.open_clip_custom import create_model_from_pretrained
from conch.downstream.zeroshot_path import zero_shot_classifier, run_mizero
from conch.downstream.wsi_datasets import WSIEmbeddingDataset

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pandas as pd 
import numpy as np

# display all jupyter output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
root = Path('../').resolve()
os.chdir(root)

This notebook provides an example for performing zero-shot classification by ensembling multiple prompts and prompt templates for WSIs.

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
checkpoint_path = './checkpoints/CONCH/pytorch_model.bin'
model, _ = create_model_from_pretrained(model_cfg='conch_ViT-B-16', checkpoint_path=checkpoint_path, device=device)
_ = model.eval()

In [None]:
index_col = 'slide_id' # column with the slide ids
target_col = 'OncoTreeCode' # column with the target labels
label_map = {'LUAD': 0, 'LUSC': 1} # maps values in target_col to integers

# assuming the csv has a column for slide_id (index_col) and OncoTreeCode (target_col), adjust above as needed
df = pd.read_csv('path/to/csv')
# path to the extracted embeddings, assumes the embeddings are saved as .pt files, 1 file per slide
data_source = '/path/to/extracted-embeddings/' 

df = df[df[target_col].isin(label_map.keys())].reset_index(drop=True)

dataset = WSIEmbeddingDataset(data_source = data_source,
                              df=df,
                              index_col=index_col,
                              target_col=target_col,
                              label_map=label_map)
dataloader = DataLoader(dataset, 
                        batch_size=1, 
                        shuffle=False, 
                        num_workers=4)

In [None]:
idx_to_class = {v:k for k,v in dataloader.dataset.label_map.items()}
print("num samples: ", len(dataloader.dataset))
print(idx_to_class)

In [None]:
prompt_file = './prompts/nsclc_prompts_all_per_class.json'
with open(prompt_file) as f:
    prompts = json.load(f)['0']
classnames = prompts['classnames']
templates = prompts['templates']
n_classes = len(classnames)
classnames_text = [classnames[str(idx_to_class[idx])] for idx in range(n_classes)]
for class_idx, classname in enumerate(classnames_text):
    print(f'{class_idx}: {classname}')

In [None]:
zeroshot_weights = zero_shot_classifier(model, classnames_text, templates, device=device)
print(zeroshot_weights.shape)

In [None]:
results, dump = run_mizero(model, zeroshot_weights, dataloader, device, 
                    dump_results=True, metrics=['bacc', 'weighted_f1'])

In [None]:
best_j_idx = np.argmax(list(results['bacc'].values()))
best_j = list(results['bacc'].keys())[best_j_idx]
for metric, metric_dict in results.items():
    print(f"{metric}: {metric_dict[best_j]:.3f}")
