In [None]:
import numpy as np
import scanpy as sc
import scvelo as scv
import matplotlib.pyplot as plt
import pandas as pd
from torchdiffeq import odeint
import torch
import torch.nn as nn
import warnings
from torch.utils.data import DataLoader
# Filter out DeprecationWarnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
from scipy.stats import norm
import gc
import os
from IPython.display import clear_output
import time
import pickle

#AnnData for trajectory inference and cell dynamics analysis
dataset = scv.datasets.gastrulation_erythroid()

dataset.X = dataset.layers["spliced"]
dataset.layers["unspliced_counts"] = dataset.layers["unspliced"]
dataset.layers["spliced_counts"] = dataset.layers["spliced"]
scv.pp.filter_and_normalize(dataset, min_shared_counts=20, n_top_genes=2000)
scv.pp.moments(dataset, n_pcs=30, n_neighbors=30)
#u_s_normalized = np.concatenate([dataset.layers["Mu"], dataset.layers["Ms"]], axis=1)
u_s_normalized = np.concatenate([dataset.layers["unspliced"].toarray(), dataset.layers["spliced"].toarray()], axis=1)
u_s_counts = np.concatenate([dataset.layers["unspliced_counts"].toarray(), dataset.layers["spliced_counts"].toarray()], axis=1)
mu_ms = np.concatenate([dataset.layers["Mu"], dataset.layers["Ms"]], axis=1)
adata = sc.AnnData(X=mu_ms)
adata.layers["u_s_counts"] = u_s_normalized
adata.layers["u_s_normalized"] = u_s_normalized
adata.layers["mu_ms"] = mu_ms

adata.obs = dataset.obs.copy()
var_names = dataset.var_names.copy()
adata.var_names = np.concatenate([var_names+"_u", var_names+"_s"])
adata.obsm["X_umap_original"] = dataset.obsm["X_umap"]
adata.obsm["X_umap"] = dataset.obsm["X_umap"]
sc.pl.umap(adata, color="celltype")
sc.pp.neighbors(adata, use_rep="X")
sc.tl.umap(adata)
adata.obsm["X_umap_new"] = adata.obsm["X_umap"]
keys = "celltype"
sc.pl.umap(adata, color=keys, legend_loc="on data")
sc.pl.umap(adata, color=keys)
adata.var_names = [name.lower().capitalize() for name in adata.var_names.tolist()]

# Assuming `adata` has been loaded with AnnData format
gene_names = adata.var_names
stages = np.unique(adata.obs["stage"])  # Get unique stages
cell_types = np.unique(adata.obs["celltype"])  # Get unique cell types
num_genes = len(gene_names)
num_stages = len(stages)
num_cell_types = len(cell_types)

# Initialize a tensor: genes x stages x cell types x 2 (mean and std)
results_tensor = np.zeros((num_genes, num_stages, num_cell_types, 2))

# Create mappings for stages and cell types to indices for easier reference
stage_to_index = {stage: i for i, stage in enumerate(stages)}
cell_type_to_index = {cell_type: i for i, cell_type in enumerate(cell_types)}

for i, gene in enumerate(gene_names):
    for stage in stages:
        stage_index = stage_to_index[stage]
        for cell_type in cell_types:
            cell_type_index = cell_type_to_index[cell_type]
            # Filter expression data for the current gene, stage, and cell type
            mask = (adata.obs["stage"] == stage) & (adata.obs["celltype"] == cell_type)
            gene_expression = adata[mask, i].X.flatten()
            # Calculate mean and std, and store them in the tensor
            results_tensor[i, stage_index, cell_type_index, 0] = np.mean(gene_expression, where=~np.isnan(gene_expression))
            results_tensor[i, stage_index, cell_type_index, 1] = np.std(gene_expression, where=~np.isnan(gene_expression))

results_tensor = np.nan_to_num(results_tensor)
np.save('tensor_erythroid_gastrulation.npy', results_tensor)


In [None]:


#results_tensor = np.load('tensor_erythroid_gastrulation.npy')

gene_names = adata.var_names
stages = np.unique(adata.obs["stage"])  # Get unique stages
cell_types = np.unique(adata.obs["celltype"])  # Get unique cell types
num_genes = len(gene_names)
num_stages = len(stages)
num_cell_types = len(cell_types)

models_directory = "models_erythroid_v2"
output_directory = "output_directory_v2"
os.makedirs(models_directory, exist_ok=True)  # Create the directory if it does not exist
os.makedirs(output_directory, exist_ok=True)  # Create the directory if it does not exist

gene_names = adata.var_names
stages = np.unique(adata.obs["stage"])  # Get unique stages
cell_types = np.unique(adata.obs["celltype"])  # Get unique cell types
num_genes = len(gene_names)
num_stages = len(stages)
num_cell_types = len(cell_types)

#predicted_states_output = np.zeros((4000, 7, 5, 2))
#predicted_rates_output = np.zeros((4000, 7, 5, 2))
predicted_states_output = {}
predicted_rates_output = {}
from itertools import chain

for gene_ind in chain(range(0, 5), range(2000, 2005)):
#for gene_ind in range(1):
    gene = results_tensor[gene_ind]
    print(f"Working on gene index {gene_ind}")

    flattened_gene = gene.reshape(70)
    flattened_gene.shape
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_states = torch.tensor(flattened_gene[:10]).to(device, dtype=torch.float32)
    ground_truth_states = torch.tensor(flattened_gene).to(device, dtype=torch.float32)
    t = torch.linspace(0, 1, steps=len(stages)).to(device, dtype=torch.float32)


    # Now `initial_state` and `ground_truth_states` are ready for use in the neural ODE model
    class ODEFunc(nn.Module):
        def __init__(self):
            super(ODEFunc, self).__init__()
            self.net = nn.Sequential(
                nn.Linear(10, 50),  # 10 states = 5 means + 5 stds, adjust sizes as needed
                nn.Tanh(),
                nn.Linear(50, 10)   # Output: rate of change for the 10 states
            )

        def forward(self, t, y):
            return self.net(y)

    model = ODEFunc().to(device, dtype=torch.float32)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    lowest_loss = float('inf')
    best_model_state = None
    save_path = os.path.join(models_directory, f"best_model_gene_{gene_ind}.pt")

    for epoch in range(1000):  # Adjust epochs as needed
        optimizer.zero_grad()
        
        pred_states = odeint(model, input_states, t=t)

        pred_states = pred_states.flatten()
        
        loss = torch.mean((pred_states - ground_truth_states) ** 2)
        
        loss.backward()
        optimizer.step()


        # Save model if it has the lowest loss so far
        if loss < lowest_loss:
            lowest_loss = loss
            best_model_state = model.state_dict()
            torch.save(best_model_state, save_path)  # Save the best model state

        #if epoch % 100 == 0:
        #    print(f'Epoch: {epoch}, Loss: {loss}')

    print(f"Training completed. Best model saved to '{save_path}' with loss: {lowest_loss}")

    model = ODEFunc()
    model = load_model_for_gene(gene_ind, model, models_directory).to(device, dtype=torch.float32)

    with torch.no_grad():
        predicted_states = odeint(model, input_states, t).cpu()
        ground_truth_states = torch.tensor(gene.reshape(7,10))
        predicted_rates = model(t, predicted_states.to(device, dtype=torch.float32)).cpu()
        #predicted_states_output[gene_ind] = predicted_states.numpy().reshape(7,5,2)
        #predicted_rates_output[gene_ind] = predicted_rates.numpy().reshape(7,5,2)
        predicted_states_output[gene_ind] = predicted_states.numpy().reshape(7,5,2)
        predicted_rates_output[gene_ind] = predicted_rates.numpy().reshape(7,5,2)

        print(predicted_states_output)
        
        
        gc.collect()
        torch.cuda.empty_cache()

        with open(f'{output_directory}/predicted_states_output.pkl', 'wb') as f:
            pickle.dump(predicted_states_output, f)

        with open(f'{output_directory}/predicted_rates_output.pkl', 'wb') as g:
            pickle.dump(predicted_rates_output, g)

        

        # Optionally, print a message indicating saving has occurred
        print(f"Saved predicted states and rates up to gene index {gene_ind}.")


#np.save('predicted_states_output.npy', predicted_states_output)
#np.save('predicted_rates_output.npy', predicted_rates_output)