### _in silico_ perturbation by cell type prompting

In [None]:
import os
import torch
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

sc.set_figure_params(figsize=(4, 4))

DEVICE = torch.device('cuda:0')

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext, \
    GeneNetworkAnalysisBase, \
    load_gene_info_table

In [None]:
ROOT_PATH = "/home/mehrtash/data"
REF_ADATA_PATH = os.path.join(ROOT_PATH, "data", "extract_0.h5ad")
GENE_INFO_PATH = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

In [None]:
val_adata = sc.read_h5ad("/home/mehrtash/data/data/cellariumgpt_artifacts/cell_types_for_validation_filtered.h5ad")

In [None]:
import pickle

# load linear response results
linear_response_file_path = os.path.join(ROOT_PATH, "data", "linear_response", "100M", "linear_response_cell_index_13.pkl")

with open(linear_response_file_path, "rb") as f:
    linear_response_dict = pickle.load(f)

In [None]:
gene_info_df, gene_symbol_to_gene_id_map, gene_id_to_gene_symbol_map = load_gene_info_table(
    GENE_INFO_PATH, included_gene_ids=linear_response_dict["query_gene_ids"])

In [None]:
for k in linear_response_dict.keys():
    print(k)

In [None]:
plt.hist(linear_response_dict['r_squared_qp'].flatten(), bins=50, log=True);

In [None]:
plt.hist(linear_response_dict['slope_qp'].flatten(), bins=50, log=True);

In [None]:
# Generate an AnnData containing just the metadata
adata_prop = sc.AnnData(
    X=np.zeros((1, 1)),
    obs=pd.DataFrame({
        "cell_type": [linear_response_dict["prompt_metadata_dict"]["cell_type"]],
        "tissue": [linear_response_dict["prompt_metadata_dict"]["tissue"]],
        "assay": [linear_response_dict["assay"]],
        "suspension_type": [linear_response_dict["suspension_type"]],
        "total_mrna_umis": [linear_response_dict["total_mrna_umis"]],
        "disease": "normal",
        "development_stage": "unknown",
        "sex": "male",
    })
)

In [None]:
adata_prop.obs

In [None]:
gene_info_tsv_path = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

min_r_squared = 0.25

response_qp = linear_response_dict["slope_qp"].copy()
response_qp[linear_response_dict["r_squared_qp"] < min_r_squared] = 0.

query_marginal_mean_q = linear_response_dict["gene_marginal_mean_q"]
query_marginal_std_q = linear_response_dict["gene_marginal_std_q"]

# note: prompt and query genes are the same in these experiments
prompt_marginal_mean_p = linear_response_dict["gene_marginal_mean_q"]
prompt_marginal_std_p = linear_response_dict["gene_marginal_std_q"]

network_ctx = GeneNetworkAnalysisBase(
    adata_obs=adata_prop.obs,
    gene_info_tsv_path=gene_info_tsv_path,
    query_var_names=linear_response_dict["query_gene_ids"].tolist(),
    prompt_var_names=linear_response_dict["query_gene_ids"].tolist(),
    response_qp=response_qp,
    prompt_marginal_mean_p=prompt_marginal_mean_p,
    prompt_marginal_std_p=prompt_marginal_std_p,
    query_marginal_mean_q=query_marginal_mean_q,
    query_marginal_std_q=query_marginal_std_q,
    verbose=True
)

In [None]:
network_ctx.process(
    response_normalization_strategy="log1p",
    feature_normalization_strategy="z_score",
    feature_max_value=None,
    query_response_amp_min_pct=None,
    min_prompt_gene_tpm=0.1,
    min_query_gene_tpm=0.1,
    norm_pseudo_count=0.,  # not needed for log1p normalization strategy
    query_hv_top_k=None,
    z_trans_func=None,
    included_gene_biotypes={'protein_coding'},
)

# network_ctx.process(
#     response_normalization_strategy="corr",
#     feature_normalization_strategy="z_score",
#     feature_max_value=None,
#     query_response_amp_min_pct=None,
#     min_prompt_gene_tpm=1,
#     min_query_gene_tpm=1,
#     norm_pseudo_count=1e-3,
#     query_hv_top_k=None,
#     z_trans_func=None,
# )

# network_ctx.process(
#     response_normalization_strategy="mean",
#     feature_normalization_strategy="z_score",
#     feature_max_value=None,
#     query_response_amp_min_pct=None,
#     min_prompt_gene_tpm=0.1,
#     min_query_gene_tpm=0.1,
#     norm_pseudo_count=1e-3,
#     query_hv_top_k=None,
#     z_trans_func=None,
# )

In [None]:
network_ctx.compute_adjacency_matrix(
    adjacency_strategy="positive_correlation",
    n_neighbors=50,
    beta=12.,
    # beta=1.,
    self_loop=False)

In [None]:
i = network_ctx.prompt_gene_id_to_idx_map[network_ctx.gene_symbol_to_gene_id_map['MT-CO1']]
inds = np.argsort(network_ctx.a_pp[i, :])[::-1]
for j in inds[:100]:
    print(network_ctx.prompt_gene_symbols[j], network_ctx.a_pp[i, j])

In [None]:
network_ctx.compute_leiden_communites(
    resolution=10.0)

In [None]:
len(np.unique(network_ctx.leiden_membership))

#### Embedding

In [None]:
import pymde

network_ctx.make_mde_embedding(
    n_neighbors=20,
    # repulsive_penalty=pymde.penalties.Log,
    max_iter=1000,
    init="quadratic",
    device="cuda")

In [None]:
def get_gene_familities(network_ctx: GeneNetworkAnalysisBase, prefix_list: list[str]) -> tuple[list[str], list[str]]:
    _gene_symbols = [gene_symbol for prefix in prefix_list for gene_symbol in network_ctx.prompt_gene_symbols if gene_symbol.startswith(prefix)]
    gene_ids = [network_ctx.gene_symbol_to_gene_id_map[gene_symbol] for gene_symbol in _gene_symbols]
    gene_symbols = [network_ctx.gene_id_to_gene_symbol_map[gene_id] for gene_id in gene_ids]
    return gene_ids, gene_symbols

mito_gene_ids, mito_gene_symbols = get_gene_familities(network_ctx, ["MT-"])
ribo_gene_ids, ribo_gene_symbols = get_gene_familities(network_ctx, ["RPS", "RPL"])
hla_gene_ids, hla_gene_symbols = get_gene_familities(network_ctx, ["HLA"])
ifi_gene_ids, ifi_gene_symbols = get_gene_familities(network_ctx, ["IFI"])
trav_gene_ids, trav_gene_symbols = get_gene_familities(network_ctx, ["TRAV"])
hb_gene_ids, hb_gene_symbols = get_gene_familities(network_ctx, ["HBA", "HBB"])

snap_n_gene_symbols = [
    'GAP43',
    'NRXN3',
    'HOMER1',
    'IL1RAPL2',
    'EPHA3',
    'RIMS1',
    'SV2B',
    'TRIM9',
    'SVOP',
    'RPH3A',
    'SYT12',
    'SYT1',
    'R3HDM2',
    'PDE4B',
    'DCC',
    'SLC4A10',
    'DNM3',
    'GRM1',
    'EGR4',
    'JUNB',
    'TFDP2'
]

snap_n_gene_symbols = [x for x in snap_n_gene_symbols if x in network_ctx.prompt_gene_symbols]
snap_n_gene_ids = [network_ctx.gene_symbol_to_gene_id_map[x] for x in snap_n_gene_symbols]

highlight_gene_sets = {
    "Mito": (mito_gene_ids, mito_gene_symbols, 'red'),
    "Ribo": (ribo_gene_ids, ribo_gene_symbols, 'blue'),
    "HLA": (hla_gene_ids, hla_gene_symbols, 'green'),
    "IFI": (ifi_gene_ids, ifi_gene_symbols, 'orange'),
    "TRAV": (trav_gene_ids, trav_gene_symbols, 'purple'),
    # "HB": (hb_gene_ids, hb_gene_symbols, 'black'),
    # "TTN": ([network_ctx.gene_symbol_to_gene_id_map['TTN']], ['TTN'], 'black'),
    # "SNAP-N": (snap_n_gene_ids, snap_n_gene_symbols, 'brown'),
}

# disable
# highlight_gene_sets = None

In [None]:
network_ctx.plot_mde_embedding(highlight_gene_sets=highlight_gene_sets)

In [None]:
sorted(np.asarray(network_ctx.prompt_gene_symbols)[network_ctx.leiden_membership == 5])

In [None]:
import plotly.express as px

# Create a DataFrame for Plotly
df = pd.DataFrame({
    'x': network_ctx.embedding_p2[:, 0],
    'y': network_ctx.embedding_p2[:, 1],
    'label': network_ctx.prompt_gene_symbols,
    'log1p_mean': np.clip(
        np.log(1 + network_ctx.prompt_marginal_mean_p * 10_000 / network_ctx.prompt_marginal_mean_p.sum()),
        0, 0.5)
})

# Create the scatter plot
fig = px.scatter(
    df,
    x='x',
    y='y',
    hover_name='label',
    color='log1p_mean',
    color_continuous_scale=[(0,'yellow'), (0.5, 'blue'), (1,'red')],
)

# Update marker size
fig.update_traces(marker=dict(size=2))  # Adjust the size as needed

# Update layout to decrease the width of the plot
fig.update_layout(
    width=800,  # Adjust the width as needed
    height=800,
    plot_bgcolor='white',
    xaxis=dict(
        showgrid=False,
        showticklabels=False,
        title='MDE_1'
    ),
    yaxis=dict(
        showgrid=False,
        showticklabels=False,
        title='MDE_2'
    )
)