In [1]:
import veloagent
import anndata
import torch
import matplotlib.pyplot as plt
import numpy as np
import scvelo as scv
import scanpy as sc

Global seed set to 0


## Load AnnData

In [2]:
adata = anndata.read_h5ad('/Users/brentyoon/Documents/COMP402/code/dataset/breast_cancer.h5ad')
genes_to_keep = ~adata.var_names.str.startswith("ENSG")
adata = adata[:, genes_to_keep]

In [None]:
sc.pp.neighbors(adata, n_neighbors=30)

max_neighbors = 29
neighbor_indices = []

mat = adata.obsp['connectivities'].todense()

for row in mat:
    neighbors = np.nonzero(row)[1]
    padded = np.full(max_neighbors, -1, dtype=int)  # -1 indicates no neighbor
    padded[:min(len(neighbors), max_neighbors)] = neighbors[:max_neighbors]
    neighbor_indices.append(padded)

neighbor_indices = np.array(neighbor_indices, dtype=int)
adata.uns['neighbors']['indices'] = neighbor_indices

## Perturbation parameters

In [None]:
# for option 1
cluster_edges_opt1 = [('cancer1','stromal cells'),('cancer2','stromal cells'),('cancer2','fibroblasts')]

In [3]:
# for option 2
cluster_edges_opt2 = ['cancer1','stromal cells','cancer2','fibroblasts']

In [4]:
cluster_name = 'clusters'
my_pert_param = 'alpha'

## Velocity projection before perturbation

In [None]:
scv.tl.velocity_graph(adata, vkey='velocity')
scv.pl.velocity_embedding_stream(adata,
                                    legend_loc="right margin", 
                                    vkey='velocity', 
                                    basis='umap', 
                                    color=[cluster_name],
                                    dpi=150)

## Calculating gene score for perturbation

In [5]:
scores = veloagent.perturbation_score(adata, cluster_name, cluster_edges_opt2, vel_key='velocity', metric_option=2, pert_param=my_pert_param)

Processing gene 0/1011


KeyError: 'velocity_u'

In [None]:
scores.sort_values(by='score')

In [None]:
# plot histogram of scores
scores.hist(bins=100, figsize=(10, 6))
plt.xlabel('Score')
plt.ylabel('Num Genes')
plt.tight_layout()
plt.show()

In [None]:
# if you want to save the results
scores.to_csv('perturbations.csv')

## Perturbation

In [None]:
# Genes to perturb
gene_list = ['Gm15564',
'Slc1a2',
'Mbp',
'Cdk8',
'Camk1d',
'Plp1',
'Kcnma1',
'Meg3',
'Mobp',
'Pcdh9',
'Zbtb20',
'Ptprd',
'Rora',
'Lsamp',
'Qk',
'Ppp3ca',       
'Nrxn1',
'Syt1',
'Celf2',
'Pde10a',
'Trpm3',
'Grin2a',
'Hexb',
'Fgfr2']