### Gene networks

TODO:
- Check for LAG3 and TIGIT before and after mean manifold
  - Do we correctly preserve the phenotype after snapping to the mean manifold?

In [None]:
import os
import torch
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

from cellarium.ml import CellariumModule, CellariumPipeline

DEVICE = torch.device('cuda:7')

In [None]:
ROOT_PATH = "/mnt/cellariumgpt-xfer/mb-ml-dev-vm"
CHECKPOINTS_PATH = "/mnt/cellariumgpt-xfer/100M_long_run/run_001/lightning_logs/version_0/checkpoints"

### Load an AnnData Extract

We will use it for category mappings ...

In [None]:
# Load an AnnData extract
adata_path = os.path.join(ROOT_PATH, "data", "extract_0.h5ad")
adata = sc.read_h5ad(adata_path)

In [None]:
gene_ontology_infos = dict()

ref_obs = adata.obs

gene_ontology_infos["assay_ontology_term_id"] = dict()
gene_ontology_infos["assay_ontology_term_id"]["names"] = list(ref_obs['assay_ontology_term_id'].cat.categories)  
gene_ontology_infos["assay_ontology_term_id"]["labels"] = list(ref_obs['assay_ontology_term_id'].cat.categories) # just because I am lazy

gene_ontology_infos["suspension_type"] = dict()
gene_ontology_infos["suspension_type"]["names"] = list(ref_obs['suspension_type'].cat.categories)  # for uniformity -- this variable does not have an ontology (does it?)
gene_ontology_infos["suspension_type"]["labels"] = list(ref_obs['suspension_type'].cat.categories)

In [None]:
# gene IDs, gene symbols, useful maps
model_var_names = np.asarray(adata.var_names)
model_var_names_set = set(model_var_names)
var_name_to_index_map = {var_name: i for i, var_name in enumerate(model_var_names)}

gene_info_tsv_path = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")
gene_info_df = pd.read_csv(gene_info_tsv_path, sep="\t")

gene_symbol_to_gene_id_map = dict()
for gene_symbol, gene_id in zip(gene_info_df['Gene Symbol'], gene_info_df['ENSEMBL Gene ID']):
    if gene_symbol != float('nan'):
        gene_symbol_to_gene_id_map[gene_symbol] = gene_id

gene_id_to_gene_symbol_map = {gene_id: gene_symbol for gene_symbol, gene_id in gene_symbol_to_gene_id_map.items()}
for gene_id in model_var_names:
    if gene_id not in gene_id_to_gene_symbol_map:
        gene_id_to_gene_symbol_map[gene_id] = gene_id

### Load a Jacobian result

In [None]:
def process_jacobian(dataset_name: str, jacobian_point: str):
    jacobian_pt_path = os.path.join(
        ROOT_PATH, "cellariumgpt_playground", "output", f"jacobian__{dataset_name}__{jacobian_point}.pt")

    jacobian_results_dict = torch.load(jacobian_pt_path, map_location=DEVICE)

    x_g = jacobian_results_dict['prompt_gene_values_g']
    jacobian_gp = jacobian_results_dict['jacobian_qg'].cpu().numpy()  # p for perturb

    # first-order Taylor's expansion of in-silico deletion
    y_gp = np.clip(x_g[:, None] - jacobian_gp * x_g[None, :], a_min=0., a_max=None)
    z_gp = np.log2(y_gp) - np.log2(x_g[:, None])
    z_gp[np.isnan(z_gp)] = 0.

    MAX_LFC = 1.
    z_gp = np.clip(z_gp, a_min=-MAX_LFC, a_max=MAX_LFC)
    z_unit_gp = z_gp / np.linalg.norm(z_gp, axis=1, keepdims=True)
    dist_gg = z_unit_gp @ (z_unit_gp.T)

    return {
        "jacobian_results_dict": jacobian_results_dict,
        "x_g": x_g,
        "jacobian_gp": jacobian_gp,
        "y_gp": y_gp,
        "z_gp": z_gp,
        "z_unit_gp": z_unit_gp,
        "dist_gg": dist_gg
    }

def get_gene_neighborhood(
        processed_jacobian_dict: dict,
        target_gene_symbol: str,
        similarity_var: str = 'dist_gg') -> dict:

    target_gene_id = gene_symbol_to_gene_id_map[target_gene_symbol]

    assert np.all(
        np.asarray(processed_jacobian_dict['jacobian_results_dict']['query_var_names']) ==
        np.asarray(processed_jacobian_dict['jacobian_results_dict']['prompt_var_names']))
    var_name_to_index_map = {var_name: index for index, var_name in enumerate(
        processed_jacobian_dict['jacobian_results_dict']['query_var_names'])}

    target_gene_index = var_name_to_index_map[target_gene_id]
    target_neighbor_dist_g = processed_jacobian_dict[similarity_var][target_gene_index]
    sort_order = np.argsort(target_neighbor_dist_g)[::-1]

    return {
        "target_gene_symbol": target_gene_symbol,
        "target_gene_index": target_gene_index,
        "target_gene_id": target_gene_id,
        "target_neighbor_dist_g": target_neighbor_dist_g,
        "sort_order": sort_order
    }

def plot_neighborhood_distance_distribution(neighborhood_dict: dict):
    fig, ax = plt.subplots(figsize=(3, 3))
    ax.hist(neighborhood_dict['target_neighbor_dist_g'], bins=50);
    ax.set_xlabel(f'Cosine distance to {neighborhood_dict["target_gene_symbol"]}')
    ax.set_ylabel('Number of genes')

def neighborhood_to_df(neighborhood_dict: dict, processed_jacobian_dict: dict) -> pd.DataFrame:
    query_var_names = processed_jacobian_dict["jacobian_results_dict"]["query_var_names"]
    return pd.DataFrame({
        "gene_id": [query_var_names[i] for i in neighborhood_dict['sort_order']],
        "gene_symbol": [gene_id_to_gene_symbol_map[query_var_names[i]] for i in neighborhood_dict['sort_order']],
        "cosine_distance": neighborhood_dict['target_neighbor_dist_g'][neighborhood_dict['sort_order']]
    })

In [None]:
target_gene_symbol = "PDCD1"
similarity_var = "z_gp"

dataset_name_1 = "luca_CD8_ex_LUAD"
jacobian_point_1 = "actual"

processed_jacobian_dict_1 = process_jacobian(dataset_name_1, jacobian_point_1)
neighborhood_dict_1 = get_gene_neighborhood(processed_jacobian_dict_1, target_gene_symbol, similarity_var)
df_1 = neighborhood_to_df(neighborhood_dict_1, processed_jacobian_dict_1)

dataset_name_2 = "luca_CD8_act_normal"
jacobian_point_2 = "marginal_mean"

processed_jacobian_dict_2 = process_jacobian(dataset_name_2, jacobian_point_2)
neighborhood_dict_2 = get_gene_neighborhood(processed_jacobian_dict_2, target_gene_symbol, similarity_var)
df_2 = neighborhood_to_df(neighborhood_dict_2, processed_jacobian_dict_2)

In [None]:
with pd.option_context('display.max_rows', 200):
    display(df_1.tail(100)) 

In [None]:
with pd.option_context('display.max_rows', 200):
    display(df_2.head(50))

In [None]:
import plotly.express as px
import pandas as pd

# Create a DataFrame for Plotly
query_gene_symbols = [
    gene_id_to_gene_symbol_map[gene_id]
    for gene_id in processed_jacobian_dict_1['jacobian_results_dict']['query_var_names']]

df = pd.DataFrame({
    'x': neighborhood_dict_1["target_neighbor_dist_g"],
    'y': neighborhood_dict_2["target_neighbor_dist_g"],
    'label': query_gene_symbols
})

# Create the scatter plot
fig = px.scatter(df, x='x', y='y', hover_name='label', title='Interactive Scatter Plot')

# Update marker size
fig.update_traces(marker=dict(size=2))  # Adjust the size as needed

# Update layout to decrease the width of the plot
fig.update_layout(
    width=800,  # Adjust the width as needed
    # plot_bgcolor='white',
    xaxis=dict(
        showgrid=True,
        # showticklabels=False,
        title=dataset_name_1,
    ),
    yaxis=dict(
        showgrid=True,
        # showticklabels=False,
        title=dataset_name_2
    )
)

# Show the plot
fig.show()

### Embedding

In [None]:
import pymde

In [None]:
processed_jacobian_dict = processed_jacobian_dict_1
neighborhood_dict = neighborhood_dict_1

In [None]:
neighborhood_dict.keys()

In [None]:
mde = pymde.preserve_neighbors(
    processed_jacobian_dict["z_unit_gp"], device=DEVICE, verbose=True, n_neighbors=10, repulsive_fraction=0.9,
    repulsive_penalty=pymde.penalties.InvPower)

In [None]:
embedding_g2 = mde.embed(verbose=True)
embedding_g2 = embedding_g2.cpu().numpy()

In [None]:
import plotly.express as px
import pandas as pd

query_gene_symbols = [
    gene_id_to_gene_symbol_map[gene_id]
    for gene_id in processed_jacobian_dict["jacobian_results_dict"]["query_var_names"]]
    
# Create a DataFrame for Plotly
df = pd.DataFrame({
    'x': embedding_g2[:, 0],
    'y': embedding_g2[:, 1],
    'label': query_gene_symbols
})

# Create the scatter plot
fig = px.scatter(df, x='x', y='y', hover_name='label', title='Interactive Scatter Plot')

# Update marker size
fig.update_traces(marker=dict(size=2))  # Adjust the size as needed

# Highlight the specific point in red
target_idx = neighborhood_dict['target_gene_index']
fig.add_scatter(
    x=[embedding_g2[target_idx, 0]],
    y=[embedding_g2[target_idx, 1]],
    mode='markers+text',
    marker=dict(color='green', size=10),
    # text=[query_gene_symbols[target_idx]],
    textposition='top center',
    showlegend=False
)

highlight_indices = [idx for idx in neighborhood_dict['sort_order'][:100] if idx != target_idx]
fig.add_scatter(
    x=embedding_g2[highlight_indices, 0],
    y=embedding_g2[highlight_indices, 1],
    mode='markers+text',
    marker=dict(color='red', size=6),
    # text=np.asarray(query_gene_symbols, dtype=object)[highlight_indices],
    textposition='top center',
    showlegend=False
)

# Update layout to decrease the width of the plot
fig.update_layout(
    width=700,  # Adjust the width as needed
    height=700,
    plot_bgcolor='white',
    xaxis=dict(
        showgrid=False,
        showticklabels=False,
        title='MDE_1'
    ),
    yaxis=dict(
        showgrid=False,
        showticklabels=False,
        title='MDE_2'
    )
)

# Show the plot
fig.show()

In [None]:
processed_jacobian_dict = processed_jacobian_dict_2
neighborhood_dict = neighborhood_dict_1

In [None]:
neighborhood_dict.keys()

In [None]:
mde = pymde.preserve_neighbors(
    processed_jacobian_dict["z_unit_gp"], device=DEVICE, verbose=True, n_neighbors=10, repulsive_fraction=0.9,
    repulsive_penalty=pymde.penalties.InvPower)

In [None]:
embedding_g2 = mde.embed(verbose=True)
embedding_g2 = embedding_g2.cpu().numpy()

In [None]:
import plotly.express as px
import pandas as pd

query_gene_symbols = [
    gene_id_to_gene_symbol_map[gene_id]
    for gene_id in processed_jacobian_dict["jacobian_results_dict"]["query_var_names"]]
    
# Create a DataFrame for Plotly
df = pd.DataFrame({
    'x': embedding_g2[:, 0],
    'y': embedding_g2[:, 1],
    'label': query_gene_symbols
})

# Create the scatter plot
fig = px.scatter(df, x='x', y='y', hover_name='label', title='Interactive Scatter Plot')

# Update marker size
fig.update_traces(marker=dict(size=2))  # Adjust the size as needed

# Highlight the specific point in red
target_idx = neighborhood_dict['target_gene_index']
fig.add_scatter(
    x=[embedding_g2[target_idx, 0]],
    y=[embedding_g2[target_idx, 1]],
    mode='markers+text',
    marker=dict(color='green', size=10),
    # text=[query_gene_symbols[target_idx]],
    textposition='top center',
    showlegend=False
)

highlight_indices = [idx for idx in neighborhood_dict['sort_order'][:100] if idx != target_idx]
fig.add_scatter(
    x=embedding_g2[highlight_indices, 0],
    y=embedding_g2[highlight_indices, 1],
    mode='markers+text',
    marker=dict(color='red', size=6),
    # text=np.asarray(query_gene_symbols, dtype=object)[highlight_indices],
    textposition='top center',
    showlegend=False
)

# Update layout to decrease the width of the plot
fig.update_layout(
    width=700,  # Adjust the width as needed
    height=700,
    plot_bgcolor='white',
    xaxis=dict(
        showgrid=False,
        showticklabels=False,
        title='MDE_1'
    ),
    yaxis=dict(
        showgrid=False,
        showticklabels=False,
        title='MDE_2'
    )
)

# Show the plot
fig.show()