In [None]:
### Import Libraries.

import os
import warnings
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import magic
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from tqdm import tqdm
from scipy.stats import wilcoxon
from sklearn.neighbors import NearestNeighbors
from matplotlib.colors import LinearSegmentedColormap, Normalize

# Custom Margaret metric-learning imports
from train_metric import train_metric_learner
from utils.util import get_start_cell_cluster_id
from utils.plot import (
    plot_connectivity_graph,
    plot_trajectory_graph_v2,
    generate_plot_embeddings
)
from models.ti.connectivity import (
    compute_directed_cluster_connectivity,
    compute_undirected_cluster_connectivity
)
from models.ti.graph import compute_connectivity_graph
from models.ti.pseudotime_v2 import compute_pseudotime
import statsmodels.formula.api as smf

In [None]:
### Load Data.

os.chdir("/folder/")
adata = ad.read_h5ad('adata.h5ad')

In [None]:
### MAGIC Î™mputation on HVGs

sc.pp.neighbors(adata, n_neighbors = 40, use_rep = 'X_harmony')
hvg_genes = adata.var[adata.var['highly_variable']].index
magic_op = magic.MAGIC(random_state = 0, solver = 'approximate')
adata.obsm['X_magic'] = magic_op.fit_transform(adata[:, hvg_genes].X)

In [None]:
### Metric Learning Embedding.

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    train_metric_learner(
        adata,
        n_episodes = 30,
        n_metric_epochs = 10,
        obsm_data_key = 'X_magic',
        code_size = 10,
        backend = 'leiden',
        device = 'cpu', ###gpu if available
        cluster_kwargs = {'random_state': 0, 'resolution': 1},
        nn_kwargs = {'random_state': 0, 'n_neighbors': 15},
        trainer_kwargs = {'optimizer': 'SGD', 'lr': 0.01, 'batch_size': 256}
    )

adata.obsm['X_met_embedding'] = generate_plot_embeddings(
    adata.obsm['metric_embedding'],
    method = 'umap',
    n_neighbors = 40,
    spread = 3
)

In [None]:
### Mapping and Connectivity.

mapping = {
    'Cluster_1': 1,
    'Cluster_1': 2,
    'Cluster_1': 3,
    'Cluster_1': 4
}
adata.obs['Cluster_Column_Num'] = adata.obs['Cluster_Column'].map(mapping)
communities = adata.obs['Cluster_Column_Num'].to_numpy()
adj_conn = adata.obsp['connectivities']

In [None]:
### Compute Undirected and Directed Connectivity.

un_connectivity, un_z_score = compute_undirected_cluster_connectivity(communities, adj_conn, z_threshold = 0.2)
connectivity, z_score = compute_directed_cluster_connectivity(communities, adj_conn, threshold = 0.2)

In [None]:
### Compute Pseudotime.

start_cell_ids = ['Starting_Cell']
start_cluster_ids = get_start_cell_cluster_id(adata, start_cell_ids, communities)
pseudotime = compute_pseudotime(adata, start_cell_ids, adata.obsp['distances'], connectivity)

adata.obs['metric_pseudotime_v2'] = pseudotime

In [None]:
### Plot Connectivity and Pseudotime

plot_connectivity_graph(adata.obsm['X_met_embedding'], communities, un_connectivity,
                        mode = 'undirected', offset = 0.2, cmap = 'Blues', node_size = 750)

plot_trajectory_graph_v2(
    pseudotime, connectivity, communities, connectivity,
    node_positions = None, offset = 0.5, figsize = (7, 7),
    node_size = 1000, font_size = 10