kernel moscot

# Set up

In [None]:
%run /work/DevM_analysis/utils/colors.py

In [None]:
import warnings
from typing import List, Literal, Optional, Tuple

import moscot as mt
import moscot.plotting as mtp
from moscot.problems.time import TemporalProblem
from tqdm.std import TqdmWarning

import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt

import muon as mu
import mudata as md
import scanpy as sc

import os

warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", TqdmWarning)

In [None]:
output_dir = "/work/DevM_analysis/03.trajectory/Moscot_Temporal/"
new_anno = "anno_wnn_v51"

# Load data

In [None]:
adata = sc.read_h5ad("/work/DevM_analysis/02.abundance/Milo_FL_PCW250401/data/FL_multiVI.h5ad")
adata

In [None]:
sc.pl.embedding(adata, basis="X_umap", color=[new_anno], palette = annot2col)

In [None]:
sc.pl.embedding(adata, basis="X_umap", color=["donorID"])

# Prepare the TemporalProblem

In [None]:
tp = TemporalProblem(adata)

In [None]:
adata.obs["PCW"] = adata.obs["PCW"].astype("str")
adata.obs["PCW"] = adata.obs["PCW"].astype("Int64").astype("category")
adata.obs["PCW"]

In [None]:
tp = tp.prepare(time_key="PCW", joint_attr="X_multiVI")

# Solve

In [None]:
tp = tp.solve(epsilon=1e-3, scale_cost="mean", max_iterations=1e7)

# Identifying ancestors and descendants

## HEATMAP

In [None]:
annot2col = {
  "HSC" : "#E41A1C",
  "GP" : "#E0FFFF",
  "Granulocyte" : "#B3CDE3",
  "MEMP-t" : "#E6AB02",
  "MEMP" : "#FF7F00",
  "MEP" : "#CD661D",
  "MEMP-Mast-Ery" : "#FDCDAC",
  "MEMP-Ery" : "#E9967A",
  "Early-Ery" : "#CD5555",
  "Late-Ery" : "#8B0000",
  "MEMP-MK" : "#663C1F",
  "MK" : "#40E0D0",
  "MastP-t" : "#1E90FF",
  "MastP" : "#1F78B4",
  "Mast" : "#253494",
  "MDP" : "#E6F5C9",
  "Monocyte" : "#005A32",
  "Kupffer" : "#00EE00",
  "cDC1" : "#B3DE69",
  "cDC2" : "#ADFF2F",
  "pDC" : "#4DAF4A",
  "ASDC" : "#CDC673",
  "LMPP" : "#FFF2AE",
  "LP" : "#FFD92F",
  "Cycling-LP" : "#FFFF33",
  "PreProB" : "#FFF0F5",
  "ProB-1" : "#FFB5C5",
  "ProB-2" : "#E78AC3",
  "Large-PreB" : "#CD1076",
  "Small-PreB" : "#FF3E96",
  "IM-B" : "#FF00FF",
  "NK" : "#A020F0",
  "ILCP" : "#49006A",
  "T" : "#984EA3"
}
o = []
for name, value in annot2col.items():
    o.append(name)

order = []
for i in o:
    if i in adata.obs[new_anno].cat.categories.values:
        order.append(i)

In [None]:
## Adapt map so that 0s are in white 
col_map = plt.get_cmap('YlOrRd', 256)
newcolors = col_map(np.linspace(0, 1, 256))
white_code = np.array([256/256, 256/256, 256/256, 1])
newcolors[:1, :] = white_code
newcmap = matplotlib.colors.ListedColormap(newcolors)
newcmap

In [None]:
order

In [None]:
def get_cell_transition(source, target, forward, 
                        groups_oi = order, suffix_key = "", min_cells = 10,
                        keep = order, save_folder = os.path.join(output_dir, "plot")):
    group_source = tp.adata.obs.loc[tp.adata.obs["PCW"] == source ,"anno_wnn_v51"].value_counts()
    group_target = tp.adata.obs.loc[tp.adata.obs["PCW"] == target ,"anno_wnn_v51"].value_counts()
        
    group_source = [i for i in order if i in group_source[group_source >= min_cells].index.values]
    group_target = [i for i in order if i in group_target[group_target >= min_cells].index.values]

    if forward: 
        group_source = [i for i in group_source if i in groups_oi]
        group_target = [i for i in group_target if i in keep]
        direction = "forward"
    else:
        group_target = [i for i in group_target if i in groups_oi]
        group_source = [i for i in group_source if i in keep]
        direction = "backward"

    key = str(source) + "_to_" + str(target) + "_" + direction + suffix_key
    print(key)
    tp.cell_transition(
            source = source, target = target, 
            source_groups = {"anno_wnn_v51": group_source}, 
            target_groups = {"anno_wnn_v51": group_target}, #groups used for aggregation
            forward=forward, 
            key_added= key)

    fig = mtp.cell_transition(
        tp,
        fontsize=6,
        figsize=(8, 8),
        return_fig=True,
        #ax=axes[0],
        key=key,
        cmap = newcmap
    )
    
    fig.savefig(os.path.join(save_folder, "Heatmap_cell-transition_" + key + ".pdf"), 
                bbox_inches = 'tight') 

In [None]:
params = [[5, 6],[6, 7],[7, 8],[8, 9],[9, 10],[10,11],[11,12],[12,13],[13,14],[14,15],[15,16],[16,17],[17,18]]

In [None]:
for par in params:
    get_cell_transition(source = par[0], target = par[1], forward = True, suffix_key = "",
                        groups_oi = order, keep = order, save_folder = os.path.join(output_dir, "plot/All-cell-types"))

In [None]:
for par in params:
    get_cell_transition(source = par[0], target = par[1], forward = False, suffix_key = "",
                        groups_oi = order, keep = order, save_folder = os.path.join(output_dir, "plot/All-cell-types"))

## SANKEY

In [None]:
cluster_subset = [
    "HSC",
                     "LP", "Cycling-LP",
                     "PreProB", "ProB-1", "ProB-2", "Large-PreB", "Small-PreB", "IM-B"
]


tp.sankey(
    source=5,
    target=6,
    source_groups={new_anno: cluster_subset},
    target_groups={new_anno: order},
    threshold=0.05,
   # order_annotations=cluster_subset[::-1],
    normalize=True,
)

In [None]:
mtp.sankey(tp, dpi=100, figsize=(5, 5), fontsize=10, interpolate_color=True)

## Save

In [None]:
tp.adata.write_h5ad(f"{output_dir}data/adata_FL_allPCW_with-Moscot-res.h5ad")

In [None]:
for s in range(5, 18):
    mr = tp.adata.uns["moscot_results"]["cell_transition"][str(s) + "_to_" + str(s+1) + "_forward"]["transition_matrix"]
    mr.to_csv(f"{output_dir}data/Moscot_allFL_" + str(s) + "_to_" + str(s+1) + "_forward.csv")

In [None]:
for s in range(5, 18):
    mr = tp.adata.uns["moscot_results"]["cell_transition"][str(s) + "_to_" + str(s+1) + "_backward"]["transition_matrix"]
    mr.to_csv(f"{output_dir}data/Moscot_allFL_" + str(s) + "_to_" + str(s+1) + "_backward.csv")