### Gene networks

In [None]:
import os
import torch
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

# for flex attention
import torch._dynamo
torch._dynamo.config.suppress_errors = True

DEVICE = torch.device('cuda:7')
sc.set_figure_params(figsize=(4, 4))

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext, JacobianContext

In [None]:
ROOT_PATH = "/mnt/cellariumgpt-xfer/mb-ml-dev-vm"
CHECKPOINT_PATH = "/mnt/cellariumgpt-xfer/100M_long_run/run_001/lightning_logs/version_1/checkpoints/epoch=2-step=252858.ckpt"
REF_ADATA_PATH = os.path.join(ROOT_PATH, "data", "extract_0.h5ad")
GENE_INFO_PATH = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=CHECKPOINT_PATH,
    ref_adata_path=REF_ADATA_PATH,
    gene_info_tsv_path=GENE_INFO_PATH,
    device=DEVICE,
    attention_backend="mem_efficient"
)

In [None]:
ROOT_PATH = "/mnt/cellariumgpt-xfer/mb-ml-dev-vm"

# dataset_name = "extract_40__0"  # neuron
dataset_name = "extract_100__3"  # CD8+ T cell
# dataset_name = "extract_100__1"  # monocyte
# dataset_name = "extract_40__1"  # oligo
# dataset_name = "extract_50__6"  # CM
# dataset_name = "extract_50__6"  # CM

adata_path = os.path.join(
    ROOT_PATH, "data", "cellariumgpt_validation", "metacells", f"{dataset_name}.h5ad")

jacobian_pt_path = os.path.join(
    ROOT_PATH, "data", "cellariumgpt_validation", "metacells", "matmul_highest_10k", f"{dataset_name}__jacobian__marginal_mean.pt")

gene_info_tsv_path = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

OUTPUT_ROOT_PATH = os.path.join(
    ROOT_PATH, "data", "cellariumgpt_validation", "metacells", "matmul_highest_10k", "analysis")

os.makedirs(OUTPUT_ROOT_PATH, exist_ok=True)

In [None]:
jac_ctx = JacobianContext.from_old_jacobian_pt_dump(jacobian_pt_path, adata_path, gene_info_tsv_path)
print(jac_ctx)

In [None]:
jac_ctx.process(
    jacobian_normalization_strategy="mean",
    feature_normalization_strategy="query_z_score",
    query_response_amp_min_pct=10,
    min_prompt_gene_tpm=5,
    min_query_gene_tpm=5)

In [None]:
jac_ctx.compute_adjacency_matrix(
    adjacency_strategy="positive_correlation",
    n_neighbors=10,
    beta=6.,
    self_loop=False)

In [None]:
jac_ctx.compute_leiden_communites(
    resolution=5.0)

In [None]:
len(np.unique(jac_ctx.leiden_membership))

In [None]:
jac_ctx.compute_spectral_dimension(n_lambda_for_estimation=10)

In [None]:
fig, ax = plt.subplots()

jac_ctx.plot_spectral_dimension(ax=ax)

### Embedding

In [None]:
jac_ctx.make_mde_embedding(device=DEVICE)

In [None]:
snap_n_gene_symbols = [
    'GAP43',
    'NRXN3',
    'HOMER1',
    'IL1RAPL2',
    'EPHA3',
    'RIMS1',
    'SV2B',
    'TRIM9',
    'SVOP',
    'RPH3A',
    'SYT12',
    'SYT1',
    'R3HDM2',
    'PDE4B',
    'DCC',
    'SLC4A10',
    'DNM3',
    'GRM1',
    'EGR4',
    'JUNB',
    'TFDP2'
]

snap_n_gene_symbols = [x for x in snap_n_gene_symbols if x in jac_ctx.query_gene_symbols]
snap_n_gene_ids = [jac_ctx.gene_symbol_to_gene_id_map[x] for x in snap_n_gene_symbols]

muscle_gene_symbols = [
    'TTN',
    'MYL3',
    'MYL4',
    'MYL7',
    'TNNC1',
    'TNNI1',
]

muscle_gene_symbols = [x for x in muscle_gene_symbols if x in jac_ctx.query_gene_symbols]
muscle_gene_ids = [jac_ctx.gene_symbol_to_gene_id_map[x] for x in muscle_gene_symbols]

def get_gene_familities(jac_ctx: JacobianContext, prefix_list: list[str]) -> tuple[list[str], list[str]]:
    _gene_symbols = [gene_symbol for prefix in prefix_list for gene_symbol in jac_ctx.query_gene_symbols if gene_symbol.startswith(prefix)]
    gene_ids = [jac_ctx.gene_symbol_to_gene_id_map[gene_symbol] for gene_symbol in _gene_symbols]
    gene_symbols = [jac_ctx.gene_id_to_gene_symbol_map[gene_id] for gene_id in gene_ids]
    return gene_ids, gene_symbols

mito_gene_ids, mito_gene_symbols = get_gene_familities(jac_ctx, ["MT-"])
ribo_gene_ids, ribo_gene_symbols = get_gene_familities(jac_ctx, ["RPS", "RPL"])
hla_gene_ids, hla_gene_symbols = get_gene_familities(jac_ctx, ["HLA"])
ifi_gene_ids, ifi_gene_symbols = get_gene_familities(jac_ctx, ["IFI"])

highlight_gene_sets = {
    "Mito": (mito_gene_ids, mito_gene_symbols, 'red'),
    "Ribo": (ribo_gene_ids, ribo_gene_symbols, 'blue'),
    # "SNAP-n": (snap_n_gene_ids, snap_n_gene_symbols, 'green'),
    # "HLA": (hla_gene_ids, hla_gene_symbols, 'green'),
    # "IFI": (ifi_gene_ids, ifi_gene_symbols, 'orange'),
    "Muscle": (muscle_gene_ids, muscle_gene_symbols, 'purple'),
}

# disable
# highlight_gene_sets = None

In [None]:
jac_ctx.plot_mde_embedding(highlight_gene_sets=highlight_gene_sets)

### Batch processing (basic)

In [None]:
ROOT_PATH = "/mnt/cellariumgpt-xfer/mb-ml-dev-vm"

# get the dataset names
jacobian_pt_root_path = os.path.join(
    ROOT_PATH, "data", "cellariumgpt_validation", "metacells", "matmul_highest_10k")
adata_root_path = os.path.join(
    ROOT_PATH, "data", "cellariumgpt_validation", "metacells")

# get the list of files inside jacobian_pt_root_path
jacobian_pt_files = os.listdir(jacobian_pt_root_path)
jacobian_pt_files = [f for f in jacobian_pt_files if f.endswith(".pt")]
jacobian_pt_paths = [os.path.join(jacobian_pt_root_path, f) for f in jacobian_pt_files]

# extract the file names from the full path
# jacobian_pt_files = [os.path.splitext(f)[0] for f in jacobian_pt_files]
suffix = "__jacobian__marginal_mean.pt"

# remove the suffix to get dataset names
dataset_names = [f.replace(suffix, "") for f in jacobian_pt_files]

# generate adata paths
adata_paths = [
    os.path.join(adata_root_path, f"{dataset_name}.h5ad")
    for dataset_name in dataset_names]
               
gene_info_tsv_path = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

output_root_path = os.path.join(jacobian_pt_root_path, "analysis")

os.makedirs(output_root_path, exist_ok=True)

In [None]:
def generate_genes_to_highlight(jac_ctx: JacobianContext) -> dict[str, tuple[list[str], list[str], str]]:
    
    snap_n_gene_symbols = [
        'GAP43',
        'NRXN3',
        'HOMER1',
        'IL1RAPL2',
        'EPHA3',
        'RIMS1',
        'SV2B',
        'TRIM9',
        'SVOP',
        'RPH3A',
        'SYT12',
        'SYT1',
        'R3HDM2',
        'PDE4B',
        'DCC',
        'SLC4A10',
        'DNM3',
        'GRM1',
        'EGR4',
        'JUNB',
        'TFDP2'
    ]

    muscle_gene_symbols = [
        'TTN',
        'MYL3',
        'MYL4',
        'MYL7',
        'TNNC1',
        'TNNI1',
    ]

    snap_n_gene_symbols = [x for x in snap_n_gene_symbols if x in jac_ctx.query_gene_symbols]
    snap_n_gene_ids = [jac_ctx.gene_symbol_to_gene_id_map[x] for x in snap_n_gene_symbols]

    muscle_gene_symbols = [x for x in muscle_gene_symbols if x in jac_ctx.query_gene_symbols]
    muscle_gene_ids = [jac_ctx.gene_symbol_to_gene_id_map[x] for x in muscle_gene_symbols]

    def get_gene_familities(jac_ctx: JacobianContext, prefix_list: list[str]) -> tuple[list[str], list[str]]:
        _gene_symbols = [
            gene_symbol
            for prefix in prefix_list
            for gene_symbol in jac_ctx.query_gene_symbols
            if gene_symbol.startswith(prefix)]
        gene_ids = [jac_ctx.gene_symbol_to_gene_id_map[gene_symbol] for gene_symbol in _gene_symbols]
        gene_symbols = [jac_ctx.gene_id_to_gene_symbol_map[gene_id] for gene_id in gene_ids]
        return gene_ids, gene_symbols

    mito_gene_ids, mito_gene_symbols = get_gene_familities(jac_ctx, ["MT-"])
    ribo_gene_ids, ribo_gene_symbols = get_gene_familities(jac_ctx, ["RPL"])
    hla_gene_ids, hla_gene_symbols = get_gene_familities(jac_ctx, ["HLA"])
    ifi_gene_ids, ifi_gene_symbols = get_gene_familities(jac_ctx, ["IFI"])

    highlight_gene_sets = {
        "Mito": (mito_gene_ids, mito_gene_symbols, 'red'),
        "Ribo": (ribo_gene_ids, ribo_gene_symbols, 'blue'),
    }

    return highlight_gene_sets

In [None]:
import pickle
import pymde
from tqdm.notebook import tqdm

n_datasets = len(dataset_names)

for i_dataset in tqdm(range(n_datasets)):

    jacobian_pt_path = jacobian_pt_paths[i_dataset]
    adata_path = adata_paths[i_dataset]

    # load the jacobian
    jac_ctx = JacobianContext.from_old_jacobian_pt_dump(jacobian_pt_path, adata_path, gene_info_tsv_path)
    print(jac_ctx)

    # process
    jac_ctx.process(
        jacobian_normalization_strategy="mean",
        feature_normalization_strategy="query_z_score",
        query_response_amp_min_pct=10,
        min_prompt_gene_tpm=10,
        min_query_gene_tpm=10)

    # adjacency matrix
    jac_ctx.compute_adjacency_matrix(
        adjacency_strategy="positive_correlation",
        n_neighbors=10,
        beta=6.,
        self_loop=False)

    # detect communities
    jac_ctx.compute_leiden_communites(
        resolution=5.0,
        min_community_size=2)

    # compute the spectral dimension
    jac_ctx.compute_spectral_dimension(n_lambda_for_estimation=10)

    # make the spectral dimension plot and save
    fig, ax = plt.subplots()
    jac_ctx.plot_spectral_dimension(ax=ax)
    fig.savefig(
        os.path.join(output_root_path, f"{dataset_names[i_dataset]}__spectral_dimension.png"),
        dpi=300,
        bbox_inches="tight")
    plt.close(fig)

    # generate embedding
    jac_ctx.make_mde_embedding(
            n_neighbors=7,
            repulsive_fraction=10,
            attractive_penalty=pymde.penalties.Log1p,
            repulsive_penalty=pymde.penalties.Log,
            device=DEVICE)

    # make the embedding plot and save
    highlight_gene_sets = generate_genes_to_highlight(jac_ctx)
    fig = jac_ctx.plot_mde_embedding(highlight_gene_sets=highlight_gene_sets)
    fig.write_image(
        os.path.join(output_root_path, f"{dataset_names[i_dataset]}__embedding.png"),
        scale=2)

    # make a dataframe of leiden memberships
    community_df = pd.DataFrame({
        'gene_ids': jac_ctx.query_var_names,
        'gene_symbols': jac_ctx.query_gene_symbols,
        'leiden_membership': jac_ctx.leiden_membership})
    community_df.to_csv(
        os.path.join(output_root_path, f"{dataset_names[i_dataset]}__leiden_membership.csv"),
        index=False)
    
    # pickle
    with open(os.path.join(output_root_path, f"{dataset_names[i_dataset]}__jac_ctx.pkl"), "wb") as f:
        pickle.dump(jac_ctx, f)

### Batch processing (GSEA)

In [None]:
# n_leiden = len(np.unique(jac_ctx.leiden_membership))
# jac_ctx.leiden_to_query_indices = {
#     leiden_id: np.where(jac_ctx.leiden_membership == leiden_id)[0]
#     for leiden_id in range(n_leiden)
# }

In [None]:
# from sklearn.metrics import silhouette_samples

# jac_ctx.silhouette_samples_q = silhouette_samples(jac_ctx.z_qp, jac_ctx.leiden_membership)

In [None]:
# jac_ctx.leiden_sillohette_coefficients = []
# for leiden_id in range(n_leiden):
#     indices = jac_ctx.leiden_to_query_indices[leiden_id]
#     scores = jac_ctx.silhouette_samples_q[indices]
#     mean_scores = np.mean(scores)
#     jac_ctx.leiden_sillohette_coefficients.append(mean_scores)

In [None]:
# from tqdm.notebook import tqdm
# import gseapy as gp


# gene_set_path = os.path.join(ROOT_PATH, "data", "gmt", "c5.go.v2024.1.Hs.symbols.gmt")

# for leiden_id in tqdm(range(n_leiden)):

#     scores = []
#     s = set(jac_ctx.leiden_to_query_indices[leiden_id])
#     for q in range(len(jac_ctx.query_gene_symbols)):
#         if q in s:
#             scores.append(1)
#         else:
#             scores.append(0)
        
#     gene_list = pd.DataFrame({
#         "gene_symbol": jac_ctx.query_gene_symbols,
#         "score": scores}
#     )

#     output_dir = os.path.join(OUTPUT_ROOT_PATH, dataset_name, f"leiden_cluster_{leiden_id}")
#     gsea_results = gp.prerank(
#         rnk=gene_list,
#         gene_sets=gene_set_path,
#         min_size=1,
#         max_size=5000,
#         permutation_num=1000,
#         graph_num=90,
#         outdir=output_dir
#     )


### Studying gene programs and their relationships

In [None]:
ROOT_PATH = "/mnt/cellariumgpt-xfer/mb-ml-dev-vm"

# get the dataset names
jacobian_pt_root_path = os.path.join(
    ROOT_PATH, "data", "cellariumgpt_validation", "metacells", "matmul_highest_10k")
adata_root_path = os.path.join(
    ROOT_PATH, "data", "cellariumgpt_validation", "metacells")

# get the list of files inside jacobian_pt_root_path
jacobian_pt_files = os.listdir(jacobian_pt_root_path)
jacobian_pt_files = [f for f in jacobian_pt_files if f.endswith(".pt")]
jacobian_pt_paths = [os.path.join(jacobian_pt_root_path, f) for f in jacobian_pt_files]

# extract the file names from the full path
# jacobian_pt_files = [os.path.splitext(f)[0] for f in jacobian_pt_files]
suffix = "__jacobian__marginal_mean.pt"

# remove the suffix to get dataset names
dataset_names = [f.replace(suffix, "") for f in jacobian_pt_files]

# generate adata paths
adata_paths = [
    os.path.join(adata_root_path, f"{dataset_name}.h5ad")
    for dataset_name in dataset_names]

output_root_path = os.path.join(jacobian_pt_root_path, "analysis")
leiden_result_paths = [
    os.path.join(output_root_path, f"{dataset_names[i_dataset]}__leiden_membership.csv")
    for i_dataset in range(len(dataset_names))]

#### Make a manifest of all available cell types

In [None]:
import pickle
from tqdm.notebook import tqdm
from collections import defaultdict

cell_manifest = defaultdict(list)
leiden_dict = dict()

# vague cell types
skip_cell_type_set = {'unknown', 'leukocyte', 'myeloid cell', 'lymphocyte'}

for i_dataset in tqdm(range(len(dataset_names))):

    adata_path = adata_paths[i_dataset]

    # load adata
    adata = sc.read_h5ad(adata_path)

    # load jac ctx
    with open(os.path.join(output_root_path, f"{dataset_names[i_dataset]}__jac_ctx.pkl"), "rb") as f:
        jac_ctx = pickle.load(f)
    
    cell_type = adata.obs["cell_type"].values[0]

    if cell_type in skip_cell_type_set:
        continue

    cell_manifest["dataset_name"].append(dataset_names[i_dataset])
    cell_manifest["cell_type"].append(cell_type)
    cell_manifest["tissue"].append(adata.obs["tissue"].values[0])
    cell_manifest["disease"].append(adata.obs["disease"].values[0])
    cell_manifest["development_stage"].append(adata.obs["development_stage"].values[0])
    cell_manifest["assay"].append(adata.obs["assay"].values[0])
    cell_manifest["suspension_type"].append(adata.obs["suspension_type"].values[0])

    # useful statistics
    cell_manifest["n_umi"].append(adata.obs["total_mrna_umis"].values.mean())
    cell_manifest["n_expr_genes"].append((adata.X.sum(0) > 0).sum())
    cell_manifest["n_graph_query_genes"].append(len(jac_ctx.query_var_names))
    cell_manifest["n_graph_prompt_genes"].append(len(jac_ctx.prompt_var_names))
    cell_manifest["n_leiden_communities"].append(np.unique(jac_ctx.leiden_membership).size)
    cell_manifest["spectral_dim"].append(jac_ctx.spectral_dim)

    # process leiden memberships
    leiden_ids = np.unique(jac_ctx.leiden_membership)
    leiden_ids = [x for x in leiden_ids if x != -1]  # remove -1 labels (unassigned)
    _leiden_dict = dict()
    for leiden_id in leiden_ids:
        indices = np.where(jac_ctx.leiden_membership == leiden_id)[0].tolist()
        _leiden_dict[leiden_id] = [jac_ctx.query_gene_symbols[i] for i in indices]
    leiden_dict[dataset_names[i_dataset]] = _leiden_dict

cell_manifest_df = pd.DataFrame(cell_manifest).sort_values("dataset_name").set_index("dataset_name", drop=True)
cell_manifest_df['main_dataset_name'] = cell_manifest_df.index.str.split("__").str[0]

In [None]:
cell_manifest_df.head()

In [None]:
# show all cells
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(cell_manifest_df.head(10))

In [None]:
import json

cell_manifest_df.to_csv(os.path.join(output_root_path, "cell_manifest.csv"))

with open(os.path.join(output_root_path, "leiden_dict.pkl"), "wb") as f:
    pickle.dump(leiden_dict, f)

#### Jaccard analysis and hierarchical clustering

In [None]:
# included_datasets = cell_manifest_df[cell_manifest_df["main_dataset_name"] == "extract_100"].index.values
included_datasets = None

In [None]:
# make a flat leiden dict
dataset_leiden_id_to_flat_map = dict()
all_communities_flat_gene_lists = []
all_communities_flat_gene_sets = []
community_labels_tuple = []
community_labels = []

min_community_size = 5

counter = 0
for dataset_name, dataset_leiden_dict in leiden_dict.items():
    if included_datasets is not None:
        if dataset_name not in included_datasets:
            continue
    dataset_leiden_id_to_flat_map[dataset_name] = dict()
    cell_type = cell_manifest_df.loc[dataset_name].cell_type
    for leiden_id, gene_list in dataset_leiden_dict.items():
        if len(gene_list) < min_community_size:
            continue
        all_communities_flat_gene_lists.append(gene_list)
        all_communities_flat_gene_sets.append(set(gene_list))
        dataset_leiden_id_to_flat_map[dataset_name][leiden_id] = counter
        community_labels_tuple.append((cell_type, leiden_id))
        community_labels.append(f"{cell_type}, {leiden_id}")
        counter += 1

In [None]:
# calculate the Jaccard index between all pairs of communities
n_communities = len(all_communities_flat_gene_lists)
jaccard_matrix = np.zeros((n_communities, n_communities))

for i in tqdm(range(n_communities)):
    for j in range(n_communities):
        num = len(all_communities_flat_gene_sets[i].intersection(all_communities_flat_gene_sets[j]))
        den = len(all_communities_flat_gene_sets[i].union(all_communities_flat_gene_sets[j]))
        jaccard_matrix[i, j] = num / den

In [None]:
# import numpy as np
# import seaborn as sns
# import matplotlib.pyplot as plt
# import pandas as pd
# from scipy.spatial.distance import squareform

# import scipy.cluster.hierarchy as hc

# linkage = hc.linkage(squareform(1. - jaccard_matrix), method='ward')

# # Convert the matrix to a DataFrame for easier labeling (optional)
# labels = [f"Set {i+1}" for i in range(jaccard_matrix.shape[0])]
# df = pd.DataFrame(jaccard_matrix, index=labels, columns=labels)

# # Use Seaborn's clustermap with precomputed distance
# sns.set(style="white")

# cluster_map = sns.clustermap(
#     df,
#     row_linkage=linkage,
#     col_linkage=linkage,
#     linewidths=0,
#     cmap='Reds',
#     figsize=(10, 10),        # Size of the figure
#     dendrogram_ratio=(0.2, 0.2),  # Adjust dendrogram size
#     cbar_pos=(0.02, 0.8, 0.03, 0.18)  # Adjust color bar position
# )

# plt.title('Clustered Heatmap with Precomputed Similarity', fontsize=14)


In [None]:
import plotly.graph_objects as go
import plotly.figure_factory as ff
import numpy as np
from scipy.spatial.distance import squareform
from scipy.cluster import hierarchy as hc

# Compute linkage for clustering
linkage = hc.linkage(squareform(1.0 - jaccard_matrix), method='ward')

# Create dendrograms
dendro_row = hc.dendrogram(linkage, no_plot=True)
dendro_col = hc.dendrogram(linkage, no_plot=True)

# Reorder the matrix based on dendrogram leaves
heatmap_order = dendro_row['leaves']
clustered_matrix = jaccard_matrix[np.ix_(heatmap_order, heatmap_order)]

# Labels for the reordered matrix
reordered_labels = [community_labels[i] for i in heatmap_order]

# Create main dendrogram
fig = ff.create_dendrogram(
    clustered_matrix,
    orientation='bottom',
    labels=reordered_labels
)
fig.for_each_trace(lambda trace: trace.update(visible=False, line=dict(width=1)))

for i in range(len(fig['data'])):
    fig['data'][i]['yaxis'] = 'y2'

# Create side dendrogram
dendro_side = ff.create_dendrogram(
    clustered_matrix,
    orientation='right',
    labels=reordered_labels
)
dendro_side.for_each_trace(lambda trace: trace.update(line=dict(width=1)))

for i in range(len(dendro_side['data'])):
    dendro_side['data'][i]['xaxis'] = 'x2'

# Add side dendrogram data to the figure
for data in dendro_side['data']:
    fig.add_trace(data)

# Create heatmap
clustered_matrix_no_diag = clustered_matrix.copy()

heatmap = go.Heatmap(
    x=reordered_labels,
    y=reordered_labels,
    z=clustered_matrix,
    colorscale='Blues',
    zmin=0,
    zmax=0.2,
)

heatmap['x'] = fig['layout']['xaxis']['tickvals']
heatmap['y'] = dendro_side['layout']['yaxis']['tickvals']

# Add heatmap data to the figure
fig.add_trace(heatmap)

# Update layout
fig.update_layout(
    width=1000,
    height=1200,
    showlegend=False,
    hovermode='closest',
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    xaxis_tickfont = dict(color = 'rgba(0,0,0,0)'),
)

# Configure x-axis
fig.update_layout(
    xaxis={
        'domain': [.15, 1],
        'mirror': False,
        'showgrid': False,
        'showline': False,
        'zeroline': False,
        'ticks': ""
    },
    xaxis2={
        'domain': [0, .15],
        'mirror': False,
        'showgrid': False,
        'showline': False,
        'zeroline': False,
        'showticklabels': False,
        'ticks': ""
    }
)

# Configure y-axis
fig.update_layout(
    yaxis={
        'domain': [0, 1],
        'mirror': False,
        'showgrid': False,
        'showline': False,
        'zeroline': False,
        'showticklabels': False,
        'ticks': ""
    },
    yaxis2={
        'domain': [.825, .975],
        'mirror': False,
        'showgrid': False,
        'showline': False,
        'zeroline': False,
        'showticklabels': False,
        'ticks': ""
    }
)

fig.show()


In [None]:
cell_manifest_df.loc[included_datasets]["cell_type"].values.tolist()

In [None]:
coarse_map = {
    'CD14-positive monocyte': 'Mono',
    'naive thymus-derived CD4-positive, alpha-beta T cell': 'T',
    'CD4-positive, alpha-beta cytotoxic T cell': 'T',
    'mucosal invariant T cell': 'T',
    'CD4-positive, alpha-beta T cell': 'T',
    'CD1c-positive myeloid dendritic cell': 'DC',
    'monocyte': 'Mono',
    'platelet': 'Platelet',
    'naive B cell': 'B',
    'gamma-delta T cell': 'T',
    'central memory CD4-positive, alpha-beta T cell': 'T',
    'CD8-positive, alpha-beta cytotoxic T cell': 'T',
    'CD16-positive, CD56-dim natural killer cell, human': 'NK',
    'CD14-low, CD16-positive monocyte': 'Mono',
    'CD8-positive, alpha-beta memory T cell': 'T',
    'effector memory CD4-positive, alpha-beta T cell': 'T',
    'naive thymus-derived CD8-positive, alpha-beta T cell': 'T',
    'memory B cell': 'B',
}

coarse_reordered_labels = [
    coarse_map[', '.join(x.split(', ')[:-1])]
    for x in reordered_labels]

import colorcet as cc
all_coarse_labels = list(set(coarse_reordered_labels))
coarse_label_to_color_map = {label: cc.glasbey_light[i] for i, label in enumerate(all_coarse_labels)}
reordered_label_colors = [coarse_label_to_color_map[label] for label in coarse_reordered_labels]

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

# Convert hex colors to an RGB array
colors_rgb = [mcolors.hex2color(color) for color in reordered_label_colors]

# Reshape the RGB array into a 1-row image
color_bar = np.array(colors_rgb).reshape(1, -1, 3)

# Display the color bar using imshow
plt.figure(figsize=(10, 1))  # Adjust the figure size as needed
plt.imshow(color_bar, aspect='auto')
plt.axis('off')  # Turn off axes
plt.show()


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Create legend handles
legend_handles = [
    mpatches.Patch(color=color, label=label)
    for label, color in coarse_label_to_color_map.items()
]

# Create the plot for the legend
plt.figure(figsize=(6, 2))  # Adjust the size if needed
plt.legend(handles=legend_handles, loc='center', frameon=False)
plt.axis('off')  # Hide axes
plt.show()


#### Studying spectral dimension across datasets

In [None]:
cell_manifest_df

In [None]:
fig, ax = plt.subplots()

x_col = "n_umi"
y_col = "spectral_dim"

x_label = "Metacell number of UMIs"
y_label = "Metacell spectral dimension"

x_values = cell_manifest_df[x_col]
y_values = cell_manifest_df[y_col]

# linear regression with R^2
slope, intercept = np.polyfit(x_values, y_values, 1)
r_squared = np.corrcoef(x_values, y_values)[0, 1] ** 2
ax.plot(x_values, slope * x_values + intercept, color='red', label=f"R^2 = {r_squared:.2f}")

plt.scatter(
    x_values,
    y_values,
)

ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.legend()

In [None]:
fig, ax = plt.subplots()

x_col = "n_leiden_communities"
y_col = "spectral_dim"

x_label = "Metacell # of leiden communities"
y_label = "Metacell spectral dimension"

x_values = cell_manifest_df[x_col]
y_values = cell_manifest_df[y_col]

# linear regression with R^2
slope, intercept = np.polyfit(x_values, y_values, 1)
r_squared = np.corrcoef(x_values, y_values)[0, 1] ** 2
ax.plot(x_values, slope * x_values + intercept, color='red', label=f"R^2 = {r_squared:.2f}")

plt.scatter(
    x_values,
    y_values,
)

ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.legend()

In [None]:
cell_manifest_df.columns

In [None]:
fig, ax = plt.subplots()

x_col = "n_graph_query_genes"
y_col = "spectral_dim"

x_label = "Metacell # of genes in the graph"
y_label = "Metacell spectral dimension"

x_values = cell_manifest_df[x_col]
y_values = cell_manifest_df[y_col]

# linear regression with R^2
slope, intercept = np.polyfit(x_values, y_values, 1)
r_squared = np.corrcoef(x_values, y_values)[0, 1] ** 2
ax.plot(x_values, slope * x_values + intercept, color='red', label=f"R^2 = {r_squared:.2f}")

plt.scatter(
    x_values,
    y_values,
)

ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.legend()

In [None]:
fig, ax = plt.subplots()

x_col = "n_graph_prompt_genes"
y_col = "spectral_dim"

x_label = "Metacell # of prompt genes"
y_label = "Metacell spectral dimension"

x_values = cell_manifest_df[x_col]
y_values = cell_manifest_df[y_col]

# linear regression with R^2
slope, intercept = np.polyfit(x_values, y_values, 1)
r_squared = np.corrcoef(x_values, y_values)[0, 1] ** 2
ax.plot(x_values, slope * x_values + intercept, color='red', label=f"R^2 = {r_squared:.2f}")

plt.scatter(
    x_values,
    y_values,
)

ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.legend()

In [None]:
fig, ax = plt.subplots()

x_col = "n_expr_genes"
y_col = "spectral_dim"

x_label = "Metacell # of expressed genes"
y_label = "Metacell spectral dimension"

x_values = cell_manifest_df[x_col]
y_values = cell_manifest_df[y_col]

# linear regression with R^2
slope, intercept = np.polyfit(x_values, y_values, 1)
r_squared = np.corrcoef(x_values, y_values)[0, 1] ** 2
ax.plot(x_values, slope * x_values + intercept, color='red', label=f"R^2 = {r_squared:.2f}")

plt.scatter(
    x_values,
    y_values,
)

ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.legend()

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Calculate mean spectral_dim for each cell_type and sort in ascending order
mean_spectral_dim = cell_manifest_df.groupby('cell_type')['spectral_dim'].mean().sort_values()

# Set cell_type as ordered categorical based on sorted spectral_dim
cell_manifest_df['cell_type'] = pd.Categorical(
    cell_manifest_df['cell_type'],
    categories=mean_spectral_dim.index,
    ordered=True
)

# Create scatter plot
plt.figure(figsize=(12, 10))
scatter = sns.scatterplot(
    x='cell_type',
    y='spectral_dim',
    hue='main_dataset_name',
    data=cell_manifest_df,
    s=100
)

# Customize the plot
plt.title('Spectral dimension by cell type')
plt.xlabel('Cell type')
plt.ylabel('Spectral dimension')
plt.xticks(rotation=90)

# Get legend handles and labels
handles, labels = scatter.get_legend_handles_labels()

# Create a mapping from main_dataset_name to main_dataset_name (tissue)
legend_labels = []
for label in labels[1:]:  # Skip the first label ('main_dataset_name')
    tissue = cell_manifest_df.loc[cell_manifest_df['main_dataset_name'] == label, 'tissue'].iloc[0]
    legend_labels.append(f"{label} ({tissue})")

# Update the legend labels
plt.legend(handles=handles[1:], labels=legend_labels, title='Metacell source',loc='upper left')

# Show the plot
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


# Create scatter plot
plt.figure(figsize=(4, 4))
scatter = sns.boxplot(
    x='assay',
    y='spectral_dim',
    data=cell_manifest_df,
)

# Customize the plot
plt.xlabel('Assay')
plt.ylabel('Spectral dimension')

# Show the plot
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


# Create scatter plot
plt.figure(figsize=(4, 4))
scatter = sns.boxplot(
    x='suspension_type',
    y='spectral_dim',
    data=cell_manifest_df,
)

# Customize the plot
plt.xlabel('Suspenstion type')
plt.ylabel('Spectral dimension')

# Show the plot
plt.tight_layout()
plt.show()


#### LDA

In [None]:
import json

# load cell manifest
cell_manifest_df = pd.read_csv(os.path.join(output_root_path, "cell_manifest.csv"), index_col=0)

# load communitity
with open(os.path.join(output_root_path, "leiden_dict.pkl"), "rb") as f:
    leiden_dict = pickle.load(f)

In [None]:
# included_datasets = cell_manifest_df[cell_manifest_df["main_dataset_name"] == "extract_100"].index.values
included_datasets = None

# make a flat leiden dict
dataset_leiden_id_to_flat_map = dict()
all_communities_flat_gene_lists = []
all_communities_flat_gene_sets = []
community_labels_tuple = []
community_labels = []

min_community_size = 5

counter = 0
for dataset_name, dataset_leiden_dict in leiden_dict.items():
    if included_datasets is not None:
        if dataset_name not in included_datasets:
            continue
    dataset_leiden_id_to_flat_map[dataset_name] = dict()
    cell_type = cell_manifest_df.loc[dataset_name].cell_type
    for leiden_id, gene_list in dataset_leiden_dict.items():
        if len(gene_list) < min_community_size:
            continue
        all_communities_flat_gene_lists.append(gene_list)
        all_communities_flat_gene_sets.append(set(gene_list))
        dataset_leiden_id_to_flat_map[dataset_name][leiden_id] = counter
        community_labels_tuple.append((cell_type, leiden_id))
        community_labels.append(f"{cell_type}, {leiden_id}")
        counter += 1

In [None]:
from gensim.corpora import Dictionary
from gensim.models import LdaMulticore, LdaModel

# Create a Gensim dictionary from the flattened gene lists
dictionary = Dictionary(all_communities_flat_gene_lists)

# # Filter extremes (optional, can be adjusted as needed)
# dictionary.filter_extremes(no_below=1, no_above=0.5)

# Create a bag-of-words (BoW) corpus
corpus = [dictionary.doc2bow(community) for community in all_communities_flat_gene_lists]

In [None]:
import os
import multiprocessing

# Total physical and logical cores
logical_cores = os.cpu_count()

# Physical cores (if you want more precise control)
physical_cores = multiprocessing.cpu_count()

print(f"Logical cores available: {logical_cores}")
print(f"Physical cores available: {physical_cores}")


In [None]:
import logging

# Enable logging to show progress
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)

# Train the LDA model using multiple cores with progress logging
num_topics = 200  # Number of topics
lda_model = LdaModel(
    corpus=corpus,
    id2word=dictionary,
    num_topics=num_topics,
    # alpha='auto',
    passes=20,  # Number of passes through the corpus
    # workers=128,  # Adjust based on the number of CPU cores
    random_state=42,
)

In [None]:
# Example: Print the top words for each topic
for topic_id, words_and_probs in lda_model.show_topics(num_topics=num_topics, num_words=10, formatted=True):
    print(f"Topic {topic_id}: {words_and_probs}")


In [None]:
import csv

# Example: Extract the top words for each topic
topics = []
min_prob 
for topic_id, words_and_probs in lda_model.show_topics(num_topics=num_topics, num_words=50, formatted=False):
    topic_words = [word for word, _ in words_and_probs]
    topics.append([topic_id] + topic_words)

# Define the CSV file path
output_csv = "./output/lda_topics.csv"

# Write to CSV
with open(output_csv, mode='w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)
    # Write header: Topic ID and word columns
    writer.writerow(["Topic ID"] + [f"Word {i+1}" for i in range(10)])
    # Write each topic and its words
    writer.writerows(topics)

print(f"Topics saved to {output_csv}")



In [None]:
from tqdm.notebook import tqdm

doc_topic_distributions = []
for bow in tqdm(corpus):
    doc_topic_distributions.append(lda_model.get_document_topics(bow))

# Convert sparse representation to dense, if needed
num_topics = lda_model.num_topics
doc_topic_dense = [
    [dict(topics).get(topic_id, 0) for topic_id in range(num_topics)]
    for topics in doc_topic_distributions
]

doc_topic_dense_np = np.asarray(doc_topic_dense)

In [None]:
# make an average embedding for each metacell
dataset_topic_dense = []
dataset_names = []
cell_type_names = []

for dataset_name, leiden_id_to_flat_map in dataset_leiden_id_to_flat_map.items():
    dataset_names.append(dataset_name)
    cell_type_names.append(cell_manifest_df.loc[dataset_name].cell_type)
    mean_topic_usage = np.mean([
        doc_topic_dense_np[flat_id] for flat_id in leiden_id_to_flat_map.values()], axis=0)
    mean_topic_usage = mean_topic_usage / np.sum(mean_topic_usage)
    dataset_topic_dense.append(mean_topic_usage)

dataset_topic_dense_np = np.asarray(dataset_topic_dense)

In [None]:
# make a UMAP embedding
import umap

umap_model = umap.UMAP(n_components=2, random_state=42)
doc_topic_umap_embedding = umap_model.fit_transform(doc_topic_dense_np)

In [None]:
# make a UMAP embedding
import umap

umap_model = umap.UMAP(n_components=2, random_state=42)
dataset_umap_embedding = umap_model.fit_transform(dataset_topic_dense_np)

In [None]:
import plotly.express as px
import pandas as pd
import colorcet as cc
import numpy as np

# Inputs
# Assume community_labels_tuple is a list of (str, int)
# Assume doc_topic_umap_embedding is an N x 2 numpy array

# Extract the labels (str) from community_labels_tuple
labels = [label[0] for label in community_labels_tuple]

# Get unique labels
unique_labels = list(set(labels))

# Assign unique colors using colorcet Glasbey
glasbey_colors = cc.glasbey[:len(unique_labels)]  # Glasbey has many colors, use as needed
label_to_color = {label: glasbey_colors[i] for i, label in enumerate(unique_labels)}

# Create a DataFrame for easier manipulation
df = pd.DataFrame({
    'x': doc_topic_umap_embedding[:, 0],  # First dimension of UMAP embedding
    'y': doc_topic_umap_embedding[:, 1],  # Second dimension of UMAP embedding
    'label': labels  # Labels from community_labels_tuple
})

# Map colors to the labels
df['color'] = df['label'].map(label_to_color)

# Create a scatter plot
fig = px.scatter(
    df,
    x='x',
    y='y',
    color='label',  # Unique color for each label
    title='Scatter Plot with UMAP Embedding',
    labels={'label': 'Community Label'},  # Legend label
    hover_name='label',  # Display label on hover
    color_discrete_map=label_to_color  # Use custom Glasbey colors
)

# Make markers smaller
fig.update_traces(marker=dict(size=2))  # Adjust the size as needed

# Update layout
fig.update_layout(
    width=1500,
    height=800,
)

# Show the plot
fig.show()


In [None]:
import plotly.express as px
import pandas as pd
import colorcet as cc
import numpy as np

# Extract the labels (str) from community_labels_tuple
labels = cell_type_names
emb = dataset_umap_embedding

# Get unique labels
unique_labels = list(set(labels))

# Assign unique colors using colorcet Glasbey
glasbey_colors = cc.glasbey[:len(unique_labels)]  # Glasbey has many colors, use as needed
label_to_color = {label: glasbey_colors[i] for i, label in enumerate(unique_labels)}

# Create a DataFrame for easier manipulation
df = pd.DataFrame({
    'x': emb[:, 0],  # First dimension of UMAP embedding
    'y': emb[:, 1],  # Second dimension of UMAP embedding
    'label': labels  # Labels from community_labels_tuple
})

# Map colors to the labels
df['color'] = df['label'].map(label_to_color)

# Create a scatter plot with text annotations for labels
fig = px.scatter(
    df,
    x='x',
    y='y',
    color='label',  # Unique color for each label
    title='Scatter Plot with UMAP Embedding',
    labels={'label': 'Community Label'},  # Legend label
    hover_name='label',  # Display label on hover
    color_discrete_map=label_to_color,  # Use custom Glasbey colors
    text='label'  # Add text annotations
)

fig.update_traces(
    marker=dict(size=10, opacity=0.7),
    textposition="top center",
    textfont=dict(size=8)  # Set font size for labels
)

# Update layout
fig.update_layout(
    width=1500,
    height=800,
)

# Show the plot
fig.show()
