In [None]:
import os
from pathlib import Path
import json

from conch.open_clip_custom import create_model_from_pretrained
from conch.downstream.zeroshot_path import zero_shot_classifier, run_zeroshot

import torch
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# display all jupyter output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
root = Path('../').resolve()
os.chdir(root)

This notebook provides a complete example for performing zero-shot classification by ensembling multiple prompts and prompt templates. You can use this notebook to reproduce the zero-shot classification results on CRC100K (image size = 224 x 224). 

In [None]:
model_cfg = 'conch_ViT-B-16'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
checkpoint_path = './checkpoints/CONCH/pytorch_model.bin'
force_image_size = 224
model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, device=device,
                                                 force_image_size=force_image_size)
_ = model.eval()

In [None]:
data_source = '/path/to/CRC100k/validation/set'
dataset = ImageFolder(data_source, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)
if hasattr(dataloader.dataset, 'class_to_idx'):
     idx_to_class = {v:k for k,v in dataloader.dataset.class_to_idx.items()}
else:
     raise ValueError('Dataset does not have label_map attribute')
print("num samples: ", len(dataloader.dataset))
print(idx_to_class)

In [None]:
prompt_file = './prompts/crc100k_prompts_all_per_class.json'
with open(prompt_file) as f:
    prompts = json.load(f)['0']
classnames = prompts['classnames']
templates = prompts['templates']
n_classes = len(classnames)
classnames_text = [classnames[str(idx_to_class[idx])] for idx in range(n_classes)]
for class_idx, classname in enumerate(classnames_text):
    print(f'{class_idx}: {classname}')

In [None]:
zeroshot_weights = zero_shot_classifier(model, classnames_text, templates, device=device)
print(zeroshot_weights.shape)

In [None]:
results, dump = run_zeroshot(model, zeroshot_weights, dataloader, device, 
                    dump_results=True, metrics=['bacc', 'weighted_f1'])

In [None]:
for k, v in results.items():
    print(f'{k}: {v:.3f}')