In [2]:
# import packages
# need to run on GPU 
import argparse
from pathlib import Path
import tempfile
import shutil
import pickle
import json 
import time
import warnings
import copy
import os
import scanpy as sc
import numpy as np
import pandas as pd
from pathlib import Path
import torch
from torch import nn
from torch.nn import functional as F
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from torch_geometric.loader import DataLoader
import scgpt as scg
from scgpt.model import TransformerGenerator
from scgpt.loss import (
    masked_mse_loss,
    criterion_neg_log_bernoulli,
    masked_relative_error,
)
from scgpt.tokenizer import tokenize_batch, pad_batch, tokenize_and_pad_batch
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed, map_raw_id_to_vocab_id

import session_info

In [3]:
# import GEARS
from gears import PertData, GEARS
from gears.inference import compute_metrics, deeper_analysis, non_dropout_analysis
from gears.utils import create_cell_graph_dataset_for_prediction

In [4]:
# define datasets
dataset_name = "southard_rpe1_crispri"   # change if needed
test_train_config_path = "/.mounts/labs/steinlab/scratch/mtong/datasets/seq/southard_RPE1_CRISPRi/southard_data/results/set2conditions.json"
patience = 1 # how many epochs with increasing error are allowed
epochs = 15
pool_size = 100
seed = 1
working_dir = "/.mounts/labs/steinlab/scratch/mtong/datasets/seq/southard_RPE1_CRISPRi/southard_data"
result_id = "sept_24_scGPT"

# create out_dir
out_dir = f"{working_dir}/results/scGPT"
set_seed(seed)

In [5]:
# settings for data prcocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
pad_value = 0  # for padding values
pert_pad_id = 2

n_hvg = 0  # number of highly variable genes
include_zero_gene = "all"  # include zero expr genes in training input, "all", "batch-wise", "row-wise", or False
max_seq_len = 1536

# settings for training
MLM = True  # whether to use masked language modeling, currently it is always on.
CLS = False  # celltype classification objective
CCE = False  # Contrastive cell embedding objective
MVC = False  # Masked value prediction for cell embedding
ECS = False  # Elastic cell similarity objective
cell_emb_style = "cls"
mvc_decoder_style = "inner product, detach"
amp = True
load_model = "/.mounts/labs/steinlab/scratch/mtong/datasets/model/scGPT_pretrained" # preloaded data can be downloaded from link in scGPT Github
load_param_prefixs = [
    "encoder",
    "value_encoder",
    "transformer_encoder",
]

# settings for optimizer
lr = 1e-4  # or 1e-4
batch_size = 64
eval_batch_size = 64
epochs = epochs
schedule_interval = 1
early_stop = 5

# settings for the model
embsize = 512  # embedding dimension
d_hid = 512  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 12  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 8  # number of heads in nn.MultiheadAttention
n_layers_cls = 3
dropout = 0.2  # dropout probability
use_fast_transformer = True  # whether to use fast transformer

# logging
log_interval = 100
device = "cuda"

In [6]:
# loading data into PertData object
pert_data_folder = Path("/.mounts/labs/steinlab/scratch/mtong/datasets/seq/southard_RPE1_CRISPRi")
pert_data = PertData(pert_data_folder)
pert_data.load(data_path = f"{working_dir}/{dataset_name}")

conds = pert_data.adata.obs["condition"].cat.remove_unused_categories().cat.categories.tolist()
gene_names = pert_data.adata.var["gene_name"].values.tolist() + ["ctrl"]
good_conds = np.array(conds)[[len(c) == 1 or (c[0] in gene_names and c[1] in gene_names) for c in [co.split("+") for co in conds]]]
pert_data.adata = pert_data.adata[[c in good_conds for c in pert_data.adata.obs["condition"]],:]

with open(test_train_config_path, "r") as json_file:
    set2conditions = json.load(json_file)

Local copy of pyg dataset is detected. Loading...
Done!


In [7]:
# filter out problematic conditions
set2conditions["train"] = [c for c in set2conditions["train"] if c in good_conds]
set2conditions["test"] = [c for c in set2conditions["test"] if c in good_conds]
set2conditions["val"] = [c for c in set2conditions["val"] if c in good_conds]

print(set2conditions)
pert_data.set2conditions = set2conditions
pert_data.split = "custom"
pert_data.subgroup = None
pert_data.seed = 1
pert_data.train_gene_set_size = 0.75
pert_data.get_dataloader(batch_size = batch_size, test_batch_size = eval_batch_size)
logger = scg.logger 

Creating dataloaders....
Done!


{'test': ['RPL18A+ctrl', 'RPS8+ctrl', 'POGLUT3+ctrl', 'EIF3A+ctrl', 'RPL15+ctrl', 'POLR3E+ctrl', 'UXT+ctrl', 'PPP1R37+ctrl', 'RPL4+ctrl', 'KPNB1+ctrl', 'RPL38+ctrl', 'MED30+ctrl', 'RPS3A+ctrl', 'RPS15A+ctrl', 'TYK2+ctrl', 'RPS19+ctrl', 'EIF3CL+ctrl', 'ERCC3+ctrl', 'EXOSC8+ctrl', 'EIF1AX+ctrl', 'NUP214+ctrl', 'FUNDC2+ctrl', 'EIF3E+ctrl', 'RPTOR+ctrl', 'EXOSC5+ctrl', 'EIF3M+ctrl', 'RPS26+ctrl', 'GTF2E1+ctrl', 'TSR2+ctrl', 'MED1+ctrl', 'RPL35+ctrl', 'RPL23+ctrl', 'RPS4X+ctrl', 'RPL6+ctrl'], 'train': ['ctrl', 'ADAM10+ctrl', 'RPL14+ctrl', 'MED21+ctrl', 'MED11+ctrl', 'UTP18+ctrl', 'NUP205+ctrl', 'WDR43+ctrl', 'SIRT7+ctrl', 'MED18+ctrl', 'RPL37A+ctrl', 'NUP98+ctrl', 'EIF3I+ctrl', 'RPL32+ctrl', 'RPL39+ctrl', 'GLE1+ctrl', 'TMX2+ctrl', 'URI1+ctrl', 'RPS14+ctrl', 'XRCC5+ctrl', 'MED9+ctrl', 'RPS29+ctrl', 'RPS6+ctrl', 'TBCB+ctrl', 'POLR3D+ctrl', 'RPS13+ctrl', 'EXOSC3+ctrl', 'TAF5+ctrl', 'NUP85+ctrl', 'NUP88+ctrl', 'MED12+ctrl', 'RPS2+ctrl', 'RPLP0+ctrl', 'RPL30+ctrl', 'MED20+ctrl', 'EIF3D+ctrl', 'U

In [8]:
# load model
model_dir = Path(load_model)
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"
vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

pert_data.adata.var["id_in_vocab"] = [
    1 if gene in vocab else -1 for gene in pert_data.adata.var["gene_name"]
]
gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
logger.info(
    f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
    f"in vocabulary of size {len(vocab)}."
)
genes = pert_data.adata.var["gene_name"].tolist()
# not a lot of genes match here, might have to go back and convert gene names to standardized form to match pert_data 

scGPT - INFO - match 3605/5129 genes in vocabulary of size 60697.


In [9]:
# modified code for model loading - does not require flash attention
ntokens = len(vocab)  # size of vocabulary
model = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=False,  # no flash attention installed
)

if load_model is not None:
    pretrained_dict = torch.load(model_file)

    # Convert fused Wqkv weights to standard in_proj_weight/in_proj_bias
    for layer_idx in range(nlayers):
        wqkv_key = f'transformer_encoder.layers.{layer_idx}.self_attn.Wqkv.weight'
        bqkv_key = f'transformer_encoder.layers.{layer_idx}.self_attn.Wqkv.bias'

        if wqkv_key in pretrained_dict:
            wqkv = pretrained_dict[wqkv_key]
            bqkv = pretrained_dict[bqkv_key]

            d = wqkv.shape[0] // 3  # hidden size

            pretrained_dict[f'transformer_encoder.layers.{layer_idx}.self_attn.in_proj_weight'] = torch.cat(
                [wqkv[:d, :], wqkv[d:2*d, :], wqkv[2*d:, :]], dim=0
            )
            pretrained_dict[f'transformer_encoder.layers.{layer_idx}.self_attn.in_proj_bias'] = torch.cat(
                [bqkv[:d], bqkv[d:2*d], bqkv[2*d:]], dim=0
            )

            # remove fused keys
            del pretrained_dict[wqkv_key]
            del pretrained_dict[bqkv_key]

    model_dict = model.state_dict()
    # Filter pretrained_dict to only include matching shapes
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
    for k, v in pretrained_dict.items():
        logger.info(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict, strict=False)
    logger.info(f"Loaded checkpoint {model_file} with Wqkv conversion.")

model.to(device)


scGPT - INFO - Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading params encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading params value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Si

TransformerGenerator(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (pert_encoder): Embedding(3, 512, padding_idx=2)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias

In [10]:
# model
with open(model_config_file, "r") as f:
    model_configs = json.load(f)
logger.info(
    f"Resume model from {model_file}, the model args will override the "
    f"config {model_config_file}."
)
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]

vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)

scGPT - INFO - Resume model from /.mounts/labs/steinlab/scratch/mtong/datasets/model/scGPT_pretrained/best_model.pt, the model args will override the config /.mounts/labs/steinlab/scratch/mtong/datasets/model/scGPT_pretrained/args.json.


In [12]:
# define functions
criterion = masked_mse_loss
criterion_cls = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, schedule_interval, gamma=0.9)
scaler = torch.cuda.amp.GradScaler(enabled=amp)


def train(model: nn.Module, train_loader: torch.utils.data.DataLoader) -> None:
    """
    Train the model for one epoch.
    """
    model.train()
    total_loss, total_mse = 0.0, 0.0
    start_time = time.time()

    num_batches = len(train_loader)
    for batch, batch_data in enumerate(train_loader):
        batch_size = len(batch_data.y)
        batch_data.to(device)
        x: torch.Tensor = batch_data.x  # (batch_size * n_genes, 2)
        ori_gene_values = x[:, 0].view(batch_size, n_genes)
        pert_flags = x[:, 1].long().view(batch_size, n_genes)
        target_gene_values = batch_data.y  # (batch_size, n_genes)

        if include_zero_gene in ["all", "batch-wise"]:
            if include_zero_gene == "all":
                input_gene_ids = torch.arange(n_genes, device=device, dtype=torch.long)
            else:
                input_gene_ids = (
                    ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
                )
            # sample input_gene_id
            if len(input_gene_ids) > max_seq_len:
                input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[
                    :max_seq_len
                ]
            input_values = ori_gene_values[:, input_gene_ids]
            input_pert_flags = pert_flags[:, input_gene_ids]
            target_values = target_gene_values[:, input_gene_ids]

            mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
            mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)

            # src_key_padding_mask = mapped_input_gene_ids.eq(vocab[pad_token])
            src_key_padding_mask = torch.zeros_like(
                input_values, dtype=torch.bool, device=device
            )

        with torch.cuda.amp.autocast(enabled=amp):
            output_dict = model(
                mapped_input_gene_ids,
                input_values,
                input_pert_flags,
                src_key_padding_mask=src_key_padding_mask,
                CLS=CLS,
                CCE=CCE,
                MVC=MVC,
                ECS=ECS,
            )
            output_values = output_dict["mlm_output"]

            masked_positions = torch.ones_like(
                input_values, dtype=torch.bool
            )  # Use all
            loss = loss_mse = criterion(output_values, target_values, masked_positions)

        model.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"Found infinite gradient. This may be caused by the gradient "
                    f"scaler. The current scale is {scaler.get_scale()}. This warning "
                    "can be ignored if no longer occurs after autoscaling of the scaler."
                )
        scaler.step(optimizer)
        scaler.update()

        # torch.cuda.empty_cache()

        total_loss += loss.item()
        total_mse += loss_mse.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            cur_mse = total_mse / log_interval
            # ppl = math.exp(cur_loss)
            logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} |"
            )
            total_loss = 0
            total_mse = 0
            start_time = time.time()


def evaluate(model: nn.Module, val_loader: torch.utils.data.DataLoader) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0

    with torch.no_grad():
        for batch, batch_data in enumerate(val_loader):
            batch_size = len(batch_data.y)
            batch_data.to(device)
            x: torch.Tensor = batch_data.x  # (batch_size * n_genes, 2)
            ori_gene_values = x[:, 0].view(batch_size, n_genes)
            pert_flags = x[:, 1].long().view(batch_size, n_genes)
            target_gene_values = batch_data.y  # (batch_size, n_genes)

            if include_zero_gene in ["all", "batch-wise"]:
                if include_zero_gene == "all":
                    input_gene_ids = torch.arange(n_genes, device=device)
                else:  # when batch-wise
                    input_gene_ids = (
                        ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
                    )

                # sample input_gene_id
                if len(input_gene_ids) > max_seq_len:
                    input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[
                        :max_seq_len
                    ]
                input_values = ori_gene_values[:, input_gene_ids]
                input_pert_flags = pert_flags[:, input_gene_ids]
                target_values = target_gene_values[:, input_gene_ids]

                mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
                mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)

                # src_key_padding_mask = mapped_input_gene_ids.eq(vocab[pad_token])
                src_key_padding_mask = torch.zeros_like(
                    input_values, dtype=torch.bool, device=input_values.device
                )
            with torch.cuda.amp.autocast(enabled=amp):
                output_dict = model(
                    mapped_input_gene_ids,
                    input_values,
                    input_pert_flags,
                    src_key_padding_mask=src_key_padding_mask,
                    CLS=CLS,
                    CCE=CCE,
                    MVC=MVC,
                    ECS=ECS,
                    do_sample=True,
                )
                output_values = output_dict["mlm_output"]

                masked_positions = torch.ones_like(
                    input_values, dtype=torch.bool, device=input_values.device
                )
                loss = criterion(output_values, target_values, masked_positions)
            total_loss += loss.item()
            total_error += masked_relative_error(
                output_values, target_values, masked_positions
            ).item()
    return total_loss / len(val_loader), total_error / len(val_loader)

In [13]:
# define prediction function - mean values
def predict(
    model: TransformerGenerator, pert_list, pool_size = None
):
    adata = pert_data.adata
    ctrl_adata = adata[adata.obs["condition"] == "ctrl"]
    if pool_size is None:
        pool_size = len(ctrl_adata.obs)
    gene_list = pert_data.gene_names.values.tolist()
    for pert in pert_list:
        for i in pert:
            if i not in gene_list:
                raise ValueError(
                    "The gene is not in the perturbation graph. Please select from GEARS.gene_list!"
                )

    model.eval()
    device = next(model.parameters()).device
    with torch.no_grad():
        results_pred = {}
        for pert in pert_list:
            cell_graphs = create_cell_graph_dataset_for_prediction(
                pert, ctrl_adata, gene_list, device, num_samples=pool_size
            )
            loader = DataLoader(cell_graphs, batch_size=eval_batch_size, shuffle=False)
            preds = []
            for batch_data in loader:
                pred_gene_values = model.pred_perturb(
                    batch_data, include_zero_gene, gene_ids=gene_ids, amp=amp
                )
                preds.append(pred_gene_values)
            preds = torch.cat(preds, dim=0)
            results_pred["_".join(pert)] = np.mean(preds.detach().cpu().numpy(), axis=0)

    return results_pred


In [14]:
best_val_loss = float("inf")
best_model = copy.deepcopy(model)
patience = patience

In [15]:
# model training/fine-tuning
for epoch in range(0, epochs):
    epoch_start_time = time.time()
    train_loader = pert_data.dataloader["train_loader"]
    valid_loader = pert_data.dataloader["val_loader"]

    train(
        model,
        train_loader,
    )
    val_loss, val_mre = evaluate(
        model,
        valid_loader,
    )
    elapsed = time.time() - epoch_start_time
    logger.info("-" * 89)
    logger.info(
        f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
        f"valid loss/mse {val_loss:5.4f} |"
    )
    logger.info("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        logger.info(f"Best model with score {best_val_loss:5.4f}")
        patience = 0
    else:
        patience += 1
        if patience >= early_stop:
            logger.info(f"Early stop at epoch {epoch}")
            break

    scheduler.step()

scGPT - INFO - | epoch   0 | 100/819 batches | lr 0.0001 | ms/batch 709.35 | loss  0.12 | mse  0.12 |
scGPT - INFO - | epoch   0 | 200/819 batches | lr 0.0001 | ms/batch 380.02 | loss  0.09 | mse  0.09 |
scGPT - INFO - | epoch   0 | 300/819 batches | lr 0.0001 | ms/batch 380.17 | loss  0.09 | mse  0.09 |
scGPT - INFO - | epoch   0 | 400/819 batches | lr 0.0001 | ms/batch 380.18 | loss  0.09 | mse  0.09 |
scGPT - INFO - | epoch   0 | 500/819 batches | lr 0.0001 | ms/batch 380.00 | loss  0.09 | mse  0.09 |
scGPT - INFO - | epoch   0 | 600/819 batches | lr 0.0001 | ms/batch 380.17 | loss  0.09 | mse  0.09 |
scGPT - INFO - | epoch   0 | 700/819 batches | lr 0.0001 | ms/batch 379.97 | loss  0.09 | mse  0.09 |
scGPT - INFO - | epoch   0 | 800/819 batches | lr 0.0001 | ms/batch 379.99 | loss  0.09 | mse  0.09 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   0 | time: 354.84s | valid loss/mse 0.0924 |
scG

In [16]:
# save models and generate predictions
conds = pert_data.adata.obs["condition"].cat.remove_unused_categories().cat.categories.tolist()
split_conds = [x.split("+") for x in conds]
split_conds = [list(filter(lambda y: y != "ctrl", x)) for x in split_conds]

out_dir_model = "/.mounts/labs/steinlab/scratch/mtong/datasets/seq/southard_RPE1_CRISPRi/southard_data/results/sept_24_scGPT"
torch.save(best_model.state_dict(), f"{out_dir_model}/full_best_model.pt")

all_pred_vals = predict(best_model, split_conds, pool_size=pool_size)
all_pred_vals = {k: v.tolist() for k, v in all_pred_vals.items()}

In [17]:
# save ground truth
ground_truth_vals = {}
for cond in conds:
    obs_idx = pert_data.adata.obs['condition'] == cond
    mean_expr = np.asarray(pert_data.adata[obs_idx, :].X.mean(axis=0)).ravel()
    # Remove "+ctrl" from the key to match prediction
    key = cond.replace("+ctrl", "")
    ground_truth_vals[key] = mean_expr.tolist()

# convert empty string key to 'ctrl'
if '' in all_pred_vals:
    all_pred_vals['ctrl'] = all_pred_vals.pop('')

# sanity check
pred_conditions = set(all_pred_vals.keys())
gt_conditions = set(ground_truth_vals.keys())
assert pred_conditions == gt_conditions, f"Mismatch in conditions: {pred_conditions ^ gt_conditions}"

In [18]:
# save 
with open(f"{out_dir}/all_predictions.json", 'w', encoding="utf8") as handle:
    json.dump(all_pred_vals, handle, indent = 4)
with open(f"{out_dir}/gene_names.json", 'w', encoding="utf8") as handle:
    json.dump(pert_data.adata.var["gene_name"].values.tolist(), handle, indent = 4)
with open(f"{out_dir}/all_ground_truth.json", 'w', encoding="utf8") as handle:
    json.dump(ground_truth_vals, handle, indent=4)

session_info.show()
print("Python done")

Python done
