# Example notebook (simulated data)

In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Imports

In [14]:
from novaice.tl import ChemPertVAEModel, ChemPertMLPModel
import anndata as ad
import numpy as np

## Data

Generate example data

In [15]:
# Create example data
n_samples = 100
n_genes = 500
embedding_dim = 768
gene_expr = np.random.randn(n_samples, n_genes)  # Gene expression
drug_emb = np.random.randn(n_samples, embedding_dim)  # Drug embeddings
adata = ad.AnnData(X=gene_expr)
adata.obsm["drug_embedding"] = drug_emb

## VAE

In [16]:
# Setup and train model
ChemPertVAEModel.setup_anndata(adata, drug_embedding_key="drug_embedding")
model = ChemPertVAEModel(adata)
model.train(max_epochs=50)
# Predict gene expression
predictions = model.predict_gene_expression()
# Get latent representation
latent = model.get_latent_representation()

[34mINFO    [0m Generating sequential column names                                                                        


  accelerator, lightning_devices, device = parse_device_args(
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
/Users/lucas-diedrich/mamba/envs/hackathon/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:166: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/lucas-diedrich/mamba/envs/hackathon/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
/Users/lucas-diedrich/mamba/envs/hackathon/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training:   0%|          | 0/50 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [18]:
# Setup and train model
ChemPertMLPModel.setup_anndata(adata, drug_embedding_key="drug_embedding")
model = ChemPertMLPModel(adata)
model.train(max_epochs=50)
# Predict gene expression
predictions = model.predict_gene_expression()
# Get latent representation
latent = model.get_prediction_error()

[34mINFO    [0m Generating sequential column names                                                                        


  accelerator, lightning_devices, device = parse_device_args(
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
/Users/lucas-diedrich/mamba/envs/hackathon/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:166: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/lucas-diedrich/mamba/envs/hackathon/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
/Users/lucas-diedrich/mamba/envs/hackathon/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training:   0%|          | 0/50 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [19]:
model.get_prediction_error(method="r2")

array([ 0.08062749,  0.09162999,  0.0323052 ,  0.10616307,  0.09551314,
        0.10559062,  0.03638097, -0.04228011,  0.11437934,  0.11619559,
        0.13928162,  0.0689207 ,  0.06799291, -0.00320087,  0.07779458,
       -0.04287967,  0.077312  , -0.0151089 ,  0.12280114,  0.03974983,
        0.09240473,  0.0833765 ,  0.06498342,  0.04111685,  0.10994626,
        0.14451183,  0.07440296,  0.13742468,  0.08721003,  0.10177605,
        0.05887708,  0.1774986 ,  0.12156647,  0.02645941,  0.16160898,
        0.07046322,  0.18522587, -0.06495865,  0.06359713, -0.01569936,
        0.00361843,  0.0793735 ,  0.23512301,  0.05725976,  0.05842183,
        0.10827314,  0.1226686 ,  0.10583632,  0.06769605,  0.06339478,
        0.09591005,  0.06394573,  0.18466068,  0.05429771,  0.0886385 ,
        0.19613076,  0.10455579, -0.03458726,  0.07521188,  0.04424723,
        0.09827788,  0.03463786,  0.12628488,  0.13443344,  0.07028681,
        0.11729928,  0.14073009, -0.03576906,  0.04965198,  0.04