### Load LM

In [1]:
import wandb
run = wandb.init()
artifact = run.use_artifact('protein-optimization/sc_diff/model-fekp2uq8:v10', type='model')
artifact_dir = artifact.download()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjohnyang[0m ([33mprotein-optimization[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [2]:
artifact_dir

'./artifacts/model-fekp2uq8:v10'

In [3]:
from cellot.models.cond_score_module import CondScoreModule

In [4]:
ckpt_path = f'{artifact_dir}/model.ckpt'

In [19]:
YAML_STR = '''
DEBUG: False
TARGET: all
LATENT_DIM: 50
COND_CLASSES: 189
SEED: 42
AE_PATH: /Mounts/rbg-storage1/users/johnyang/cellot/results/sciplex3/full_ae
VAL_SIZE: 0.1
DEVICES: 1

diffuser:
  min_b: 0.01
  max_b: 1.0
  schedule: exponential
  score_scaling: var
  coordinate_scaling: 1.0
  latent_dim: ${LATENT_DIM}
  dt: 0.01
  min_t: 0

ae:
  name: scgen
  beta: 0.0
  dropout: 0.0
  hidden_units: [512, 512]
  latent_dim: 50

score_network:
  latent_dim: ${LATENT_DIM}
  cond_classes: ${COND_CLASSES}
  model_dim: 64   # Adjusted to 64
  n_layers: 12    # Adjusted to 12
  nhead: 8
  dim_feedforward: 2048
  dropout: 0.1
  ffn_hidden_dim: 1024


data:
  type: cell
  source: control
  condition: drug
  path: /Mounts/rbg-storage1/users/johnyang/cellot/datasets/scrna-sciplex3/hvg.h5ad
  target: trametinib

datasplit:
  groupby: drug   
  name: train_test
  test_size: 0.2
  random_state: 0
  
dataloader:
  batch_size: 256   # Adjusted to 256
  shuffle: true
  num_workers: 80
  
experiment:
  name: base
  mode: train
  num_loader_workers: 0
  port: 12319
  dist_mode: single
  use_wandb: True
  ckpt_path: null
  wandb_logger:
    project: sc_diff
    name: ${experiment.name}
    dir: /Mounts/rbg-storage1/users/johnyang/cellot/
    log_model: all
    tags: ['experimental']
  lr: 0.0001


trainer:
  accelerator: 'gpu'
  check_val_every_n_epoch: 50
  log_every_n_steps: 100
  num_sanity_val_steps: 1
  enable_progress_bar: True
  enable_checkpointing: True
  fast_dev_run: False
  profiler: simple
  max_epochs: 10000
  strategy: auto
  enable_model_summary: True
  overfit_batches: 0.0
  limit_train_batches: 1.0
  limit_val_batches: 1.0
  limit_predict_batches: 1.0
'''

In [20]:
from omegaconf import OmegaConf
config = OmegaConf.create(YAML_STR)

In [21]:
ckpt_path = f'{artifact_dir}/model.ckpt'

In [23]:
from cellot.train.utils import get_free_gpu
replica_id = int(get_free_gpu())

cuda
Using GPUs: 1


In [24]:
device = f'cuda:{replica_id}'

In [33]:
lm = CondScoreModule.load_from_checkpoint(hparams=config, checkpoint_path=ckpt_path).to(device)
print('')

Dropout is 0.1



### Load Data

In [25]:
config

{'DEBUG': False, 'TARGET': 'all', 'LATENT_DIM': 50, 'COND_CLASSES': 189, 'SEED': 42, 'AE_PATH': '/Mounts/rbg-storage1/users/johnyang/cellot/results/sciplex3/full_ae', 'VAL_SIZE': 0.1, 'DEVICES': 1, 'diffuser': {'min_b': 0.01, 'max_b': 1.0, 'schedule': 'exponential', 'score_scaling': 'var', 'coordinate_scaling': 1.0, 'latent_dim': '${LATENT_DIM}', 'dt': 0.01, 'min_t': 0}, 'ae': {'name': 'scgen', 'beta': 0.0, 'dropout': 0.0, 'hidden_units': [512, 512], 'latent_dim': 50}, 'score_network': {'latent_dim': '${LATENT_DIM}', 'cond_classes': '${COND_CLASSES}', 'model_dim': 64, 'n_layers': 12, 'nhead': 8, 'dim_feedforward': 2048, 'dropout': 0.1, 'ffn_hidden_dim': 1024}, 'data': {'type': 'cell', 'source': 'control', 'condition': 'drug', 'path': '/Mounts/rbg-storage1/users/johnyang/cellot/datasets/scrna-sciplex3/hvg.h5ad', 'target': 'trametinib'}, 'datasplit': {'groupby': 'drug', 'name': 'train_test', 'test_size': 0.2, 'random_state': 0}, 'dataloader': {'batch_size': 256, 'shuffle': True, 'num_wor

In [26]:
# %%
import cellot.models
from cellot.data.cell import load_cell_data
import torch
from cellot.models.ae import AutoEncoder
from pathlib import Path

def load_data(config, **kwargs):
    data_type = config.get("data.type", "cell")
    if data_type in ["cell", "cell-merged", "tupro-cohort"]:
        loadfxn = load_cell_data

    elif data_type == "toy":
        loadfxn = load_toy_data

    else:
        raise ValueError

    return loadfxn(config, **kwargs)


def load_model(config, device, restore=None, **kwargs):
    # def load_autoencoder_model(config, restore=None, **kwargs):
    
    def load_optimizer(config, params):
        kwargs = dict(config.get("optim", {}))
        assert kwargs.pop("optimizer", "Adam") == "Adam"
        optim = torch.optim.Adam(params, **kwargs)
        return optim


    def load_networks(config, **kwargs):
        kwargs = kwargs.copy()
        kwargs.update(dict(config.get("ae", {})))
        name = kwargs.pop("name")

        # if name == "scgen":
        model = AutoEncoder

        # elif name == "cae":
        #     model = ConditionalAutoEncoder
        # else:
        #     raise ValueError

        return model(**kwargs)
    
    model = load_networks(config, **kwargs)
    optim = load_optimizer(config, model.parameters())

    if restore is not None and Path(restore).exists():
        print('Loading model from checkpoint')
        ckpt = torch.load(restore, map_location=device)
        model.load_state_dict(ckpt["model_state"])
        optim.load_state_dict(ckpt["optim_state"])
        # if config.model.name == "scgen" and "code_means" in ckpt:
        #     model.code_means = ckpt["code_means"]
            
    # logger.info(f'Model on device {next(model.parameters()).device}')

    return model, optim

def load(config, device, restore=None, include_model_kwargs=False, **kwargs):

    loader, model_kwargs = load_data(config, include_model_kwargs=True, **kwargs)

    model, opt = load_model(config, device, restore=restore, **model_kwargs)

    return model, opt, loader
# %% [markdown]
# ### Training

# %%
restore_path = '/Mounts/rbg-storage1/users/johnyang/cellot/saved_weights/ae/ae.pt'
ae = load_model(config, 'cuda', restore=restore_path, input_dim=1000)

Loading model from checkpoint


In [46]:
from imp import reload
reload(cellot.data.utils)

<module 'cellot.data.utils' from '/Mounts/rbg-storage1/users/johnyang/cellot/cellot/data/utils.py'>

In [47]:
from cellot.data.utils import load_ae_cell_data
loader = load_ae_cell_data(config, ae=None, encode_latents=False)

In [48]:
len(loader.train.dataset)

16673

In [49]:
len(loader.test.dataset)

4169

### adata

In [66]:
from cellot.data.cell import read_single_anndata
adata = read_single_anndata(config)

2023-07-06 15:50:11,850 Loaded cell data with TARGET trametinib and OBS SHAPE (20842, 16)


In [67]:
adata.var_names = adata.var_names.astype(str)
adata.var_names

Index(['ENSG00000243620.1', 'ENSG00000271503.5', 'ENSG00000259124.1',
       'ENSG00000121101.15', 'ENSG00000160963.13', 'ENSG00000135346.8',
       'ENSG00000143839.14', 'ENSG00000100867.14', 'ENSG00000140986.7',
       'ENSG00000230666.5',
       ...
       'ENSG00000140795.12', 'ENSG00000232006.8', 'ENSG00000135821.17',
       'ENSG00000166960.16', 'ENSG00000187391.19', 'ENSG00000227124.8',
       'ENSG00000280081.3', 'ENSG00000270019.1', 'ENSG00000072182.12',
       'ENSG00000183242.11'],
      dtype='object', name='id', length=1000)

In [68]:
adata.obs['drug'].unique()

['control', 'trametinib']
Categories (2, object): ['control', 'trametinib']

In [69]:
categories = sorted(adata.obs['drug'].cat.categories)
categories

['control', 'trametinib']

In [70]:
categories.index('trametinib')

1

In [71]:
from cellot.data.cell import read_single_anndata
def load_markers():
    data = read_single_anndata(config, path=None)
    key = f'marker_genes-{config.data.condition}-rank'

    # rebuttal preprocessing stored marker genes using
    # a generic marker_genes-condition-rank key
    # instead of e.g. marker_genes-drug-rank
    # let's just patch that here:
    if key not in data.varm:
        key = 'marker_genes-condition-rank'
        print('WARNING: using generic condition marker genes')

    sel_mg = (
        data.varm[key][config.data.target]
        .sort_values()
        .index
    )
    return sel_mg

sel_mg = load_markers()[:50]
sel_mg

2023-07-06 15:50:24,912 Loaded cell data with TARGET trametinib and OBS SHAPE (20842, 16)


Index(['ENSG00000198074.9', 'ENSG00000019186.9', 'ENSG00000108846.15',
       'ENSG00000115414.18', 'ENSG00000231185.6', 'ENSG00000112541.13',
       'ENSG00000117983.17', 'ENSG00000145819.15', 'ENSG00000184588.17',
       'ENSG00000165376.10', 'ENSG00000154529.14', 'ENSG00000182752.9',
       'ENSG00000251003.7', 'ENSG00000101144.12', 'ENSG00000117724.12',
       'ENSG00000157168.18', 'ENSG00000275395.5', 'ENSG00000185483.11',
       'ENSG00000108405.3', 'ENSG00000089199.9', 'ENSG00000254166.2',
       'ENSG00000215182.8', 'ENSG00000004948.13', 'ENSG00000227706.3',
       'ENSG00000065809.13', 'ENSG00000004799.7', 'ENSG00000144847.12',
       'ENSG00000107957.16', 'ENSG00000108602.17', 'ENSG00000059804.15',
       'ENSG00000047648.21', 'ENSG00000076706.16', 'ENSG00000003436.15',
       'ENSG00000229140.8', 'ENSG00000066279.17', 'ENSG00000153956.15',
       'ENSG00000086548.8', 'ENSG00000171408.13', 'ENSG00000005108.15',
       'ENSG00000138696.10', 'ENSG00000236213.1', 'ENSG0000003842

In [73]:
adata[:, sel_mg]

View of AnnData object with n_obs × n_vars = 20842 × 50
    obs: 'size_factor', 'cell_type', 'replicate', 'dose', 'drug_code', 'pathway_level_1', 'pathway_level_2', 'product_name', 'target', 'pathway', 'drug', 'drug-dose', 'drug_code-dose', 'n_genes', 'transport', 'split'
    var: 'gene_short_name', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'pca', 'rank_genes_groups'
    obsm: 'X_pca'
    varm: 'PCs', 'marker_genes-drug-rank', 'marker_genes-drug-score'

In [74]:
from imp import reload
reload(cellot.data.utils)

<module 'cellot.data.utils' from '/Mounts/rbg-storage1/users/johnyang/cellot/cellot/data/utils.py'>

In [75]:
from cellot.data.utils import load_ae_cell_data
loader = load_ae_cell_data(config, ae=None, encode_latents=False, sel_mg=sel_mg)

In [77]:
loader

{'test': <torch.utils.data.dataloader.DataLoader at 0x7fae973af5b0>,
 'train': <torch.utils.data.dataloader.DataLoader at 0x7fae973af8b0>}

### end

In [50]:
from cellot.losses.mmd import mmd_distance
import numpy as np

def compute_mmd_loss(lhs, rhs, gammas):
    return np.mean([mmd_distance(lhs, rhs, g) for g in gammas])

gammas = np.logspace(1, -3, num=50)

In [80]:
for batch in loader.test:
    ex_batch = batch
    break

In [82]:
compute_mmd_loss(ex_batch[0], ex_batch[0] + 1, gammas)

0.1368182871707745

In [None]:
compute_mmd_loss(lhs, rhs, gammas)

### OT Eval

In [15]:
from pathlib import Path
import numpy as np
import pandas as pd
from absl import app, flags
from cellot.utils.evaluate import (
    load_conditions,
    compute_knn_enrichment,
)
from cellot.losses.mmd import mmd_distance
from cellot.utils import load_config
from cellot.data.cell import read_single_anndata

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def compute_mmd_loss(lhs, rhs, gammas):
    return np.mean([mmd_distance(lhs, rhs, g) for g in gammas])


def compute_pairwise_corrs(df):
    corr = df.corr().rename_axis(index='lhs', columns='rhs')
    return (
        corr
        .where(np.triu(np.ones(corr.shape), k=1).astype(bool))
        .stack()
        .reset_index()
        .set_index(['lhs', 'rhs'])
        .squeeze()
        .rename()
    )


def compute_evaluations(iterator):
    gammas = np.logspace(1, -3, num=50)
    for ncells, nfeatures, treated, imputed in iterator:
        mut, mui = treated.mean(0), imputed.mean(0)
        stdt, stdi = treated.std(0), imputed.std(0)
        pwct = compute_pairwise_corrs(treated)
        pwci = compute_pairwise_corrs(imputed)

        yield ncells, nfeatures, 'l2-means', np.linalg.norm(mut - mui)
        yield ncells, nfeatures, 'l2-stds', np.linalg.norm(stdt - stdi)
        yield ncells, nfeatures, 'r2-means', pd.Series.corr(mut, mui)
        yield ncells, nfeatures, 'r2-stds', pd.Series.corr(stdt, stdi)
        yield ncells, nfeatures, 'r2-pairwise_feat_corrs', pd.Series.corr(pwct, pwci)
        yield ncells, nfeatures, 'l2-pairwise_feat_corrs', np.linalg.norm(pwct - pwci)

        if treated.shape[1] < 1000:
            mmd = compute_mmd_loss(treated, imputed, gammas=gammas)
            yield ncells, nfeatures, 'mmd', mmd

            knn, enrichment = compute_knn_enrichment(imputed, treated)
            k50 = enrichment.iloc[:, :50].values.mean()
            k100 = enrichment.iloc[:, :100].values.mean()

            yield ncells, nfeatures, 'enrichment-k50', k50
            yield ncells, nfeatures, 'enrichment-k100', k100


In [None]:
def main(argv):
    expdir = Path(FLAGS.outdir)
    setting = FLAGS.setting
    where = FLAGS.where
    embedding = FLAGS.embedding
    prefix = FLAGS.evalprefix
    n_reps = FLAGS.n_reps

    if (embedding is None) or len(embedding) == 0:
        embedding = None

    if FLAGS.n_markers is None:
        n_markers = None
    else:
        n_markers = FLAGS.n_markers.split(',')
    all_ncells = [int(x) for x in FLAGS.n_cells.split(',')]

    if prefix is None:
        prefix = f'evals_{setting}_{where}'
    outdir = expdir / prefix

    outdir.mkdir(exist_ok=True, parents=True)

    def iterate_feature_slices():

        config = load_config(expdir / 'config.yaml')
        if 'ae_emb' in config.data:
            assert config.model.name == 'cellot'
            config.data.ae_emb.path = str(expdir.parent / 'model-scgen')
        cache = outdir / 'imputed.h5ad'

        _, treateddf, imputed = load_conditions(
                expdir, where, setting, embedding=embedding)

        imputed.write(cache)
        imputeddf = imputed.to_df()

        imputeddf.columns = imputeddf.columns.astype(str)
        treateddf.columns = treateddf.columns.astype(str)

        assert imputeddf.columns.equals(treateddf.columns)

        def load_markers():
            data = read_single_anndata(config, path=None)
            key = f'marker_genes-{config.data.condition}-rank'

            # rebuttal preprocessing stored marker genes using
            # a generic marker_genes-condition-rank key
            # instead of e.g. marker_genes-drug-rank
            # let's just patch that here:
            if key not in data.varm:
                key = 'marker_genes-condition-rank'
                print('WARNING: using generic condition marker genes')

            sel_mg = (
                data.varm[key][config.data.target]
                .sort_values()
                .index
            )
            return sel_mg

        if n_markers is not None:
            markers = load_markers()
            for k in n_markers:
                if k != 'all':
                    feats = markers[:int(k)]
                else:
                    feats = list(markers)

                for ncells in all_ncells:
                    if ncells > min(len(treateddf), len(imputeddf)):
                        break
                    for r in range(n_reps):
                        trt = treateddf[feats].sample(ncells)
                        imp = imputeddf[feats].sample(ncells)
                        yield ncells, k, trt, imp

        else:
            for ncells in all_ncells:
                if ncells > min(len(treateddf), len(imputeddf)):
                    break
                for r in range(n_reps):
                    trt = treateddf.sample(ncells)
                    imp = imputeddf.sample(ncells)
                    yield ncells, 'all', trt, imp

    evals = pd.DataFrame(
            compute_evaluations(iterate_feature_slices()),
            columns=['ncells', 'nfeatures', 'metric', 'value']
            )
    evals.to_csv(outdir / 'evals.csv', index=None)

    return


In [14]:
for batch in dm.val_dataloader():
    batch = [b.to(device) for b in batch]
    mse = lm.validation_step(batch, None)
    break

  rank_zero_warn(


In [15]:
mse

tensor(0.8200, device='cuda:0', dtype=torch.float64)