In [2]:
import numpy as np
import pandas as pd

import sys
import os
import json
import logging
from pathlib import Path, PurePath
from collections import OrderedDict
from itertools import chain
from tqdm import tqdm

from training.datasets import CellPainting

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor

from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs

from hticnn.models import GAPNet

In [3]:
img_path = "/publicdata/cellpainting/npzs/chembl24/"
val = "cellpainting-test-phenotype-imgpermol.csv"
classes = "cellpainting-split-test-imgpermol.csv"

In [4]:
model_path = "gapnet.pth.tar"

In [5]:
checkpoint = torch.load(model_path)
state_dict = checkpoint["state_dict"]

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.nn.DataParallel(GAPNet(fc_units=1024, dropout=0, num_classes=209, input_shape=[5, 520, 696]))

In [7]:
model.load_state_dict(state_dict)
model.to(device)

DataParallel(
  (module): GAPNet(
    (block1): Sequential(
      (0): Conv2d(5, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): SELU(inplace=True)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (block2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): SELU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): SELU(inplace=True)
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): SELU(inplace=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (block3): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): SELU(inplace=True)
      (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): SELU(inplace=True)
      (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1

In [8]:
model.module.classifier = torch.nn.Sequential(*list(model.module.classifier.children())[:-3])

In [9]:
model

DataParallel(
  (module): GAPNet(
    (block1): Sequential(
      (0): Conv2d(5, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): SELU(inplace=True)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (block2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): SELU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): SELU(inplace=True)
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): SELU(inplace=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (block3): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): SELU(inplace=True)
      (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): SELU(inplace=True)
      (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1

In [10]:
def encode_image(input_tensor, model, device):

    with torch.no_grad():
        input_tensor = torch.from_numpy(input_tensor).to(device)
        output = model(input_tensor)

    return output

In [11]:
def get_features(dataset, model, device):
    all_image_features = []
    all_ids = []

    print(f"get_features {device}")
    print(len(dataset))

    with torch.no_grad():
        for batch in tqdm(DataLoader(dataset, num_workers=20, batch_size=32)):
            #print(mols)
            imgs = batch
            
            images = imgs["input"]
            ids = imgs["ID"]

            img_features = model(images.to(device))
            img_features = img_features / img_features.norm(dim=-1, keepdim=True)

            all_image_features.append(img_features)
            all_ids.append(ids)

        all_ids = list(chain.from_iterable(all_ids))
    return torch.cat(all_image_features), all_ids

In [12]:
def main(df, model, img_path, image_resolution):
    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"

    preprocess_val = ToTensor()

    # Load the dataset
    val = CellPainting(df,
                       img_path,
                       transforms = preprocess_val)

    # Calculate the image features
    print("getting_features")
    val_img_features, val_ids = get_features(val, model, device)
    
    return val_img_features, val_ids

In [13]:
val_gapnet_features, ids = main(val, model, img_path, 520)

44102
getting_features
get_features cuda
44102


100%|███████████████████████████████████████| 1379/1379 [02:51<00:00,  8.03it/s]


In [14]:
val_gapnet_features.shape

torch.Size([44102, 1024])

In [15]:
class_img_features, class_ids = main(classes, model, img_path, 520)
val_gapnet_features = val_gapnet_features.cpu()
class_img_features = class_img_features.cpu()

2115
getting_features
get_features cuda
2115


100%|███████████████████████████████████████████| 67/67 [00:17<00:00,  3.89it/s]


In [18]:
classes_df = pd.read_csv(classes)
classes_df.set_index("SAMPLE_KEY", inplace=True)
class_inchis = classes_df.loc[class_ids]["INCHIKEY"]

In [19]:
val_df = pd.read_csv(val)
val_df.set_index("SAMPLE_KEY", inplace=True)
val_inchis = val_df.loc[ids]["INCHIKEY"]

In [20]:
class_dict = {}

for i, inchi in enumerate(class_inchis): 
    class_dict[inchi] = i

In [21]:
ground_truth = np.zeros(len(val_inchis), dtype=int)

for i, inchi in enumerate(val_inchis): 
    label = class_dict[inchi]
    ground_truth[i] = int(label)

In [22]:
logits = val_gapnet_features @ class_img_features.T

In [23]:
ranking = torch.argsort(logits, descending=True)
t = torch.tensor(ground_truth, dtype=torch.int16).view(-1,1)

preds = torch.where(ranking == t)[1]
preds = preds.detach().cpu().numpy()

metrics = {}
for k in [1, 5, 10]:
    metrics[f"R@{k}"] = np.mean(preds < k) * 100
    
print(metrics)

{'R@1': 0.36279533807990566, 'R@5': 1.0747811890617207, 'R@10': 1.7958369234955331}


In [25]:
from scipy.stats import binomtest

n_samples = val_gapnet_features.shape[0]

mdict, cis = {}, {}

for metric, value in metrics.items():
    successes = int(value * n_samples / 100)
    btest = binomtest(k=successes, n=n_samples)
    mdict[metric] = btest.proportion_estimate * 100
    cis[metric] = btest.proportion_ci(confidence_level=0.95)
    
print(mdict)
print(cis)

{'R@1': 0.36279533807990566, 'R@5': 1.0747811890617207, 'R@10': 1.7958369234955331}
{'R@1': ConfidenceInterval(low=0.0030883812227122357, high=0.0042344008536001725), 'R@5': ConfidenceInterval(low=0.009806440210406235, high=0.011754395603765509), 'R@10': ConfidenceInterval(low=0.016739425902711765, high=0.01924134506550992)}
