# Spatial Umaps

In [None]:
import matplotlib.pyplot as plt
import scanpy as sc
import seaborn as sns
import pandas as pd
import glob
import os
import scanpy as sc
from tqdm import tqdm


def process_spatial_data(dataset, file_path, write_h5ad, csv_file = "spatial_umap.csv", model_path="../spatial_umaps_results", sort_ag_but_no_gene2vec_model_path="../spatial_split1_sort_ag_results", all_genes=False, seed=55):
    """
    Process spatial transcriptomics data, match model directories, and create UMAP embeddings.
    
    Parameters:
    -----------
    dataset : str
        Name of the experiment dataset (e.g., "spatial_cancer_split1")
    file_path : str
        Path to the h5ad file (e.g., "/ovarian_cancer/pairs_0_1.h5ad")
    write_h5ad : str
        Output filename for the processed h5ad file
    all_genes : bool, default=False
        Flag to process all genes
    """

    df = pd.read_csv(csv_file)
    df = df.dropna(how='all')

    df = df[df["Experiment Name"] == dataset]
    df = df[df["all_genes"] == all_genes]

    print(len(df))

    # Define the mappings
    gene_mapping = {
        "identity": "IdentityFg",
        "pca_hyenadna": "HyenaDNAFg",
        "pca_genept": "GenePTFg",
        "pca_gene2vec": "Gene2VecFg",
        "pca_esm2": "ESM2Fg",
    }

    expression_mapping = {
        "nonzero_2nn": "NonzeroIdentityFe",
        "scfoundation": "scfound",
        "sorting": "SortingFe",
        "binning": "BinningFe",
    }

    cell_mapping = {
        "geneformer": "GeneformerFc",
        "scgpt": "ScGPTFc",
    }

    name_expression_mapping = {
        "nonzero_2nn": "Continuous",
        "scfoundation": "Autobin",
        "sorting": "Sorting",
        "binning": "Binning",
    }

    name_gene_mapping = {
        "identity": "Random",
        "pca_hyenadna": "HyenaDNA",
        "pca_genept": "GenePT",
        "pca_gene2vec": "Gene2vec",
        "pca_esm2": "ESM2",
    }

    dataset_mapping = {
        "spatial_cancer_split1": "ovarian_cancer_pairs_0_1",
        "spatial_cancer_split2": "ovarian_cancer_pairs_1_3",
    }

    models_folder = model_path

    # Container for matched paths
    matched_paths = []
    payload = []

    # Loop through each row
    for idx, row in df.dropna().iterrows():
        lr = row['learning_rate']
        bz = int(row['batchsize'])
        f_gene = row['F_Gene']
        f_expression = row['F_Expression']
        f_cell = row['F_Cell']
        all_genes = row['all_genes']

        dataset = dataset_mapping[row["Experiment Name"]]
        
        # Map to folder name parts
        gene_part = gene_mapping[f_gene]
        expression_part = expression_mapping[f_expression]
        cell_part = cell_mapping[f_cell]
        drop_zeros = not all_genes
        
        # Build a search pattern
        pattern = f"Heimdall.fg.{gene_part}_Heimdall.fe.{expression_part}_Heimdall.fc.{cell_part}_{dataset}_lr{lr}_bz{bz}_seed{seed}"
        
        if row["F_Expression"] == "sorting" and row["Experiment Name"] == "spatial_cancer_split1" and drop_zeros == False and row["F_Gene"] != "pca_gene2vec":
            models_folder = sort_ag_but_no_gene2vec_model_path
        else:
            models_folder = model_path


        # Full search path
        full_pattern = os.path.join(models_folder, pattern)
        
        # Find matching directories
        matches = glob.glob(full_pattern)
        if matches:
            matched_paths.extend(matches)
            payload.append({
                "uuid": name_gene_mapping[f_gene] + "_" + name_expression_mapping[f_expression],
                "path": matches[0]
            })
        else:
            print(f"[WARNING] No match found for config at row {idx}: {full_pattern}, trying different configuration...")

            
            pattern = f"Heimdall.fg.{gene_part}_Heimdall.fe.{expression_part}_Heimdall.fc.{cell_part}_{dataset}_lr{lr}_bz{bz}_seed{seed}_ag{drop_zeros}"
            full_pattern = os.path.join(models_folder, pattern)
            matches2 = glob.glob(full_pattern)

            if matches2:
                matched_paths.extend(matches2)
                payload.append({
                    "uuid": name_gene_mapping[f_gene] + "_" + name_expression_mapping[f_expression],
                    "path": matches2[0]
                })
                print(f" Found!! for {idx}: {pattern}")

            else: 
                print(f"DID NOT FIND: {full_pattern}")
    
    # Output
    for path in matched_paths:
        print(path)

    print(len(matched_paths))

    paths_df = pd.DataFrame(payload)
    print(paths_df)

    # data_path = "/work/magroup/shared/Heimdall/data/" 
    adata = sc.read_h5ad(file_path) 
    test_adata = adata[adata.obs["split"] == "test"].copy()  

    for i in tqdm(range(len(paths_df))):     
        row = paths_df.iloc[i]     
        sel_adata = sc.read_h5ad(row["path"] + "/test_adata.h5ad")     
        test_adata.obsm["X_" + row["uuid"]] = sel_adata.obsm["X_umap"].copy()  

    test_adata.write(write_h5ad)
    return test_adata

    


def visualize_embeddings(test_adata):
    """
    Visualize the UMAP embeddings created by the process_spatial_data function,
    with fixed grid spacing, aligned titles, and a unified legend at the bottom.
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    import scanpy as sc
    import matplotlib.patches as mpatches

    # 1. Define your models, tasks, and cell‑type categories
    models = ["Random", "HyenaDNA", "ESM2", "GenePT", "Gene2vec"]
    tasks  = ["Sorting", "Binning", "Autobin", "Continuous"]
    ct_cats = test_adata.obs["celltypes"].cat.categories

    # 2. Set up a clean grid with fixed spacing (no constrained_layout)
    sns.set(style="whitegrid", context="talk")
    fig, axes = plt.subplots(
        len(models), len(tasks),
        figsize=(20, 20),
        gridspec_kw={"hspace": 0.3, "wspace": 0.2}
    )
    axes = axes.reshape(len(models), len(tasks))

    # 3. Build a distinct palette for your cell‑types
    palette = sns.color_palette("husl", n_colors=len(ct_cats))
    ct_colors = {ct: palette[i] for i, ct in enumerate(ct_cats)}

    # 4. Loop over each subplot
    embedding_keys = [k for k in test_adata.obsm if k.startswith("X_")]
    for i, model in enumerate(models):
        for j, task in enumerate(tasks):
            ax = axes[i, j]
            matching = [k for k in embedding_keys if model in k and task in k]
            if not matching:
                ax.axis("off")
                continue

            basis = matching[0][2:]  # drop the "X_"
            sc.pl.embedding(
                test_adata,
                basis=basis,
                color="celltypes",
                palette=[ct_colors[c] for c in ct_cats],
                size=20,
                alpha=0.7,
                frameon=False,
                legend_loc=None,
                ax=ax,
                show=False
            )

            # 5. Pin every title at exactly the same height
            ax.set_title(f"{model} {task}", fontsize=16, y=1.03, weight = "bold")
            ax.set_aspect("equal", "box")

    # 6. Reserve room at the bottom for the legend
    fig.tight_layout(rect=[0, 0.07, 1, 1])

    # 7. Unified legend below
    handles = [
        mpatches.Patch(color=ct_colors[c], label=c)
        for c in ct_cats
    ]
    fig.legend(
        handles=handles,
        title="Cell Type",
        loc="lower center",
        bbox_to_anchor=(0.5, 0.02),
        ncol=min(6, len(ct_cats)),
        fontsize=12,
        title_fontsize=14,
        frameon=False
    )

    plt.show()



# spatial umap 35 nonzero

In [None]:
# Example parameters
dataset = "spatial_cancer_split1" # dataset column of scv
file_path = "/work/magroup/shared/Heimdall/data/ovarian_cancer/pairs_0_1.h5ad" #dataset to read from 
write_h5ad = "spatial_split1_nonzero_umaps.h5ad" # file to write
all_genes = False # are we considering all genes?
seed = 55 # seed


model_path="/work/magroup/nzh/Heimdall-dev/spatial_umaps_results" # default spatial umap folder
sort_ag_but_no_gene2vec_model_path="/work/magroup/nzh/Heimdall-dev/spatial_split1_sort_ag_results" # special condition model path
csv_file = "/work/magroup/nzh/Heimdall-dev/umaps/spatial_umap.csv"
# Process data
test_adata = process_spatial_data(dataset, file_path, write_h5ad, csv_file, model_path, sort_ag_but_no_gene2vec_model_path, all_genes, seed)

# Visualize results
visualize_embeddings(test_adata)

# spatial umap 35 all_genes

In [None]:
# Example parameters
dataset = "spatial_cancer_split1"
file_path = "/work/magroup/shared/Heimdall/data/ovarian_cancer/pairs_0_1.h5ad"
write_h5ad = "spatial_split1_allgenes_umaps.h5ad"
all_genes = True
seed = 55

model_path="/work/magroup/nzh/Heimdall-dev/spatial_umaps_results"
sort_ag_but_no_gene2vec_model_path="/work/magroup/nzh/Heimdall-dev/spatial_split1_sort_ag_results"
csv_file = "/work/magroup/nzh/Heimdall-dev/umaps/spatial_umap.csv"

# Process data
test_adata = process_spatial_data(dataset, file_path, write_h5ad, csv_file, model_path, sort_ag_but_no_gene2vec_model_path, all_genes, seed)


# Visualize results
visualize_embeddings(test_adata)

# spatial umap 113 nonzero

In [None]:
# Example parameters
dataset = "spatial_cancer_split2"
file_path = "/work/magroup/shared/Heimdall/data/ovarian_cancer/pairs_1_3.h5ad"
write_h5ad = "spatial_split2_nonzero_umaps.h5ad"
all_genes = False
seed = 56

model_path="/work/magroup/nzh/Heimdall-dev/spatial_umaps_results"
sort_ag_but_no_gene2vec_model_path="/work/magroup/nzh/Heimdall-dev/spatial_split1_sort_ag_results"
csv_file = "/work/magroup/nzh/Heimdall-dev/umaps/spatial_umap.csv"

# Process data
test_adata = process_spatial_data(dataset, file_path, write_h5ad, csv_file, model_path, sort_ag_but_no_gene2vec_model_path, all_genes, seed)


# Visualize results
visualize_embeddings(test_adata)

# spatial umap 113 all_genes

In [None]:
# Example parameters

dataset = "spatial_cancer_split2"
file_path = "/work/magroup/shared/Heimdall/data/ovarian_cancer/pairs_1_3.h5ad"
write_h5ad = "spatial_split2_allgenes_umaps.h5ad"
all_genes = True
seed = 55

model_path="/work/magroup/nzh/Heimdall-dev/spatial_umaps_results"
sort_ag_but_no_gene2vec_model_path="/work/magroup/nzh/Heimdall-dev/spatial_split1_sort_ag_results"
csv_file = "/work/magroup/nzh/Heimdall-dev/umaps/spatial_umap.csv"

# Process data
test_adata = process_spatial_data(dataset, file_path, write_h5ad, csv_file, model_path, sort_ag_but_no_gene2vec_model_path, all_genes, seed)


# Visualize results
visualize_embeddings(test_adata)