In [1]:
# %%
from pathlib import Path

import torch
import numpy as np
import random
import pickle
from absl import logging
from absl.flags import FLAGS
from cellot import losses
from cellot.utils.loaders import load
from cellot.models.cellot import compute_loss_f, compute_loss_g, compute_w2_distance
from cellot.train.summary import Logger
from cellot.data.utils import cast_loader_to_iterator
from cellot.models.ae import compute_scgen_shift
from tqdm import trange

from cellot.models.ae import AutoEncoder

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
logger = logging.getLogger("data_logger")
logger.setLevel(logging.INFO)

import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, Linear

In [2]:
DEBUG = False
TARGET = 'all' if not DEBUG else 'abexinostat'
LATENT_DIM = 50
COND_CLASSES = 189 if not DEBUG else 2

from pathlib import Path
outdir_path = '/Mounts/rbg-storage1/users/johnyang/cellot/results/sciplex3/full_ae'
outdir = Path(outdir_path)

# %%
outdir.mkdir(exist_ok=True, parents=True)

cachedir = outdir / "cache"
cachedir.mkdir(exist_ok=True)

In [3]:

import torch
import GPUtil
import os

def get_free_gpu():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    # Set environment variables for which GPUs to use.
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    chosen_gpu = ''.join(
        [str(x) for x in GPUtil.getAvailable(order='memory')])
    os.environ["CUDA_VISIBLE_DEVICES"] = chosen_gpu
    print(f"Using GPUs: {chosen_gpu}")
    return chosen_gpu

status = cachedir / "status"
status.write_text("running")

device = f'cuda:{get_free_gpu()}'

cuda
Using GPUs: 1


In [4]:

# %%
import omegaconf

if DEBUG:
    n_iters = 250000
    batch_size = 256
else:
    n_iters = 250000
    batch_size = 256

yaml_str = f"""
model:
   name: scgen
   beta: 0.0
   dropout: 0.0
   hidden_units: [512, 512]
   latent_dim: 50

optim:
   lr: 0.001
   optimizer: Adam
   weight_decay: 1.0e-05

scheduler:
   gamma: 0.5
   step_size: 100000

training:
  cache_freq: 10000
  eval_freq: 2500
  logs_freq: 250
  n_iters: {n_iters}

data:
  type: cell
  source: control
  condition: drug
  path: /Mounts/rbg-storage1/users/johnyang/cellot/datasets/scrna-sciplex3/hvg.h5ad
  target: {TARGET}

datasplit:
    groupby: drug   
    name: train_test
    test_size: 0.2
    random_state: 0

dataloader:
    batch_size: {batch_size}
    shuffle: true
"""

config = omegaconf.OmegaConf.create(yaml_str)


In [5]:

# %% [markdown]
# ### Utils

# %%
def load_lr_scheduler(optim, config):
    if "scheduler" not in config:
        return None

    return torch.optim.lr_scheduler.StepLR(optim, **config.scheduler)

def check_loss(*args):
    for arg in args:
        if torch.isnan(arg):
            raise ValueError


def load_item_from_save(path, key, default):
    path = Path(path)
    if not path.exists():
        return default

    ckpt = torch.load(path)
    if key not in ckpt:
        logging.warn(f"'{key}' not found in ckpt: {str(path)}")
        return default

    return ckpt[key]

# %%
import cellot.models
# from cellot.data.cell import load_cell_data


In [6]:
def load_data(config, **kwargs):
    data_type = config.get("data.type", "cell")
    if data_type in ["cell", "cell-merged", "tupro-cohort"]:
        loadfxn = load_cell_data

    elif data_type == "toy":
        loadfxn = load_toy_data

    else:
        raise ValueError

    return loadfxn(config, **kwargs)


def load_model(config, device, restore=None, **kwargs):
    # def load_autoencoder_model(config, restore=None, **kwargs):
    
    def load_optimizer(config, params):
        kwargs = dict(config.get("optim", {}))
        assert kwargs.pop("optimizer", "Adam") == "Adam"
        optim = torch.optim.Adam(params, **kwargs)
        return optim


    def load_networks(config, **kwargs):
        kwargs = kwargs.copy()
        kwargs.update(dict(config.get("model", {})))
        name = kwargs.pop("name")

        if name == "scgen":
            model = AutoEncoder

        # elif name == "cae":
        #     model = ConditionalAutoEncoder
        else:
            raise ValueError

        return model(**kwargs)
    
    model = load_networks(config, **kwargs)
    optim = load_optimizer(config, model.parameters())

    if restore is not None and Path(restore).exists():
        print('Loading model from checkpoint')
        ckpt = torch.load(restore, map_location=device)
        model.load_state_dict(ckpt["model_state"])
        optim.load_state_dict(ckpt["optim_state"])
        if config.model.name == "scgen" and "code_means" in ckpt:
            model.code_means = ckpt["code_means"]
            
    # logger.info(f'Model on device {next(model.parameters()).device}')

    return model, optim

def load(config, device, restore=None, include_model_kwargs=False, **kwargs):

    loader, model_kwargs = load_data(config, include_model_kwargs=True, **kwargs)

    model, opt = load_model(config, device, restore=restore, **model_kwargs)

    # if include_model_kwargs:
    #     return model, opt, loader, model_kwargs

    return model, opt, loader

# %% [markdown]
# ### Training


In [7]:
ae = load_model(config, 'cuda', restore=cachedir / "last.pt", input_dim=1000)

Loading model from checkpoint


KeyError: 'model_state'

In [10]:
from cellot.data.cell import *

def load_cell_data(
    config,
    data=None,
    split_on=None,
    return_as="loader",
    include_model_kwargs=False,
    pair_batch_on=None,
    **kwargs
):

    if isinstance(return_as, str):
        return_as = [return_as]

    assert set(return_as).issubset({"anndata", "dataset", "loader"})
    config.data.condition = config.data.get("condition", "drug")
    condition = config.data.condition
    
    data = read_single_anndata(config, **kwargs)

    # if "ae_emb" in config.data:
        # load path to autoencoder
        # assert config.get("model.name", "cellot") == "cellot"
    # path_ae = Path(outdir_path)
    # model_kwargs = {"input_dim": data.n_vars}
    # config_ae = load_config('/Mounts/rbg-storage1/users/johnyang/cellot/configs/models/scgen.yaml')
    # ae_model, _ = load_autoencoder_model(
    #     config_ae, restore=path_ae / "cache/model.pt", **model_kwargs
    # )

    inputs = torch.Tensor(
        data.X if not sparse.issparse(data.X) else data.X.todense()
    )

    # genes = data.var_names.to_list()
    # data = anndata.AnnData(
    #     ae[0].eval().encode(inputs).detach().numpy(),
    #     obs=data.obs.copy(),
    #     uns=data.uns.copy(),
    # )
    # data.uns["genes"] = genes

    # cast to dense and check for nans
    if sparse.issparse(data.X):
        data.X = data.X.todense()
    assert not np.isnan(data.X).any()

    dataset_args = dict()
    model_kwargs = {}

    model_kwargs["input_dim"] = data.n_vars

    # if config.get("model.name") == "cae":
    condition_labels = sorted(data.obs[condition].cat.categories)
    model_kwargs["conditions"] = condition_labels
    dataset_args["obs"] = condition
    dataset_args["categories"] = condition_labels

    if "training" in config:
        pair_batch_on = config.training.get("pair_batch_on", pair_batch_on)

    if split_on is None:
        if config.model.name == "cellot":
            # datasets & dataloaders accessed as loader.train.source
            split_on = ["split", "transport"]
            if pair_batch_on is not None:
                split_on.append(pair_batch_on)

        elif (config.model.name == "scgen" or config.model.name == "cae"
              or config.model.name == "popalign"):
            split_on = ["split"]

        else:
            raise ValueError

    if isinstance(split_on, str):
        split_on = [split_on]

    for key in split_on:
        assert key in data.obs.columns

    if len(split_on) > 0:
        splits = {
            (key if isinstance(key, str) else ".".join(key)): data[index]
            for key, index in data.obs[split_on].groupby(split_on).groups.items()
        }

        dataset = nest_dict(
            {
                key: AnnDataDataset(val.copy(), **dataset_args)
                for key, val in splits.items()
            },
            as_dot_dict=True,
        )
    else:
        dataset = AnnDataDataset(data.copy(), **dataset_args)

    if "loader" in return_as:
        kwargs = dict(config.dataloader)
        kwargs.setdefault("drop_last", True)
        loader = cast_dataset_to_loader(dataset, **kwargs)

    returns = list()
    for key in return_as:
        if key == "anndata":
            returns.append(data)

        elif key == "dataset":
            returns.append(dataset)

        elif key == "loader":
            returns.append(loader)

    if include_model_kwargs:
        returns.append(model_kwargs)

    if len(returns) == 1:
        return returns[0]

    # returns.append(data)

    return tuple(returns)

In [12]:
cond_datasets = load_cell_data(config, return_as="dataset")

2023-07-12 04:14:44,839 Loaded cell data with TARGET all and OBS SHAPE (762039, 16)


In [13]:
cond_datasets.test.dataset.adata.obs['drug']

AttributeError: 'AnnDataDataset' object has no attribute 'dataset'

In [17]:
t = cond_datasets.test

In [21]:
t.adata

AnnData object with n_obs × n_vars = 152485 × 1000
    obs: 'size_factor', 'cell_type', 'replicate', 'dose', 'drug_code', 'pathway_level_1', 'pathway_level_2', 'product_name', 'target', 'pathway', 'drug', 'drug-dose', 'drug_code-dose', 'n_genes', 'transport', 'split'
    var: 'gene_short_name', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'pca', 'rank_genes_groups'
    obsm: 'X_pca'
    varm: 'PCs', 'marker_genes-drug-rank', 'marker_genes-drug-score'

In [16]:
cond_datasets.test.categories

['2_methoxyestradiol',
 '_jq1',
 'a_366',
 'abexinostat',
 'abt_737',
 'ac480',
 'ag_14361',
 'ag_490',
 'aicar',
 'alendronate_sodium_trihydrate',
 'alisertib',
 'altretamine',
 'alvespimycin_hcl',
 'amg_900',
 'aminoglutethimide',
 'amisulpride',
 'anacardic_acid',
 'andarine',
 'ar_42',
 'at9283',
 'aurora_a_inhibitor_i',
 'avagacestat',
 'az_960',
 'azacitidine',
 'azd1480',
 'barasertib',
 'baricitinib',
 'belinostat',
 'bisindolylmaleimide_ix',
 'bms_265246',
 'bms_536924',
 'bms_754807',
 'bms_911543',
 'bosutinib',
 'brd4770',
 'busulfan',
 'capecitabine',
 'carmofur',
 'cediranib',
 'celecoxib',
 'cep_33779',
 'cerdulatinib',
 'cimetidine',
 'clevudine',
 'control',
 'costunolide',
 'crizotinib',
 'cudc_101',
 'cudc_907',
 'curcumin',
 'cyc116',
 'cyclocytidine_hcl',
 'dacinostat',
 'danusertib',
 'daphnetin',
 'dasatinib',
 'decitabine',
 'disulfiram',
 'divalproex_sodium',
 'droxinostat',
 'eed226',
 'ellagic_acid',
 'enmd_2076',
 'enmd_2076_l__tartaric_acid',
 'entacapone',

In [27]:
sorted(t.adata.obs['drug'].cat.categories)

['2_methoxyestradiol',
 '_jq1',
 'a_366',
 'abexinostat',
 'abt_737',
 'ac480',
 'ag_14361',
 'ag_490',
 'aicar',
 'alendronate_sodium_trihydrate',
 'alisertib',
 'altretamine',
 'alvespimycin_hcl',
 'amg_900',
 'aminoglutethimide',
 'amisulpride',
 'anacardic_acid',
 'andarine',
 'ar_42',
 'at9283',
 'aurora_a_inhibitor_i',
 'avagacestat',
 'az_960',
 'azacitidine',
 'azd1480',
 'barasertib',
 'baricitinib',
 'belinostat',
 'bisindolylmaleimide_ix',
 'bms_265246',
 'bms_536924',
 'bms_754807',
 'bms_911543',
 'bosutinib',
 'brd4770',
 'busulfan',
 'capecitabine',
 'carmofur',
 'cediranib',
 'celecoxib',
 'cep_33779',
 'cerdulatinib',
 'cimetidine',
 'clevudine',
 'control',
 'costunolide',
 'crizotinib',
 'cudc_101',
 'cudc_907',
 'curcumin',
 'cyc116',
 'cyclocytidine_hcl',
 'dacinostat',
 'danusertib',
 'daphnetin',
 'dasatinib',
 'decitabine',
 'disulfiram',
 'divalproex_sodium',
 'droxinostat',
 'eed226',
 'ellagic_acid',
 'enmd_2076',
 'enmd_2076_l__tartaric_acid',
 'entacapone',

In [None]:
y_set = set()
for x in cond_datasets.test:
    y_set.add(x[1])