# Spatial UMAPS

In [None]:
import matplotlib.pyplot as plt
import scanpy as sc
import seaborn as sns
import pandas as pd
import glob
import os
import scanpy as sc
from tqdm import tqdm

def process_data(csv_file, dataset, file_path, write_h5ad, models_folder, alt_models_folder, split_col, all_genes=False, seed=55):
    """
    Process spatial transcriptomics data, match model directories, and create UMAP embeddings.
    
    Parameters:
    -----------
    dataset : str
        Name of the experiment dataset (e.g., "spatial_cancer_split1")
    file_path : str
        Path to the h5ad file (e.g., "/ovarian_cancer/pairs_0_1.h5ad")
    write_h5ad : str
        Output filename for the processed h5ad file
    all_genes : bool, default=False
        Flag to process all genes
    """

    df = pd.read_csv(csv_file)
    # df = df.dropna(how='all')
    df = df.dropna()

    df = df[df["Experiment Name"] == dataset]
    df = df[df["all_genes"] == all_genes]

    print(len(df))

    # Define the mappings
    gene_mapping = {
        "identity": "IdentityFg",
        "pca_hyenadna": "HyenaDNAFg",
        "pca_genept": "GenePTFg",
        "pca_gene2vec": "Gene2VecFg",
        "pca_esm2": "ESM2Fg",
    }

    expression_mapping = {
        "nonzero_2nn": "NonzeroIdentityFe",
        "scfoundation": "scfound",
        "sorting": "SortingFe",
        "binning": "BinningFe",
    }

    cell_mapping = {
        "geneformer": "GeneformerFc",
        "scgpt": "ScGPTFc",
    }

    name_expression_mapping = {
        "nonzero_2nn": "Continuous",
        "scfoundation": "Autobin",
        "sorting": "Sorting",
        "binning": "Binning",
    }

    name_gene_mapping = {
        "identity": "Identity",
        "pca_hyenadna": "HyenaDNA",
        "pca_genept": "GenePT",
        "pca_gene2vec": "Gene2vec",
        "pca_esm2": "ESM2",
    }

    dataset_mapping = {
        "sctab_split1_all": "new_sctab",
    }


    # Container for matched paths
    matched_paths = []
    payload = []

    # Loop through each row
    for idx, row in df.dropna().iterrows():
        lr = row['learning_rate']
        bz = int(row['batchsize'])
        f_gene = row['F_Gene']
        f_expression = row['F_Expression']
        f_cell = row['F_Cell']
        all_genes = row['all_genes']

        dataset = dataset_mapping[row["Experiment Name"]]
        
        # Map to folder name parts
        gene_part = gene_mapping[f_gene]
        expression_part = expression_mapping[f_expression]
        cell_part = cell_mapping[f_cell]
        
        # Build a search pattern
        pattern = f"Heimdall.fg.{gene_part}_Heimdall.fe.{expression_part}_Heimdall.fc.{cell_part}_{dataset}_lr{lr}_bz{bz}_seed{seed}"
        
        # Full search path
        full_pattern = os.path.join(models_folder, pattern)
        
        # Find matching directories
        matches = glob.glob(full_pattern)
        if matches:
            matched_paths.extend(matches)
            payload.append({
                "uuid": name_gene_mapping[f_gene] + "_" + name_expression_mapping[f_expression],
                "path": matches[0]
            })
        else:
            print(f"[WARNING] No match found for config at row {idx}: {full_pattern}, trying different configuration...")
            pattern = f"Heimdall.fg.{gene_part}_Heimdall.fe.{expression_part}_Heimdall.fc.{cell_part}_{dataset}_lr{lr}_bz{bz}_seed{seed}"
            full_pattern = os.path.join(alt_models_folder, pattern)
            drop_zeros = not all_genes
            
            matches2 = glob.glob(full_pattern)

            if matches2:
                matched_paths.extend(matches2)
                payload.append({
                    "uuid": name_gene_mapping[f_gene] + "_" + name_expression_mapping[f_expression],
                    "path": matches2[0]
                })
                print(f" Found!! for {idx}: {pattern}")

            else:
                print(f"[WARNING] No match found for config at row {idx}: {full_pattern}, trying different configuration...")
                pattern = f"Heimdall.fg.{gene_part}_Heimdall.fe.{expression_part}_Heimdall.fc.{cell_part}_{dataset}_lr{lr}_bz{bz}_seed{seed}_ag{drop_zeros}"
                full_pattern = os.path.join(models_folder, pattern)
                matches3 = glob.glob(full_pattern)
    

                if matches3:
                    matched_paths.extend(matches3)
                    payload.append({
                        "uuid": name_gene_mapping[f_gene] + "_" + name_expression_mapping[f_expression],
                        "path": matches2[0]
                    })
                    print(f" Found!! for {idx}: {pattern}")

                else:
                    print(f"[ERROR]: No patthern found for {full_pattern}")


    # Output
    for path in matched_paths:
        print(path)

    print(len(matched_paths))

    paths_df = pd.DataFrame(payload)
    print(paths_df)

    # data_path = "/work/magroup/shared/Heimdall/data/" 
    adata = sc.read_h5ad(file_path) 
    test_adata = adata[adata.obs[split_col] == "test"].copy()  

    for i in tqdm(range(len(paths_df))):     
        row = paths_df.iloc[i]     
        sel_adata = sc.read_h5ad(row["path"] + "/test_adata.h5ad")     
        test_adata.obsm["X_" + row["uuid"]] = sel_adata.obsm["X_umap"].copy()  

    test_adata.write(write_h5ad)
    return test_adata

def visualize_embeddings(test_adata):
    import matplotlib.pyplot as plt
    import seaborn as sns
    import scanpy as sc
    import matplotlib.patches as mpatches

    models = ["Identity", "HyenaDNA", "ESM2", "GenePT", "Gene2vec"]
    tasks  = ["Sorting", "Binning", "Autobin", "Continuous"]
    cats   = test_adata.obs["cell_type"].cat.categories

    # 1) Harmonious hue palette
    sns.set(style="white", context="talk")
    palette = sns.color_palette("husl", n_colors=len(cats))
    color_key = {c: palette[i] for i, c in enumerate(cats)}

    # 2) Create grid with fixed spacing (no constrained_layout)
    fig, axes = plt.subplots(
        len(models),
        len(tasks),
        figsize=(20, 20),
        gridspec_kw={'hspace': 0.3, 'wspace': 0.2}
    )
    axes = axes.reshape(len(models), len(tasks))

    # 3) Plot each embedding and force title to y=1.03
    for i, model in enumerate(models):
        for j, task in enumerate(tasks):
            ax = axes[i, j]
            keys = [k for k in test_adata.obsm if model in k and task in k]
            if not keys:
                ax.axis("off")
                continue

            basis = keys[0][2:]
            sc.pl.embedding(
                test_adata,
                basis=basis,
                color="cell_type",
                palette=[color_key[c] for c in cats],
                size=40,
                alpha=1,
                frameon=False,
                legend_loc=None,
                ax=ax,
                show=False
            )
            ax.set_aspect("equal", "box")
            ax.set_title(f"{model} â€” {task}", fontsize=16, y=1.03)

    # 4) Reserve space at bottom for legend
    fig.tight_layout(rect=[0, 0.07, 1, 1])

    # 5) Unified legend below
    handles = [mpatches.Patch(color=color_key[c], label=c) for c in cats]
    fig.legend(
        handles=handles,
        title="Cell Type",
        loc="lower center",
        bbox_to_anchor=(0.5, 0.02),
        ncol=min(6, len(cats)),
        fontsize=12,
        title_fontsize=14,
        frameon=False
    )

    plt.show()


In [None]:
# Example parameters
dataset = "sctab_split1_all"
file_path = "/work/magroup/shared/Heimdall/data/sctab/tissue_splits_spencer/scTab_GItract_train.h5ad"
write_h5ad = "sctab_split1_nonzero_umaps.h5ad"
all_genes = "FALSE"
seed = 56
csv_file = "tissue_umap.csv"
models_folder = "/work/magroup/nzh/Heimdall-dev/new_sctab_split1_allgenes_results"
alt_models_folder = "/work/magroup/nzh/Heimdall-dev/new_sctab_split1_allgenes_results-og"
alt_models_folder = models_folder
split_col = "split1"
# Process data
test_adata = process_data(csv_file, dataset, file_path, write_h5ad, models_folder, alt_models_folder, split_col, all_genes, seed)

# Visualize results
visualize_embeddings(test_adata)

In [None]:
# Example parameters
dataset = "sctab_split1_all"
file_path = "/work/magroup/shared/Heimdall/data/sctab/tissue_splits_spencer/scTab_GItract_train.h5ad"
write_h5ad = "sctab_split1_allgenes.h5ad"
all_genes = "TRUE"
seed = 55
csv_file = "tissue_umap.csv"
models_folder = "/work/magroup/nzh/Heimdall-dev/new_sctab_split1_allgenes_results"
alt_models_folder = "/work/magroup/nzh/Heimdall-dev/new_sctab_split1_allgenes_results-og"
split_col = "split1"
# Process data
test_adata = process_data(csv_file, dataset, file_path, write_h5ad, models_folder, alt_models_folder, split_col, all_genes, seed)

# Visualize results
visualize_embeddings(test_adata)