In [11]:
import torch

In [1]:
import wandb
run = wandb.init()
artifact = run.use_artifact('protein-optimization/sc_diff/model-fekp2uq8:v10', type='model')
artifact_dir = artifact.download()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjohnyang[0m ([33mprotein-optimization[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [2]:
artifact_dir

'./artifacts/model-fekp2uq8:v10'

In [3]:
from cellot.models.cond_score_module import CondScoreModule

In [4]:
ckpt_path = f'{artifact_dir}/model.ckpt'

In [5]:
YAML_STR = '''
DEBUG: False
TARGET: all
LATENT_DIM: 50
COND_CLASSES: 189
SEED: 42
AE_PATH: /Mounts/rbg-storage1/users/johnyang/cellot/results/sciplex3/full_ae
VAL_SIZE: 0.1
DEVICES: 1

diffuser:
  min_b: 0.01
  max_b: 1.0
  schedule: exponential
  score_scaling: var
  coordinate_scaling: 1.0
  latent_dim: ${LATENT_DIM}
  dt: 0.01
  min_t: 0

ae:
  name: scgen
  beta: 0.0
  dropout: 0.0
  hidden_units: [512, 512]
  latent_dim: 50

score_network:
  latent_dim: ${LATENT_DIM}
  cond_classes: ${COND_CLASSES}
  model_dim: 64   # Adjusted to 64
  n_layers: 12    # Adjusted to 12
  nhead: 8
  dim_feedforward: 2048
  dropout: 0.1
  ffn_hidden_dim: 1024


data:
  type: cell
  source: control
  condition: drug
  path: /Mounts/rbg-storage1/users/johnyang/cellot/datasets/scrna-sciplex3/hvg.h5ad
  target: trametinib

datasplit:
  groupby: drug   
  name: train_test
  test_size: 0.2
  random_state: 0
  
dataloader:
  batch_size: 256   # Adjusted to 256
  shuffle: true
  num_workers: 80
  
experiment:
  name: base
  mode: train
  num_loader_workers: 0
  port: 12319
  dist_mode: single
  use_wandb: True
  ckpt_path: null
  wandb_logger:
    project: sc_diff
    name: ${experiment.name}
    dir: /Mounts/rbg-storage1/users/johnyang/cellot/
    log_model: all
    tags: ['experimental']
  lr: 0.0001


trainer:
  accelerator: 'gpu'
  check_val_every_n_epoch: 50
  log_every_n_steps: 100
  num_sanity_val_steps: 1
  enable_progress_bar: True
  enable_checkpointing: True
  fast_dev_run: False
  profiler: simple
  max_epochs: 10000
  strategy: auto
  enable_model_summary: True
  overfit_batches: 0.0
  limit_train_batches: 1.0
  limit_val_batches: 1.0
  limit_predict_batches: 1.0
'''

In [6]:
from omegaconf import OmegaConf
config = OmegaConf.create(YAML_STR)

In [7]:
ckpt_path = f'{artifact_dir}/model.ckpt'

In [8]:
from cellot.train.utils import get_free_gpu
replica_id = int(get_free_gpu())

cuda
Using GPUs: 1


In [9]:
device = f'cuda:{replica_id}'

In [10]:
lm = CondScoreModule.load_from_checkpoint(hparams=config, checkpoint_path=ckpt_path).to(device)
print('')

Dropout is 0.1



In [12]:
# %%
import cellot.models
from cellot.data.cell import load_cell_data
import torch
from cellot.models.ae import AutoEncoder
from pathlib import Path

def load_data(config, **kwargs):
    data_type = config.get("data.type", "cell")
    if data_type in ["cell", "cell-merged", "tupro-cohort"]:
        loadfxn = load_cell_data

    elif data_type == "toy":
        loadfxn = load_toy_data

    else:
        raise ValueError

    return loadfxn(config, **kwargs)


def load_model(config, device, restore=None, **kwargs):
    # def load_autoencoder_model(config, restore=None, **kwargs):
    
    def load_optimizer(config, params):
        kwargs = dict(config.get("optim", {}))
        assert kwargs.pop("optimizer", "Adam") == "Adam"
        optim = torch.optim.Adam(params, **kwargs)
        return optim


    def load_networks(config, **kwargs):
        kwargs = kwargs.copy()
        kwargs.update(dict(config.get("ae", {})))
        name = kwargs.pop("name")

        # if name == "scgen":
        model = AutoEncoder

        # elif name == "cae":
        #     model = ConditionalAutoEncoder
        # else:
        #     raise ValueError

        return model(**kwargs)
    
    model = load_networks(config, **kwargs)
    optim = load_optimizer(config, model.parameters())

    if restore is not None and Path(restore).exists():
        print('Loading model from checkpoint')
        ckpt = torch.load(restore, map_location=device)
        model.load_state_dict(ckpt["model_state"])
        optim.load_state_dict(ckpt["optim_state"])
        # if config.model.name == "scgen" and "code_means" in ckpt:
        #     model.code_means = ckpt["code_means"]
            
    # logger.info(f'Model on device {next(model.parameters()).device}')

    return model, optim

def load(config, device, restore=None, include_model_kwargs=False, **kwargs):

    loader, model_kwargs = load_data(config, include_model_kwargs=True, **kwargs)

    model, opt = load_model(config, device, restore=restore, **model_kwargs)

    return model, opt, loader
# %% [markdown]
# ### Training

# %%
restore_path = '/Mounts/rbg-storage1/users/johnyang/cellot/saved_weights/ae/ae.pt'
ae = load_model(config, 'cuda', restore=restore_path, input_dim=1000)

Loading model from checkpoint


In [14]:
ae[0].decode

<bound method AutoEncoder.decode of AutoEncoder(
  (encoder_net): Sequential(
    (0): Linear(in_features=1000, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=50, bias=True)
  )
  (decoder_net): Sequential(
    (0): Linear(in_features=50, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=1000, bias=True)
  )
  (mse): MSELoss()
)>

In [20]:
import numpy as np

In [65]:
autoencoder = ae[0].to(device)

In [87]:
def inference(lm, batch, lamb=4):
    with torch.inference_mode():
        lm.eval()
        all_genes_x, y = batch
        latent_x = autoencoder.eval().encode(all_genes_x)
        
        x_t, _ = lm.diffuser.forward_marginal(latent_x.detach().cpu().numpy(), t=1.0)
        
        for i, t in enumerate(np.arange(1.0, 0, -0.001)):
            x_t = torch.tensor(x_t).float().to(lm.device)
            uncond_score = lm.score_network((x_t, torch.zeros_like(y).to(device)), t)
            cond_score = lm.score_network((x_t, y), t)
            pred_score = (1 + lamb) * cond_score - lamb * uncond_score
            
            x_t = lm.diffuser.reverse(x_t=x_t.detach().cpu().numpy(), score_t=pred_score.detach().cpu().numpy(), t=t, dt=lm.dt, center=False)
        
        x_0 = torch.tensor(x_t, dtype=torch.float).to(lm.device)
        
        recon = autoencoder.eval().decode(x_0)
        return recon
        
        

In [59]:
from cellot.data.cell import read_single_anndata
def load_markers():
    data = read_single_anndata(config, path=None)
    key = f'marker_genes-{config.data.condition}-rank'

    # rebuttal preprocessing stored marker genes using
    # a generic marker_genes-condition-rank key
    # instead of e.g. marker_genes-drug-rank
    # let's just patch that here:
    if key not in data.varm:
        key = 'marker_genes-condition-rank'
        print('WARNING: using generic condition marker genes')

    sel_mg = (
        data.varm[key][config.data.target]
        .sort_values()
        .index
    )
    marker_gene_indices = [i for i, gene in enumerate(data.var_names) if gene in sel_mg]

    return sel_mg, marker_gene_indices

sel_mg, gene_idxs = load_markers()[:50]
sel_mg

2023-07-06 16:20:36,112 Loaded cell data with TARGET trametinib and OBS SHAPE (20842, 16)


Index(['ENSG00000198074.9', 'ENSG00000019186.9', 'ENSG00000108846.15',
       'ENSG00000115414.18', 'ENSG00000231185.6', 'ENSG00000112541.13',
       'ENSG00000117983.17', 'ENSG00000145819.15', 'ENSG00000184588.17',
       'ENSG00000165376.10',
       ...
       'ENSG00000138617.14', 'ENSG00000072274.12', 'ENSG00000249364.5',
       'ENSG00000243193.4', 'ENSG00000070601.9', 'ENSG00000135253.13',
       'ENSG00000050628.20', 'ENSG00000165646.11', 'ENSG00000154415.7',
       'ENSG00000130830.14'],
      dtype='object', name='id', length=1000)

In [60]:
from cellot.data.utils import load_ae_cell_data
loader = load_ae_cell_data(config)#, ae=autoencoder.cpu(), encode_latents=True)#, sel_mg=sel_mg)

In [61]:
for batch in loader.test:
    print(batch[0].shape)
    break

torch.Size([256, 1000])


In [73]:
from tqdm import tqdm

In [74]:
from cellot.losses.mmd import mmd_distance
import numpy as np

def compute_mmd_loss(lhs, rhs, gammas):
    return np.mean([mmd_distance(lhs, rhs, g) for g in gammas])

gammas = np.logspace(1, -3, num=50)

In [89]:
losses = []
for batch in tqdm(loader.test):
    batch = [x.to(device) for x in batch]
    # for i in range(batch[0].shape[0]):
        # single = batch[0][i].unsqueeze(0), batch[1][i].unsqueeze(0)
    recon = inference(lm, batch, lamb=4)
    # for i in range(batch[0].shape[0]):
    mmd_loss = compute_mmd_loss(recon.detach().cpu().numpy()[:, gene_idxs], batch[0].detach().cpu().numpy()[:, gene_idxs], gammas)
    losses.append(mmd_loss)

 12%|█▎        | 2/16 [01:05<07:18, 31.33s/it]

In [84]:
compute_mmd_loss(recon.detach().cpu().numpy()[:, gene_idxs], batch[0].detach().cpu().numpy()[:, gene_idxs], gammas)

0.009160959883056003

In [83]:
compute_mmd_loss(recon.detach().cpu().numpy(), batch[0].detach().cpu().numpy(), gammas)

0.009160959883056003

In [82]:
losses

[0.00966372249177559,
 0.009888802105523125,
 0.00956955381808379,
 0.009733738261301722,
 0.00983408870076948,
 0.00938515146746336,
 0.009916365548519561,
 0.010070365515457046,
 0.009820912044018793,
 0.009621424285639458,
 0.009814309236616163,
 0.009255622948887217,
 0.009843751405248296,
 0.009771178342321945,
 0.010108215091519131,
 0.009160959883056003]

In [81]:
np.mean(losses)

0.009716135071637542

In [68]:
recon

tensor([[-0.0161,  0.0037,  0.0200,  ...,  0.0086, -0.0070,  0.0121],
        [ 0.0071,  0.0036,  0.0015,  ...,  0.0187, -0.0065,  0.0047],
        [ 0.0249, -0.0080,  0.0110,  ...,  0.0056, -0.0065,  0.0002],
        ...,
        [ 0.0064,  0.0034,  0.0082,  ...,  0.0227, -0.0098,  0.0181],
        [-0.0132,  0.0020, -0.0040,  ...,  0.0119, -0.0128,  0.0113],
        [ 0.0057, -0.0019,  0.0118,  ...,  0.0175,  0.0226, -0.0063]],
       device='cuda:1')

In [69]:
recon.shape

torch.Size([256, 1000])

In [72]:
sel_mg_recon = recon[:, gene_idxs]
sel_mg_gt = batch[0][:, gene_idxs]
assert sel_mg_recon.shape == sel_mg_gt.shape

In [41]:
compute_mmd_loss(batch[0].detach().cpu().numpy(), recon.detach().cpu().numpy(), gammas)

ValueError: Incompatible dimension for X and Y matrices: X.shape[1] == 50 while Y.shape[1] == 1000