In [1]:
import torch

In [2]:
import wandb
run = wandb.init()
artifact = run.use_artifact('protein-optimization/sc_diff/model-ylwt16su:v59', type='model')
artifact_dir = artifact.download()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjohnyang[0m ([33mprotein-optimization[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact model-ylwt16su:v59, 96.20MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


In [3]:
artifact_dir

'./artifacts/model-ylwt16su:v59'

In [5]:
from cellot.models.cond_score_module import CondScoreModuleV2

In [6]:
ckpt_path = f'{artifact_dir}/model.ckpt'

In [85]:
YAML_STR = '''
DEBUG: False
TARGET: abexinostat
LATENT_DIM: 50
COND_CLASSES: 190
SEED: 42
AE_PATH: /Mounts/rbg-storage1/users/johnyang/cellot/results/sciplex3/full_ae
VAL_SIZE: 0.1
DEVICES: 1
WARM_START: False
WARM_START_PATH: null
MODEL_CLASS: CondScoreModule


diffuser:
  min_b: 0.01
  max_b: 1.0
  schedule: exponential
  score_scaling: var
  coordinate_scaling: 1.0
  latent_dim: ${LATENT_DIM}
  dt: 0.01
  min_t: 0

ae:
  name: scgen
  beta: 0.0
  dropout: 0.0
  hidden_units: [512, 512]
  latent_dim: 50

score_network:
  latent_dim: ${LATENT_DIM}
  cond_classes: ${COND_CLASSES}
  model_dim: 256   # Adjusted to 64
  n_layers: 6    # Adjusted to 12
  nhead: 8
  dim_feedforward: 2048
  dropout: 0.1
  ffn_hidden_dim: 1024
  extra_null_cond_embedding: False


data:
  type: cell
  source: control
  condition: drug
  path: /Mounts/rbg-storage1/users/johnyang/cellot/datasets/scrna-sciplex3/hvg.h5ad
  target: ${TARGET}

datasplit:
  groupby: drug   
  name: train_test
  test_size: 0.2
  random_state: 0
  
dataloader:
  batch_size: 256   # Adjusted to 256
  shuffle: true
  num_workers: 80
  
experiment:
  name: base
  mode: train
  num_loader_workers: 0
  port: 12319
  dist_mode: single
  use_wandb: True
  ckpt_path: null
  wandb_logger:
    project: sc_diff
    name: ${experiment.name}
    dir: /Mounts/rbg-storage1/users/johnyang/cellot/
    log_model: all
    tags: ['experimental']
  lr: 0.0001


trainer:
  accelerator: 'gpu'
  check_val_every_n_epoch: 50
  log_every_n_steps: 100
  num_sanity_val_steps: 1
  enable_progress_bar: True
  enable_checkpointing: True
  fast_dev_run: False
  profiler: simple
  max_epochs: 10000
  strategy: auto
  enable_model_summary: True
  overfit_batches: 0.0
  limit_train_batches: 1.0
  limit_val_batches: 1.0
  limit_predict_batches: 1.0
'''

In [86]:
from omegaconf import OmegaConf
config = OmegaConf.create(YAML_STR)

In [87]:
ckpt_path = f'{artifact_dir}/model.ckpt'

In [88]:
from cellot.train.utils import get_free_gpu
replica_id = int(get_free_gpu())

cuda
Using GPUs: 7


In [89]:
device = f'cuda:{replica_id}'

In [90]:
config

{'DEBUG': False, 'TARGET': 'abexinostat', 'LATENT_DIM': 50, 'COND_CLASSES': 190, 'SEED': 42, 'AE_PATH': '/Mounts/rbg-storage1/users/johnyang/cellot/results/sciplex3/full_ae', 'VAL_SIZE': 0.1, 'DEVICES': 1, 'WARM_START': False, 'WARM_START_PATH': None, 'MODEL_CLASS': 'CondScoreModule', 'diffuser': {'min_b': 0.01, 'max_b': 1.0, 'schedule': 'exponential', 'score_scaling': 'var', 'coordinate_scaling': 1.0, 'latent_dim': '${LATENT_DIM}', 'dt': 0.01, 'min_t': 0}, 'ae': {'name': 'scgen', 'beta': 0.0, 'dropout': 0.0, 'hidden_units': [512, 512], 'latent_dim': 50}, 'score_network': {'latent_dim': '${LATENT_DIM}', 'cond_classes': '${COND_CLASSES}', 'model_dim': 256, 'n_layers': 6, 'nhead': 8, 'dim_feedforward': 2048, 'dropout': 0.1, 'ffn_hidden_dim': 1024, 'extra_null_cond_embedding': False}, 'data': {'type': 'cell', 'source': 'control', 'condition': 'drug', 'path': '/Mounts/rbg-storage1/users/johnyang/cellot/datasets/scrna-sciplex3/hvg.h5ad', 'target': '${TARGET}'}, 'datasplit': {'groupby': 'dru

In [91]:
lm = CondScoreModuleV2.load_from_checkpoint(hparams=config, checkpoint_path=ckpt_path).to(device)
print('')

Dropout is 0.1



In [92]:
# %%
import cellot.models
from cellot.data.cell import load_cell_data
import torch
from cellot.models.ae import AutoEncoder
from pathlib import Path

def load_data(config, **kwargs):
    data_type = config.get("data.type", "cell")
    if data_type in ["cell", "cell-merged", "tupro-cohort"]:
        loadfxn = load_cell_data

    elif data_type == "toy":
        loadfxn = load_toy_data

    else:
        raise ValueError

    return loadfxn(config, **kwargs)


def load_model(config, device, restore=None, **kwargs):
    # def load_autoencoder_model(config, restore=None, **kwargs):
    
    def load_optimizer(config, params):
        kwargs = dict(config.get("optim", {}))
        assert kwargs.pop("optimizer", "Adam") == "Adam"
        optim = torch.optim.Adam(params, **kwargs)
        return optim


    def load_networks(config, **kwargs):
        kwargs = kwargs.copy()
        kwargs.update(dict(config.get("ae", {})))
        name = kwargs.pop("name")

        # if name == "scgen":
        model = AutoEncoder

        # elif name == "cae":
        #     model = ConditionalAutoEncoder
        # else:
        #     raise ValueError

        return model(**kwargs)
    
    model = load_networks(config, **kwargs)
    optim = load_optimizer(config, model.parameters())

    if restore is not None and Path(restore).exists():
        print('Loading model from checkpoint')
        ckpt = torch.load(restore, map_location=device)
        model.load_state_dict(ckpt["model_state"])
        optim.load_state_dict(ckpt["optim_state"])
        # if config.model.name == "scgen" and "code_means" in ckpt:
        #     model.code_means = ckpt["code_means"]
            
    # logger.info(f'Model on device {next(model.parameters()).device}')

    return model, optim

def load(config, device, restore=None, include_model_kwargs=False, **kwargs):

    loader, model_kwargs = load_data(config, include_model_kwargs=True, **kwargs)

    model, opt = load_model(config, device, restore=restore, **model_kwargs)

    return model, opt, loader
# %% [markdown]
# ### Training

# %%
restore_path = '/Mounts/rbg-storage1/users/johnyang/cellot/saved_weights/ae/ae.pt'
ae = load_model(config, 'cuda', restore=restore_path, input_dim=1000)

Loading model from checkpoint


In [93]:
ae[0].decode

<bound method AutoEncoder.decode of AutoEncoder(
  (encoder_net): Sequential(
    (0): Linear(in_features=1000, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=50, bias=True)
  )
  (decoder_net): Sequential(
    (0): Linear(in_features=50, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=1000, bias=True)
  )
  (mse): MSELoss()
)>

In [94]:
import numpy as np

In [95]:
autoencoder = ae[0].to(device)

In [115]:
def inference(lm, batch, lamb=4, dt=0.01, t_start=1.0, cond=True):
    with torch.inference_mode():
        lm.eval()
        all_genes_x, y = batch
        latent_x = autoencoder.eval().encode(all_genes_x)
        
        x_t, _ = lm.diffuser.forward_marginal(latent_x.detach().cpu().numpy(), t=t_start)
        
        for i, t in enumerate(np.arange(t_start, 0, -dt)):
            x_t = torch.tensor(x_t).float().to(lm.device)
            uncond_score = lm.score_network((x_t, torch.zeros_like(y).to(device)), t)
            if cond:
                cond_score = lm.score_network((x_t, y), t)
                pred_score = (1 + lamb) * cond_score - lamb * uncond_score
            else:
                pred_score = uncond_score
            
            x_t = lm.diffuser.reverse(x_t=x_t.detach().cpu().numpy(), score_t=pred_score.detach().cpu().numpy(), t=t, dt=lm.dt, center=False)
        
        x_0 = torch.tensor(x_t, dtype=torch.float).to(lm.device)
        
        recon = autoencoder.eval().decode(x_0)
        return recon
        
        

In [97]:
from cellot.data.cell import read_single_anndata
def load_markers():
    data = read_single_anndata(config, path=None)
    key = f'marker_genes-{config.data.condition}-rank'

    # rebuttal preprocessing stored marker genes using
    # a generic marker_genes-condition-rank key
    # instead of e.g. marker_genes-drug-rank
    # let's just patch that here:
    if key not in data.varm:
        key = 'marker_genes-condition-rank'
        print('WARNING: using generic condition marker genes')

    sel_mg = (
        data.varm[key][config.data.target]
        .sort_values()
        .index
    )
    marker_gene_indices = [i for i, gene in enumerate(data.var_names) if gene in sel_mg]

    return sel_mg, marker_gene_indices

sel_mg, gene_idxs = load_markers()
sel_mg

2023-07-10 13:16:11,922 Loaded cell data with TARGET abexinostat and OBS SHAPE (22070, 16)


Index(['ENSG00000160963.13', 'ENSG00000121101.15', 'ENSG00000230666.5',
       'ENSG00000175175.5', 'ENSG00000140986.7', 'ENSG00000100867.14',
       'ENSG00000173727.12', 'ENSG00000183032.11', 'ENSG00000117724.12',
       'ENSG00000259124.1',
       ...
       'ENSG00000116183.10', 'ENSG00000102683.7', 'ENSG00000117983.17',
       'ENSG00000205359.9', 'ENSG00000142347.16', 'ENSG00000232072.1',
       'ENSG00000135426.15', 'ENSG00000271856.1', 'ENSG00000185008.17',
       'ENSG00000196628.15'],
      dtype='object', name='id', length=1000)

In [98]:
from cellot.data.utils import *

In [99]:
def DEV_load_ae_cell_data(
        config,
        data=None,
        split_on=None,
        return_as="loader",
        include_model_kwargs=False,
        pair_batch_on=None,
        ae=None,
        encode_latents=False,
        sel_mg=None,
        **kwargs
    ):
        assert ae is not None or not encode_latents, "ae must be provided"
        
        if isinstance(return_as, str):
            return_as = [return_as]

        assert set(return_as).issubset({"anndata", "dataset", "loader"})
        config.data.condition = config.data.get("condition", "drug")
        condition = config.data.condition
        
        data = read_single_anndata(config, **kwargs)
        
        inputs = torch.Tensor(
            data.X if not sparse.issparse(data.X) else data.X.todense()
        )

        if encode_latents:
            genes = data.var_names.to_list()
            data = anndata.AnnData(
                ae.eval().encode(inputs).detach().numpy(),
                obs=data.obs.copy(),
                uns=data.uns.copy(),
            )
            data.uns["genes"] = genes


        # cast to dense and check for nans
        if sparse.issparse(data.X):
            data.X = data.X.todense()
        assert not np.isnan(data.X).any()

        if sel_mg is not None:
            data = data[:, sel_mg]

        dataset_args = dict()
        model_kwargs = {}

        model_kwargs["input_dim"] = data.n_vars

        # if config.get("model.name") == "cae":
        condition_labels = sorted(data.obs[condition].cat.categories)
        model_kwargs["conditions"] = condition_labels
        dataset_args["obs"] = condition
        dataset_args["categories"] = condition_labels

        if "training" in config:
            pair_batch_on = config.training.get("pair_batch_on", pair_batch_on)

        # if split_on is None:
            # if config.model.name == "cellot":
            #     # datasets & dataloaders accessed as loader.train.source
        split_on = ["split", "transport"]
        if pair_batch_on is not None:
            split_on.append(pair_batch_on)

            # if (config.ae.name == "scgen" #or config.ae.name == "cae"
            #     #or config.ae.name == "popalign"):
            # split_on = ["split"]

            # else:
            #     raise ValueError

        if isinstance(split_on, str):
            split_on = [split_on]

        for key in split_on:
            assert key in data.obs.columns

        if len(split_on) > 0:
            splits = {
                (key if isinstance(key, str) else ".".join(key)): data[index]
                for key, index in data.obs[split_on].groupby(split_on).groups.items()
            }

            dataset = nest_dict(
                {
                    key: AnnDataDataset(val.copy(), **dataset_args)
                    for key, val in splits.items()
                },
                as_dot_dict=True,
            )
        else:
            dataset = AnnDataDataset(data.copy(), **dataset_args)

        if "loader" in return_as:
            kwargs = dict(config.dataloader)
            kwargs.setdefault("drop_last", True)
            loader = cast_dataset_to_loader(dataset, **kwargs)

        returns = list()
        for key in return_as:
            if key == "anndata":
                returns.append(data)

            elif key == "dataset":
                returns.append(dataset)

            elif key == "loader":
                returns.append(loader)

        if include_model_kwargs:
            returns.append(model_kwargs)

        if len(returns) == 1:
            return returns[0]

        # returns.append(data)

        return tuple(returns)

In [100]:
datasets = DEV_load_ae_cell_data(config, return_as='dataset')#, ae=autoencoder.cpu(), encode_latents=True)#, sel_mg=sel_mg)

In [101]:
loader = cast_dataset_to_loader(datasets, batch_size=256, shuffle=False, drop_last=False)
loader

{'test': {'source': <torch.utils.data.dataloader.DataLoader at 0x7f9c6dc274f0>,
  'target': <torch.utils.data.dataloader.DataLoader at 0x7f9c6dc273a0>},
 'train': {'source': <torch.utils.data.dataloader.DataLoader at 0x7f9c6dc27820>,
  'target': <torch.utils.data.dataloader.DataLoader at 0x7f9c6dc272b0>}}

In [102]:
source = datasets.test.source.adata.X

In [103]:
target = datasets.test.target.adata.X

In [104]:
source[:, gene_idxs[:50]].shape

(3513, 50)

In [105]:
target[:, gene_idxs[:50]].shape

(901, 50)

In [106]:
from tqdm import tqdm

In [107]:
from cellot.losses.mmd import mmd_distance
import numpy as np

def compute_mmd_loss(lhs, rhs, gammas):
    return np.mean([mmd_distance(lhs, rhs, g) for g in gammas])

gammas = np.logspace(1, -3, num=50)

In [108]:
sel_target = target[:, gene_idxs[:50]]

In [110]:
gts = []
recons = []
for batch in tqdm(loader.test.source):
    batch = [x.to(device) for x in batch]
    gts.append(batch)
    recon = inference(lm, batch, lamb=4, dt=0.01, cond=True)
    recons.append(recon)

100%|██████████| 14/14 [00:20<00:00,  1.47s/it]


In [111]:
torch.where(gts[0][0][0, gene_idxs[:50]] > 0, 1, 0)

tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0], device='cuda:7')

In [112]:
torch.where(recons[0][0, gene_idxs[:50]] > 0.1, 1, 0)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
        0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 0], device='cuda:7')

In [113]:
all_recon = torch.cat(recons, dim=0)
all_recon.shape

torch.Size([3513, 1000])

In [114]:
compute_mmd_loss(all_recon[:, gene_idxs[:50]].detach().cpu().numpy(), sel_target, gammas)

0.03229137025773525