In [None]:
import numpy as np
import pandas as pd
import sys
import os
import torch
sys.path.insert(0, os.path.abspath("src/"))

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from data_utils import CLOOMDataset_Dataset
from model_utils import PrecalculatedModel

In [None]:
from amumo import model as am_model
from amumo import data as am_data
from amumo import widgets as am_widgets
from amumo import utils as am_utils

In [None]:
current_directory = os.path.abspath('.')  # Get the absolute path of the current directory
parent_directory = os.path.dirname(current_directory)  # Get the directory name one level up
basepath = parent_directory
datapath = os.path.join(basepath, "amumo/data")
batch_size = 100
seed = 31415

In [None]:
# Load Data
dataset_cloome = CLOOMDataset_Dataset(datapath,seed,batch_size)
cloome_images, cloome_molecules, cloome_dataset_name = dataset_cloome.get_data()
cloome_dataset_name

In [None]:
def get_features(molecule_features, image_features):
    # molecule features
    mol_features_torch = torch.load(molecule_features, map_location=device)
    mol_features = mol_features_torch["mol_features"]
    mol_ids = mol_features_torch["mol_ids"]

    # microscopy features
    img_features_torch = torch.load(image_features, map_location=device)
    img_features = img_features_torch["img_features"]
    img_ids = img_features_torch["img_ids"]

    # extract subsets of features
    img_feature_idcs = [np.where(np.array(img_features_torch["img_ids"])==i)[0][0] for i in dataset_cloome.dataset["SAMPLE_KEY_img"].values]
    mol_feature_idcs = [np.where(np.array(mol_features_torch["mol_ids"])==i)[0][0] for i in dataset_cloome.dataset["SAMPLE_KEY_mol"].values]

    mol_features_sample = mol_features_torch['mol_features'][mol_feature_idcs]
    mol_features_sample = mol_features_sample.cpu()
    mol_features_sample = am_utils.l2_norm(mol_features_sample)

    img_features_sample = img_features_torch['img_features'][img_feature_idcs]
    img_features_sample = img_features_sample.cpu()
    img_features_sample = am_utils.l2_norm(img_features_sample)
    
    return mol_features_sample, img_features_sample

In [None]:
molecule_features_cloob = "/.../cloob_2022-04-09-09-47-00_mol_embedings_test.pkl"
image_features_cloob = "/.../cloob_2022-04-09-09-47-00_img_embedings_test.pkl"

molecule_features_clip = "/.../clip_2022-04-13-16-14-59_mol_embedings_test.pkl"
image_features_clip = "/.../clip_2022-04-13-16-14-59_img_embedings_test.pkl"

molecule_features_cloob_sample, image_features_cloob_sample = get_features(molecule_features_cloob, image_features_cloob)
molecule_features_clip_sample, image_features_clip_sample = get_features(molecule_features_clip, image_features_clip)

In [None]:
cloob_model = am_model.PrecalculatedModel('precalculated_cloob', cloome_dataset_name, image_features_cloob_sample, molecule_features_cloob_sample)
cloob_widget = am_widgets.CLIPExplorerWidget(cloome_dataset_name, cloome_images, cloome_molecules, models=[cloob_model])
cloob_widget

In [None]:
clip_model = am_model.PrecalculatedModel('precalculated_clip', cloome_dataset_name, image_features_clip_sample, molecule_features_clip_sample)
clip_widget = am_widgets.CLIPExplorerWidget(cloome_dataset_name, cloome_images, cloome_molecules, models=[clip_model])
clip_model