In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
import os
import hydra
from omegaconf import OmegaConf

# add parent directory to path
sys.path.append(os.path.abspath(os.path.join('..')))

# initialize hydra

In [14]:
from hydra.core.global_hydra import GlobalHydra
GlobalHydra.instance().clear()

In [15]:
hydra.initialize(config_path="../config", version_base="1.1")

# Choose which config to load
config_name = "config"  # Change this to use a different config
print(f"Loading config: {config_name}")

# Load the config
cfg = hydra.compose(
    config_name=config_name, 
    overrides=["experiment=essential_genes", "loss=perturbseq"]
)

# Display the loaded config
print(OmegaConf.to_yaml(cfg))

Loading config: config
dataset:
  _target_: datasets.perturbseq_dataset.PerturbseqDataset
  adata_path: /orcd/data/omarabu/001/Omnicell_datasets/essential_gene_knockouts_raw/essential_gene_knockouts_raw.h5ad
  pert_embedding_path: /orcd/data/omarabu/001/Omnicell_datasets/essential_gene_knockouts_raw/pert_embeddings/GenePT.pt
  control_pert: non-targeting
  pert_key: gene
  cell_key: cell_type
  split_mode: iid
  pca_components: ${experiment.pert_embedding_dim}
  seed: 42
  set_size: 100
  data_shape:
  - 11907
  heldout_perts:
  - SUPT5H
  - ATF5
  - SRSF1
  - PSMA3
  - SNRPD3
  - RPL30
  - EXOSC2
  - CDC73
  - NUP54
  - PRIM2
  - TSR2
  - RPS11
  - KPNB1
  - NACA
  - CSE1L
  - SF3B2
  - PHAX
  - POLR2G
  - RPS15A
  - SF3A2
  heldout_cell_types:
  - k562
encoder:
  _target_: encoder.perturbseq_encoders.DistributionEncoderResNetPertPredictor
  in_dim: ${dataset.data_shape[0]}
  latent_dim: ${experiment.latent_dim}
  hidden_dim: ${experiment.hidden_dim}
  set_size: ${experiment.set_size}

In [6]:
from torch.utils.data import DataLoader

dataset = hydra.utils.instantiate(cfg.dataset)



No PCA applied, using 3072 components
Loaded 9220 sets (cell_type x gene combinations)


In [7]:
dataloader = DataLoader(dataset, batch_size=cfg.experiment.batch_size, shuffle=True)

In [16]:
# Create encoder
encoder = hydra.utils.instantiate(cfg.encoder)

In [17]:
# Create generator (with model already instantiated)
generator = hydra.utils.instantiate(cfg.generator)

In [18]:
# Get model parameters
model_parameters = list(encoder.parameters()) + list(generator.model.parameters())

# Create optimizer and scheduler
optimizer = hydra.utils.instantiate(cfg.optimizer)(params=model_parameters)
scheduler = hydra.utils.instantiate(cfg.scheduler)(optimizer=optimizer)

loss_manager = hydra.utils.instantiate(cfg.loss)

# Create trainer
trainer = hydra.utils.instantiate(cfg.training)

In [19]:
trainer.use_tqdm = True
trainer.log_interval = 10
output_dir, stats = trainer.train(
    encoder=encoder,
    generator=generator,
    dataloader=dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_manager=loss_manager,
    output_dir=os.path.abspath('../outputs'),
    config=cfg
)

similar_experiments []


Epoch 1/100:   0%|          | 0/101 [00:03<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 438.00 MiB. GPU 0 has a total capacity of 79.10 GiB of which 150.69 MiB is free. Process 763708 has 75.74 GiB memory in use. Including non-PyTorch memory, this process has 3.20 GiB memory in use. Of the allocated memory 2.63 GiB is allocated by PyTorch, and 59.46 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [13]:
!ls /orcd/data/omarabu/001/njwfish/DistributionEmbeddings/outputs/essential_genes_exp_7febae6e4ed11221715d4828439f4f33

best_model.pt		 checkpoint_epoch_160.pt  checkpoint_epoch_40.pt
checkpoint_epoch_100.pt  checkpoint_epoch_180.pt  checkpoint_epoch_60.pt
checkpoint_epoch_120.pt  checkpoint_epoch_200.pt  checkpoint_epoch_80.pt
checkpoint_epoch_140.pt  checkpoint_epoch_20.pt   config.yaml


In [11]:
config = {'dir': '/orcd/data/omarabu/001/njwfish/DistributionEmbeddings/outputs/essential_genes_exp_7febae6e4ed11221715d4828439f4f33'}
config['config'] = OmegaConf.load('/orcd/data/omarabu/001/njwfish/DistributionEmbeddings/outputs/essential_genes_exp_7febae6e4ed11221715d4828439f4f33/config.yaml')

In [12]:
import torch
import hydra

def instantiate_and_load_model(config) -> torch.nn.Module:
    """
    Instantiate and load a model from a checkpoint file.
    
    Args:
        model_path: The path to the model checkpoint file
        device: The device to load the model on

    Returns:
        The loaded model
    """
    cfg = config['config']
    encoder = hydra.utils.instantiate(cfg.encoder)
    generator = hydra.utils.instantiate(cfg.generator)

    checkpoint = torch.load(config['dir'] + '/checkpoint_epoch_100.pt', weights_only=False)
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    generator.model.load_state_dict(checkpoint['generator_state_dict'])
    return encoder, generator

encoder, generator = instantiate_and_load_model(config)

In [22]:
import numpy as np
import pandas as pd 
import torch
from sklearn.metrics import r2_score, mean_squared_error
from scipy.stats import pearsonr

def generate_set_mean_predictions(encoder, sets, X, ctrl_key, pert_keys):
    encoder = encoder.to('cuda') 

    ctrl_X = torch.tensor(X[sets[ctrl_key]]).to('cuda')
    pert_X = {k: torch.tensor(X[sets[k]]).to('cuda') for k in pert_keys}

    ctrl_S = encoder(ctrl_X.unsqueeze(0))
    
    pert_S = {k: encoder(pert_X[k].unsqueeze(0)) for k in pert_keys}
    pert_S_delta = {k: pert_S[k] - ctrl_S for k in pert_keys}

    pert_S = torch.cat([pert_S[k] for k in pert_keys], dim=0)
    pert_S_delta = torch.cat([pert_S_delta[k] for k in pert_keys], dim=0)

    ctrl_X_mean = torch.mean(ctrl_X, dim=0)
    pert_X_mean = {k: torch.mean(pert_X[k], dim=0) for k in pert_keys}
    pert_X_mean = torch.cat([pert_X_mean[k].unsqueeze(0) for k in pert_keys], dim=0)
    pert_X_delta = pert_X_mean - ctrl_X_mean.unsqueeze(0)

    pert_X_delta_recon = encoder.mean_predictor(pert_S) - ctrl_X_mean
    return ctrl_X_mean.cpu().detach().numpy(), ctrl_S.cpu().detach().numpy(), pert_X_delta.cpu().detach().numpy(), pert_S_delta.cpu().detach().numpy(), pert_X_delta_recon.cpu().detach().numpy()


def r2_score(y_true, y_pred):
    """Calculate R² using Pearson correlation."""
    r = pearsonr(y_true, y_pred, axis=1)
    return (r[0]**2).mean()

# solve optimal linear predictor
def solve_optimal_linear_predictor(Y, X, bias=True):
    if bias:
        X = np.hstack([X, np.ones((X.shape[0], 1))])
    beta = np.linalg.inv(X.T @ X) @ X.T @ Y
    if bias:
        return beta[:-1], beta[-1]
    return beta

In [24]:
cell_type = 'k562'  
ctrl_key = dataset.control_pert
pert_keys = [k for k in dataset.sets[cell_type] if k != ctrl_key and k in dataset.pert_embeddings]
eval_pert_keys = [k for k in dataset.eval_sets[cell_type] if k != ctrl_key and k in dataset.pert_embeddings]

with torch.no_grad():
    ctrl_X, ctrl_S, X_delta, S_delta, X_delta_recon = generate_set_mean_predictions(
        encoder, dataset.sets[cell_type], dataset.X, ctrl_key, pert_keys
    )
    _, _, X_delta_eval, S_delta_eval, X_delta_recon_eval = generate_set_mean_predictions(
        encoder, dataset.eval_sets[cell_type], dataset.X, ctrl_key, eval_pert_keys
    )


In [25]:

beta, bias = solve_optimal_linear_predictor(X_delta, S_delta)
X_delta_pred_full = S_delta_eval @ beta + bias
r2_score(X_delta_eval, X_delta_pred_full), mean_squared_error(X_delta_eval, X_delta_pred_full)

(0.9514674027501666, 0.08505320212344723)

In [None]:
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.kernel_ridge import KernelRidge
from sklearn.ensemble import RandomForestRegressor
import numpy as np

Z = np.vstack([dataset.pert_embeddings[k] for k in pert_keys])
Z_eval = np.vstack([dataset.pert_embeddings[k] for k in eval_pert_keys])

# compute all interactions of Z
Zi = np.einsum('bi,bj->bij', Z, Z).reshape(Z.shape[0], -1)
Zi_eval = np.einsum('bi,bj->bij', Z_eval, Z_eval).reshape(Z_eval.shape[0], -1)



reg = Ridge(alpha=1.)
# reg = KernelRidge(kernel='polynomial', degree=3, alpha=0.1)
# reg = RandomForestRegressor()
reg.fit(Zi, S_delta)
S_delta_pred_eval_kr = reg.predict(Zi_eval).astype(np.float32)
X_delta_pred_gde = S_delta_pred_eval_kr @ beta + bias


# mean predict the delta
reg = Ridge(alpha=1.)# grid_search.best_params_['alpha'])
reg.fit(Zi, X_delta)
X_delta_pred_full = reg.predict(Zi_eval).astype(np.float32)

r2_score(X_delta_eval, X_delta_pred_gde), mean_squared_error(X_delta_eval, X_delta_pred_gde), r2_score(X_delta_eval, X_delta_pred_full), mean_squared_error(X_delta_eval, X_delta_pred_full)

In [None]:
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.kernel_ridge import KernelRidge
from sklearn.ensemble import RandomForestRegressor
import numpy as np

Z = np.vstack([dataset.pert_embeddings[k] for k in pert_keys])
Z_eval = np.vstack([dataset.pert_embeddings[k] for k in eval_pert_keys])

# compute all interactions of Z
Zi = np.einsum('bi,bj->bij', Z, Z).reshape(Z.shape[0], -1)
Zi_eval = np.einsum('bi,bj->bij', Z_eval, Z_eval).reshape(Z_eval.shape[0], -1)


# search over alpha
# search over alpha values using cross validation
from sklearn.model_selection import GridSearchCV

param_grid = {'alpha': [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]}
ridge = Ridge()
grid_search = GridSearchCV(ridge, param_grid, cv=5, scoring='r2')
grid_search.fit(Zi, S_delta)

print(f"Best alpha: {grid_search.best_params_['alpha']}")
print(f"Best CV score: {grid_search.best_score_:.3f}")


reg = Ridge(alpha=grid_search.best_params_['alpha'])
# reg = KernelRidge(kernel='polynomial', degree=3, alpha=0.1)
# reg = RandomForestRegressor()
reg.fit(Zi, S_delta)
S_delta_pred_eval_kr = reg.predict(Zi_eval).astype(np.float32)
X_delta_pred_gde = S_delta_pred_eval_kr @ beta + bias


grid_search = GridSearchCV(ridge, param_grid, cv=5, scoring='r2')
grid_search.fit(Zi, X_delta)

print(f"Best alpha: {grid_search.best_params_['alpha']}")
print(f"Best CV score: {grid_search.best_score_:.3f}")

# mean predict the delta
reg = Ridge(alpha=0.1)# grid_search.best_params_['alpha'])
reg.fit(Zi, X_delta)
X_delta_pred_full = reg.predict(Zi_eval).astype(np.float32)

r2_score(X_delta_eval, X_delta_pred_gde), mean_squared_error(X_delta_eval, X_delta_pred_gde), r2_score(X_delta_eval, X_delta_pred_full), mean_squared_error(X_delta_eval, X_delta_pred_full)