# scFoundation Embedding Evaluation – Dataset 1
This interactive notebook evaluates the `X_scFoundation` embedding and **prints** cluster quality metrics instead of saving them to a file.


## Imports & parameters

In [3]:

import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score, confusion_matrix
import pathlib, os

DATA_PATH = pathlib.Path('/work/scratch/ndickenmann/scfoundation_dataset2.h5ad')
OUTPUT_DIR = pathlib.Path('/work/scratch/ndickenmann/dataset2_results')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
sc.settings.figdir = OUTPUT_DIR

K_NEIGHBORS = 50
COMPONENTS_RANGE = range(2, 16)
RANDOM_STATE = 0
EMBED_KEY = 'X_scFoundation'


## Load dataset

In [None]:
if not DATA_PATH.is_file():
    raise FileNotFoundError(f"Data file not found: {DATA_PATH}")
adata = sc.read_h5ad(DATA_PATH)
print(adata)
print('Embedding present:', EMBED_KEY in adata.obsm_keys())


## Build 50‑nearest‑neighbor graph

In [None]:
sc.pp.neighbors(adata, n_neighbors=K_NEIGHBORS, use_rep=EMBED_KEY, random_state=RANDOM_STATE)

In [None]:
adata.write_h5ad(OUTPUT_DIR / 'scfoundation_dataset2_with_50_neighbors.h5ad')

## Clustering

In [None]:

# Leiden & Louvain
sc.tl.leiden(adata, key_added='leiden', random_state=RANDOM_STATE)
sc.tl.louvain(adata, key_added='louvain', flavor ="igraph",  random_state=RANDOM_STATE)

In [None]:

# k‑means
n_states = adata.obs['cell_state'].nunique()
km = KMeans(n_clusters=n_states, random_state=RANDOM_STATE).fit(adata.obsm[EMBED_KEY])
adata.obs['kmeans'] = km.labels_.astype(str)

# GMM fixed k
gmm_fixed = GaussianMixture(n_components=n_states, covariance_type='full', random_state=RANDOM_STATE).fit(adata.obsm[EMBED_KEY])
adata.obs['gmm_fixed'] = gmm_fixed.predict(adata.obsm[EMBED_KEY]).astype(str)

# BIC‑optimised GMM
bic_vals = []
for k in COMPONENTS_RANGE:
    gm = GaussianMixture(n_components=k, covariance_type='full', random_state=RANDOM_STATE).fit(adata.obsm[EMBED_KEY])
    bic_vals.append(gm.bic(adata.obsm[EMBED_KEY]))
best_k = COMPONENTS_RANGE[int(np.argmin(bic_vals))]
gmm_best = GaussianMixture(n_components=best_k, covariance_type='full', random_state=RANDOM_STATE).fit(adata.obsm[EMBED_KEY])
adata.obs['gmm_bic'] = gmm_best.predict(adata.obsm[EMBED_KEY]).astype(str)

print('Best k by BIC:', best_k)
cluster_keys = ['leiden', 'louvain', 'kmeans', 'gmm_fixed', 'gmm_bic']


## Print cluster quality metrics

In [None]:

metrics = []
for key in cluster_keys:
    labels_true = adata.obs['cell_state']
    labels_pred = adata.obs[key]
    ari = adjusted_rand_score(labels_true, labels_pred)
    nmi = normalized_mutual_info_score(labels_true, labels_pred)
    sil = silhouette_score(adata.obsm[EMBED_KEY], pd.Categorical(labels_pred).codes)
    metrics.append({'method': key, 'ARI': ari, 'NMI': nmi, 'silhouette': sil})

metrics_df = pd.DataFrame(metrics).set_index('method')
display(metrics_df.style.format('{:.3f}'))


## Confusion matrices (saved as CSV/PNG)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

for key in cluster_keys:
    cm = confusion_matrix(adata.obs['cell_state'], adata.obs[key])
    cm_df = pd.DataFrame(
        cm,
        index=sorted(adata.obs['cell_state'].unique()),
        columns=sorted(adata.obs[key].unique())
    )
    display(cm_df)                         # show the DataFrame
    plt.figure(figsize=(6,4))
    sns.heatmap(cm_df, annot=True, fmt='d')  # draw a heatmap
    plt.title(f'Confusion matrix: {key}')
    plt.ylabel('True')
    plt.xlabel('Predicted')
    plt.show()
    cm_df.to_csv(OUTPUT_DIR / f'confusion_{key}.csv')

## UMAP visualisations

In [None]:

sc.tl.umap(adata, random_state=RANDOM_STATE)
sc.pl.umap(adata, color=['cell_state'], show=True)
batch_key = next((k for k in ['donor_id','sample'] if k in adata.obs.columns), None)
if batch_key:
    sc.pl.umap(adata, color=[batch_key], show=True)
sc.pl.umap(adata, color=cluster_keys, show=True)


## Save annotated AnnData

In [None]:

adata.write_h5ad(OUTPUT_DIR / 'dataset1_annotated.h5ad')
print('Annotated AnnData saved.')
