In [12]:
import sys
import os

src_path = os.path.split(os.getcwd())[0]
sys.path.insert(0, src_path)

import json
import logging
import numpy as np
import pandas as pd
from pathlib import Path, PurePath
from collections import OrderedDict
from itertools import chain

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import clip.clip as clip
from training.datasets import CellPainting
from clip.clip import _transform
from clip.model import convert_weights, CLIPGeneral
from tqdm import tqdm

from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs

from sklearn.metrics import accuracy_score, top_k_accuracy_score

from huggingface_hub import hf_hub_download

In [13]:
FILENAME = "cloome-retrieval-zero-shot.pt"
REPO_ID = "anasanchezf/cloome"
checkpoint_path = hf_hub_download(REPO_ID, FILENAME)

In [14]:
# CLOOB
model = "RN50"
image_resolution = [520, 696]
img_path = "<your-image-path>"
mol_path = "morgan_chiral_fps_1024.hdf5"
val = "<your-path>/cellpainting-split-test-imgpermol.csv"


In [15]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()

In [16]:
def load(model_path, device, model, image_resolution):
    checkpoint = torch.load(model_path)
    state_dict = checkpoint["state_dict"]

    model_config_file = os.path.join(src_path, f"training/model_configs/{model.replace('/', '-')}.json")
    print('Loading model from', model_config_file)
    assert os.path.exists(model_config_file)
    with open(model_config_file, 'r') as f:
        model_info = json.load(f)
    model = CLIPGeneral(**model_info)

    if str(device) == "cpu":
        model.float()
    print(device)

    new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()}

    model.load_state_dict(new_state_dict)
    model.to(device)
    model.eval()

    return model, _transform(image_resolution, image_resolution,  is_train=False)

In [17]:
def get_features(dataset, model, device):
    all_image_features = []
    all_text_features = []
    all_ids = []

    print(f"get_features {device}")
    print(len(dataset))

    with torch.no_grad():
        for batch in tqdm(DataLoader(dataset, num_workers=20, batch_size=64)):
            #print(mols)
            imgs, mols = batch

            images, mols = imgs["input"], mols["input"]
            ids = imgs["ID"]
            
            img_features = model.encode_image(images.to(device))
            text_features = model.encode_text(mols.to(device))

            img_features = img_features / img_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            all_image_features.append(img_features)
            all_text_features.append(text_features)
            all_ids.append(ids)

        all_ids = list(chain.from_iterable(all_ids))
    return torch.cat(all_image_features), torch.cat(all_text_features), all_ids

In [18]:
def main(df, model_path, model, img_path, mol_path, image_resolution):
    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(torch.cuda.device_count())

    model, preprocess = load(model_path, device, model, image_resolution)

    preprocess_train = _transform(image_resolution, image_resolution, is_train=True)
    preprocess_val = _transform(image_resolution, image_resolution, is_train=False, normalize="dataset", preprocess="crop")

    # Load the dataset
    val = CellPainting(df,
                       img_path,
                       mol_path,
                       transforms = preprocess_val)

    # Calculate the image features
    print("getting_features")
    val_img_features, val_text_features, val_ids = get_features(val, model, device)
    
    return val_img_features, val_text_features, val_ids

In [19]:
def get_metrics(image_features, text_features):
    metrics = {}
    logits_per_image = image_features @ text_features.t()
    logits_per_text = logits_per_image.t()

    logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
    ground_truth = (
        torch.arange(len(text_features)).view(-1, 1).to(logits_per_image.device)
    )

    rankings = {}
    all_top_samples = {}
    all_preds = {}

    for name, logit in logits.items():
        ranking = torch.argsort(logit, descending=True)
        rankings[name] = ranking
        preds = torch.where(ranking == ground_truth)[1]
        preds = preds.detach().cpu().numpy()
        all_preds[name] = preds
        top_samples = np.where(preds < 10)[0]
        all_top_samples[name] = top_samples
        metrics[f"{name}_mean_rank"] = preds.mean() + 1
        metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
        for k in [1, 5, 10]:
            metrics[f"{name}_R@{k}"] = np.mean(preds < k)

    return rankings, all_top_samples, all_preds, metrics, logits

In [20]:
val_img_features, val_text_features, val_ids = main(val, checkpoint_path, model, img_path, mol_path, image_resolution)

2
Loading model from /publicwork/sanchez/cloob/src/training/model_configs/RN50.json
cuda
2115
getting_features
get_features cuda
2115


100%|███████████████████████████████████████████████████████████████████████████| 34/34 [00:34<00:00,  1.02s/it]


In [21]:
rankings, all_top_samples, all_preds, metrics, logits = get_metrics(val_img_features, val_text_features)

In [22]:
metrics

{'image_to_text_mean_rank': 673.3040189125295,
 'image_to_text_median_rank': 524.0,
 'image_to_text_R@1': 0.030260047281323876,
 'image_to_text_R@5': 0.06619385342789598,
 'image_to_text_R@10': 0.08416075650118203,
 'text_to_image_mean_rank': 673.0912529550827,
 'text_to_image_median_rank': 549.0,
 'text_to_image_R@1': 0.03309692671394799,
 'text_to_image_R@5': 0.062411347517730496,
 'text_to_image_R@10': 0.0789598108747045}

In [23]:
ground_truth = (
    torch.arange(len(val_text_features)).view(-1, 1).to("cpu")
)
ground_truth

tensor([[   0],
        [   1],
        [   2],
        ...,
        [2112],
        [2113],
        [2114]])

In [24]:
all_preds = []

for i, logs in enumerate(logits["image_to_text"]):
    choices = np.arange(len(val_text_features))
    choices = np.delete(choices, i)
        
    logs = logs.cpu().numpy()
    
    positive = logs[i]
    negatives_ind = np.random.choice(choices, 99, replace=False)
    negatives = logs[negatives_ind]
    
    sampled_logs = np.hstack([positive, negatives])
    
    ground_truth = np.zeros(len(sampled_logs))
    ground_truth[0] = 1
    
    ranking = np.argsort(sampled_logs)
    ranking = np.flip(ranking)
    pred = np.where(ranking == 0)[0]
    all_preds.append(pred)


all_preds = np.vstack(all_preds)
print(all_preds)

r1 = np.mean(all_preds < 1) * 100
r5 = np.mean(all_preds < 5) * 100
r10 = np.mean(all_preds < 10) * 100
print(r1, r5, r10)

n1 = len(np.where(all_preds < 1)[0])
n5 = len(np.where(all_preds < 5)[0])
n10 = len(np.where(all_preds < 10)[0])
print(n1, n5, n10)

[[ 0]
 [18]
 [51]
 ...
 [ 0]
 [42]
 [45]]
10.212765957446807 21.465721040189127 30.49645390070922
216 454 645


In [25]:
all_preds_t = []

for i, logs in enumerate(logits["text_to_image"]):
    choices = np.arange(len(val_text_features))
    choices = np.delete(choices, i)
    
    logs = logs.cpu().numpy()
    
    positive = logs[i]
    negatives_ind = np.random.choice(choices, 99, replace=False)
    negatives = logs[negatives_ind]
    
    sampled_logs = np.hstack([positive, negatives])
    
    ground_truth = np.zeros(len(sampled_logs))
    ground_truth[0] = 1
    
    ranking = np.argsort(sampled_logs)
    ranking = np.flip(ranking)
    pred = np.where(ranking == 0)[0]
    all_preds_t.append(pred)

all_preds_t = np.vstack(all_preds_t)
print(all_preds_t)

r1_t = np.mean(all_preds_t < 1) * 100
r5_t = np.mean(all_preds_t < 5) * 100
r10_t = np.mean(all_preds_t < 10) * 100
print(r1_t, r5_t, r10_t)

n1_t = len(np.where(all_preds_t < 1)[0])
n5_t = len(np.where(all_preds_t < 5)[0])
n10_t = len(np.where(all_preds_t < 10)[0])
print(n1_t, n5_t, n10_t)

[[ 0]
 [41]
 [50]
 ...
 [ 0]
 [32]
 [27]]
10.401891252955082 21.087470449172578 29.408983451536642
220 446 622


In [27]:
from scipy.stats import binomtest

n_samples = 2115
value = 0.030260047281323876*100

successes = int(round(value * n_samples / 100))
print(successes)

btest = binomtest(k=successes, n=n_samples)
result = btest.proportion_estimate * 100
ci = btest.proportion_ci(confidence_level=0.95)
    
print(result)
print(ci)

64
3.0260047281323876
ConfidenceInterval(low=0.023380231995578697, high=0.03847873199349217)
