# cluster the embeddings:

new_esm_embeds.pt  new_gene2vec_embeds.pt  new_genept_embeds.pt  new_hyenaDNA_1m_last_genes.pt

gene2vec: `/work/magroup/shared/Heimdall/data/pca_embeds/new_gene2vec_embeds.pt`

esm2: `/work/magroup/shared/Heimdall/data/pca_embeds/new_esm_embeds.pt`

new_genept_embeds: `/work/magroup/shared/Heimdall/data/pca_embeds/new_genept_embeds.pt`

new_hyenaDNA_1m_last_genes: `/work/magroup/shared/Heimdall/data/pca_embeds/new_hyenaDNA_1m_last_genes.pt`


In [None]:
import torch 
import scanpy as sc
import numpy as np
import scanpy as sc
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import random
import pandas as pd
from goatools.obo_parser import GODag
from collections import defaultdict
from tqdm import tqdm
import os
import sys

# Step 1: Get current working directory
current_dir = os.getcwd()

# Step 2: Get parent directory (one level up)
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))

# Step 3: Add parent directory to sys.path
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

# âœ… Now you can import modules from one level up

from Heimdall.utils import (
    symbol_to_ensembl_from_ensembl,
)

def convert_to_ensembl_ids(adata, data_dir, species="human"):
    """Converts gene symbols in the anndata object to Ensembl IDs using a
    provided mapping.

    Args:
        - data: anndata object with gene symbols as var index
        - data_dir: directory where the data is stored
        - species: species name (default is "human")

    Returns:
        - data: anndata object with Ensembl IDs as var index
        - symbol_to_ensembl_mapping: mapping dictionary from symbols to Ensembl IDs

    """
    symbol_to_ensembl_mapping = symbol_to_ensembl_from_ensembl(
        data_dir=data_dir,
        genes=adata.var.index.tolist(),
        species=species,
    )

    adata.uns["gene_mapping:symbol_to_ensembl"] = symbol_to_ensembl_mapping.mapping_full

    adata.var["gene_symbol"] = adata.var.index
    adata.var["gene_ensembl"] = adata.var["gene_symbol"].map(
        symbol_to_ensembl_mapping.mapping_combined.get,
    )
    adata.var.index = adata.var.index.map(symbol_to_ensembl_mapping.mapping_reduced)
    adata.var.index.name = "index"

    return adata, symbol_to_ensembl_mapping




# Global Paths

In [None]:
# data_path = "/work/magroup/shared/Heimdall/data/" + "/sctab/tissue_splits_spencer/scTab_GItract_train.h5ad"
data_path = "/work/magroup/shared/Heimdall/data/ovarian_cancer/pairs_1_3.h5ad"

# Example inputs
# embeds = torch.load("/work/magroup/shared/Heimdall/data/pca_embeds/new_genept_embeds.pt")
embeds = torch.load("/work/magroup/shared/Heimdall/data/pca_embeds/new_gene2vec_embeds.pt")
# embeds = torch.load("/work/magroup/shared/Heimdall/data/pca_embeds/new_esm_embeds.pt")
# embeds = torch.load("/work/magroup/shared/Heimdall/data/pca_embeds/new_hyenaDNA_1m_last_genes.pt")

go_human_df_path = "/work/magroup/shared/Heimdall/data/GO_terms/GO_human.csv"
go_dag_path = "/work/magroup/nzh/Heimdall-dev/gene_embeds_analysis/go-basic.obo"

# Load Dataset of Interest

In [None]:

dataset = sc.read_h5ad(data_path)
adata, mapping = convert_to_ensembl_ids(dataset, "./", species = "human")
genes = list(dataset.var.index)  # Full list
dataset.var

# GO Annotations

In [None]:
ensembl_ids = list(dataset.var.index)  # Full list


# Step 1: Load your GO_human.csv
go_human_df = pd.read_csv(
    go_human_df_path,
    header=None,
    names=["rownum", "ensembl_id", "gene_name", "go_id", "go_term_name"]
)

# Filter rows that have a GO ID (not NaN)
go_human_df = go_human_df.dropna(subset=["go_id"])

# Step 2: Build ensembl_id -> list of GO terms
ensembl_to_go = defaultdict(list)
for _, row in go_human_df.iterrows():
    ensembl_to_go[row['ensembl_id']].append(row['go_id'])

# Step 3: Load GO DAG (for parent-child lookup)
godag = GODag(go_dag_path)

# Step 4: Define broad functional categories
broad_categories = {
    "Signal transduction": ["GO:0007165"],
    "Transcription regulation": ["GO:0006355"],
    "Translation": ["GO:0006412"],
    "Metabolism": ["GO:0008152"],
    "Cell cycle": ["GO:0007049"],
    "Apoptosis": ["GO:0006915"],
    "Developmental process": ["GO:0032502"],
    "Immune response": ["GO:0006955"],
    "Response to stimulus": ["GO:0050896"],
    "Transport": ["GO:0006810"],
    "Protein modification": ["GO:0036211"],
    "Cell adhesion": ["GO:0007155"],
    "Cytoskeleton organization": ["GO:0007010"],
    "DNA repair": ["GO:0006281"],
    "Chromatin organization": ["GO:0006325"],
    "Membrane organization": ["GO:0061024"],
    "Autophagy": ["GO:0006914"],
    "RNA processing": ["GO:0006396"],
    "Ion transport": ["GO:0006811"],
    "Unknown function": []
}

# Reverse mapping: GO ID -> Category
go_to_category = {}
for category, go_ids in broad_categories.items():
    for go_id in go_ids:
        go_to_category[go_id] = category

# Step 5: Helper function to assign a broad category
def assign_broad_category(go_terms):
    seen_categories = set()
    for go_id in go_terms:
        if go_id in go_to_category:
            seen_categories.add(go_to_category[go_id])
        else:
            term = godag.get(go_id)
            if term:
                for ancestor_id in term.get_all_parents():
                    if ancestor_id in go_to_category:
                        seen_categories.add(go_to_category[ancestor_id])
    if seen_categories:
        return list(seen_categories)[0]  # Take first match arbitrarily
    else:
        return "Unknown function"

# Step 6: Your Ensembl IDs to classify
# ensembl_ids = [
#     "ENSG00000141510",
#     "ENSG00000171862",
#     "ENSG00000157764",
#     "ENSG00000281022",  # MED22 from your CSV
#     "ENSG00000280858",  # RPL7A from your CSV
# ]


# Step 7: Classify each Ensembl ID
gene_to_broad_category = {}

for ensembl_id in tqdm(ensembl_ids):
    go_terms = ensembl_to_go.get(ensembl_id, [])
    category = assign_broad_category(go_terms)
    gene_to_broad_category[ensembl_id] = category

# Step 8: Output
print("\nBroad Functional Categories:")
for gene_id, category in gene_to_broad_category.items():
    print(f"{gene_id}: {category}")

# Optional: Save to CSV
output_df = pd.DataFrame({
    "ensembl_id": list(gene_to_broad_category.keys()),
    "broad_category": list(gene_to_broad_category.values())
})
# output_df.to_csv("ensembl_broad_function_assignment.csv", index=False)
# print("\nSaved output to ensembl_broad_function_assignment.csv")


# Getting the UMAP for the embeds, colored by BROAD CATEGORY

In [None]:
# Filter only valid genes
valid_genes = [eid for eid in genes if eid in embeds]
X = np.array([embeds[eid] for eid in valid_genes])
adata = ad.AnnData(X)
adata.obs['ensembl_id'] = valid_genes
adata.obs = adata.obs.merge(output_df, on='ensembl_id', how='left')

# UMAP
sc.pp.neighbors(adata, n_neighbors=15, use_rep='X')
sc.tl.umap(adata)

# Plot
sc.pl.umap(adata, legend_loc='right margin', size=20)

# Prettier Plotting

In [None]:


# 1. Prepare categories
categories = adata.obs['broad_category'].cat.categories
n_categories = len(categories)

# 2. Generate a good distinct color palette
palette = sns.color_palette("tab20", n_colors=n_categories)  # or "hls" for more neon

# 3. Shuffle the palette
random.seed(10)  # For reproducibility
random.shuffle(palette)

# 4. Create a mapping from category to color (shuffled)
category_to_color = {cat: palette[i] for i, cat in enumerate(categories)}

# 5. Make a figure larger manually
fig, ax = plt.subplots(figsize=(12, 10))

# 6. Actually plot
sc.pl.umap(
    adata,
    color='broad_category',
    palette=[category_to_color[c] for c in adata.obs['broad_category'].cat.categories],
    ax=ax,
    size=50,                  # Bigger points
    alpha=1,
    frameon=False,             # No frame
    legend_loc='right margin', # Legend out of the way
    legend_fontsize=14,
    title='UMAP of Broad Categories With ESM2',
    show=True
)
