In [1]:
import torch
import torch.optim as optim
import numpy as np
import scanpy as sc
import scvelo as scv
import sys, os
current_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))
from model import NETWORK  # Ensure that model.py is saved in the same directory
from dataloaders import * # Ensure that dataloaders.py is saved in the same directory
from utils import *
from sklearn.manifold import Isomap
from sklearn.neighbors import kneighbors_graph
from scipy.sparse.csgraph import dijkstra
import gc
import pandas as pd
import seaborn as sns

In [None]:
# Setup configuration
latent_dim = 64  # Latent dimension size, can be adjusted
hidden_dim = 512  # Hidden dimension size for the encoder and decoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

n_components = 100
n_knn_search = 10
dataset_name = "gastrulation_erythroid"
cell_type_key = "celltype"
model_name = "imVelo"

num_genes = 2000
nhead = 1 #original: 1
embedding_dim = 128*nhead# original: 128
num_encoder_layers = 1 #original: 1
num_bins = 50
batch_size = 128  # Batch size for training
epochs = 10  # Number of epochs for training
learning_rate = 1e-4  # Learning rate for the optimizer
lambda1 = 1e-1  # Weight for heuristic loss
lambda2 = 1 # Weight for discrepancy loss
K = 11  # Number of neighbors for heuristic loss

# Load data
adata = sc.read_h5ad("gastrulation_erythroid_common_smoothing.h5ad")

adata.obs[cell_type_key] = [str(cat) for cat in list(adata.obs[cell_type_key])]
adata.obs[cell_type_key] = pd.Series(adata.obs[cell_type_key], dtype="category")
unique_categories = adata.obs[cell_type_key].cat.categories
rgb_colors = sns.color_palette("tab20", len(unique_categories))
hex_colors = ['#%02x%02x%02x' % (int(r*255), int(g*255), int(b*255)) for r, g, b in rgb_colors]
adata.uns[f"{cell_type_key}_colors"] = hex_colors
print(dataset_name)
adata.layers['counts_unspliced'] = adata.layers["unspliced"].copy()
adata.layers['counts_spliced'] = adata.layers["spliced"].copy()

# Initialize model, optimizer, and loss function
model = NETWORK(input_dim=adata.shape[1]*2, latent_dim=latent_dim, 
                hidden_dim=hidden_dim, emb_dim = embedding_dim,
                nhead=nhead, num_encoder_layers=num_encoder_layers,
                num_genes=num_genes, num_bins=num_bins).to(device)
                
model.load_state_dict(torch.load('model.pth'))

# Ensure to call model.eval() if you're loading the model for inference to set the dropout and batch normalization layers to evaluation mode
model.eval()


In [None]:
_, _, full_data_loader = setup_dataloaders_binning(adata, 
                                                    batch_size=batch_size, 
                                                    num_genes=num_genes, 
                                                    num_bins=num_bins)

# Initialize empty layers in adata for storing results
adata.layers["velocity_u"] = np.zeros_like(adata.layers["Mu"], dtype=np.float32)
adata.layers["velocity"] = np.zeros_like(adata.layers["Ms"], dtype=np.float32)
adata.obsm["pred"] = np.zeros((adata.shape[0], adata.shape[1] * 2), dtype=np.float32)
adata.obsm["cell_embeddings"] = np.zeros((adata.shape[0], adata.shape[1] * 2), dtype=np.float32)
adata.layers["pp"] = np.zeros_like(adata.layers["Mu"])  # Same shape as Mu
adata.layers["nn"] = np.zeros_like(adata.layers["Mu"])  # Same shape as Mu
adata.layers["pn"] = np.zeros_like(adata.layers["Mu"])  # Same shape as Mu
adata.layers["np"] = np.zeros_like(adata.layers["Mu"])  # Same shape as Mu
model.eval()
with torch.no_grad():
    for batch_idx, (tokens, data, batch_indices) in enumerate(full_data_loader):
        print(f"Batch {batch_idx+1}/{len(full_data_loader)}")
        tokens = tokens.to(device)
        data = data.to(device)
        out_dic = model(tokens, data)

        # Store results and convert to numpy inside the loop to reduce peak memory usage
        adata.layers["velocity_u"][batch_indices] = out_dic["v_u"].detach().cpu().numpy()
        adata.layers["velocity"][batch_indices] = out_dic["v_s"].detach().cpu().numpy()
        adata.obsm["pred"][batch_indices] = out_dic["pred"].detach().cpu().numpy()
        adata.obsm["cell_embeddings"][batch_indices] = out_dic["cell_embeddings"].detach().cpu().numpy()
        adata.layers["pp"][batch_indices] = out_dic["pp"].cpu().numpy()
        adata.layers["nn"][batch_indices] = out_dic["nn"].cpu().numpy()
        adata.layers["pn"][batch_indices] = out_dic["pn"].cpu().numpy()
        adata.layers["np"][batch_indices] = out_dic["np"].cpu().numpy()

        # Explicit memory cleanup
        del tokens, data, out_dic
        gc.collect()
        torch.cuda.empty_cache()  # If using CUDA

In [None]:
np.isnan(adata.layers["velocity_u"]).sum()

In [None]:
np.isnan(adata.layers["velocity"]).sum()

In [6]:
adata.obsm["MuMs"] = np.concatenate([adata.layers["Mu"], adata.layers["Ms"]], axis=1)
adata.obsm["velocity"] = np.concatenate([adata.layers["velocity_u"], adata.layers["velocity"]], axis=1)

In [7]:
adata.layers["velocity_u"] *= -1
adata.layers["velocity"] *= -1

In [None]:
sc.pp.neighbors(adata)
scv.tl.velocity_graph(adata)
scv.tl.velocity_confidence(adata)
scv.tl.velocity_pseudotime(adata)
keys = ["velocity_confidence", "velocity_length", "velocity_pseudotime"]
sc.pl.umap(adata, color=keys)
scv.pl.velocity_embedding_stream(adata, color=keys, basis="umap")

In [None]:
sc.pp.neighbors(adata,use_rep="velocity")
sc.tl.umap(adata)
keys = [cell_type_key, "stage"]
sc.pl.umap(adata, color=keys)

In [None]:
sc.pp.neighbors(adata, use_rep="cell_embeddings")
sc.tl.umap(adata)
sc.pl.umap(adata, color=[cell_type_key, "stage"])

In [None]:
scv.tl.rank_velocity_genes(adata, groupby=cell_type_key)
pd.DataFrame(adata.uns["rank_velocity_genes"]["names"])

In [12]:
gene_names = ["Actb", "Hba-x", "Rap1b", "Rpl18a"]

In [None]:
for stage in adata.obs["stage"].unique():
    adata_tmp = adata[adata.obs["stage"] == stage].copy()
    sc.pp.neighbors(adata_tmp)
    sc.tl.umap(adata_tmp)
    scv.pl.velocity_embedding_stream(adata_tmp, color=[cell_type_key], basis="umap")
    for gene_name in gene_names:
        plot_phase_plane(adata_tmp, gene_name, dataset_name, 11, 
                    u_scale=0.1, s_scale=0.1, cell_type_key=cell_type_key,
                    save_path="plots/plot1.png")

In [None]:
for gene_name in gene_names:
    plot_phase_plane(adata, gene_name, dataset_name, 11, 
                    u_scale=0.1, s_scale=0.1, cell_type_key=cell_type_key,
                    save_path="plots/plot1.png")
    plot_phase_plane(adata, gene_name, dataset_name, 11, 
                    u_scale=0.1, s_scale=0.1, cell_type_key="stage",
                    save_path="plots/plot1.png")
    