In [0]:
#create environment
!apt install python3.9-venv
!python -m venv ./venv/scgen_cuda124

In [0]:
!source venv/scgen_cuda124/bin/activate
!pip list

In [0]:
!pip install scvi-tools
!pip install scanpy
!pip install git+https://github.com/theislab/scgen.git
!pip install --quiet scgen[tutorials]

In [0]:
!pip install huggingface_hub

In [0]:
!pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124

In [0]:
!env CUDA_HOME=/usr/local/cuda

### Loading Train Data

In [0]:
!pip install poetry
!poetry install
!pip install Cython


In [0]:
import typing

In [0]:
import torch

In [0]:
import logging
import sklearn
import numpy as np
import os
import scgen
import scanpy as sc

import perturbench

In [0]:
os.chdir('/Workspace/Users/kl11@sanger.ac.uk/scEvalsJam')

In [0]:
!wget https://huggingface.co/datasets/scEvalsJam/datasets/resolve/main/1gene-norman.h5ad

In [0]:
!wget https://huggingface.co/datasets/scEvalsJam/datasets/resolve/eca3347678cda5b08843caf60c8705fff70c590b/1gene-replogle-essential-split.h5ad?download=true

In [0]:
np.random.seed(42)

In [0]:
os.path.exists("/databricks/driver/1gene-norman.h5ad")

In [0]:
adata = sc.read_h5ad("/databricks/driver/1gene-norman.h5ad")

In [0]:
#create test train split
train_size = int(adata.shape[0]*0.8)
test_size = adata.shape[0] - train_size


In [0]:
train_idx = np.random.choice(adata.obs_names, size = train_size, replace = False)

In [0]:
train_set = adata[adata.obs_names.isin(train_idx.tolist())].copy()
test_set = adata[~adata.obs_names.isin(train_idx.tolist())].copy()

In [0]:
train_set = adata[adata.obs_names.isin(train_idx.tolist())].copy()
test_set = adata[~adata.obs_names.isin(train_idx.tolist())].copy()

In [0]:
scgen.SCGEN.setup_anndata(train_set, batch_key="condition", labels_key="cell_type")

In [0]:
model = scgen.SCGEN(train_set)
model.save("../saved_models/model_perturbation_prediction.pt", overwrite=True)

In [0]:
model.train(
    max_epochs=100,
    batch_size=32,
    early_stopping=True,
    early_stopping_patience=25
)

## Prediction

In [0]:
pred, delta = model.predict(
    ctrl_key='train',
    stim_key='test',
    celltype_to_predict='CD4T'
)
pred.obs['condition'] = 'pred'