In [2]:
"""
Clustering module for EUCLID.
This module performs:
  - Conventional Leiden clustering on harmonized NMF embeddings.
  - A self-supervised, locally enhanced hierarchical (Euclid) clustering.
  - Assignment of colors to clusters.
  - Application of the learnt clustering tree to new data.
  - Anatomical naming of clusters and generation of cluster inspection PDFs.
  
All functions work on an AnnData object produced by the embedding module.
"""

import os
import pickle
import warnings
import random
import itertools
import json
from datetime import datetime
import cProfile
import pstats

import joblib
import numpy as np
import pandas as pd
import anndata
import scanpy as sc
import squidpy as sq
import backSPIN #############################
import leidenalg
import networkx as nx
import igraph as ig

from matplotlib import colors as mcolors
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

from sklearn.cluster import KMeans, DBSCAN
from sklearn.decomposition import PCA, NMF
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.utils import resample

from xgboost import XGBClassifier
import xgboost as xgb
from imblearn.under_sampling import RandomUnderSampler

from scipy.cluster.hierarchy import linkage, fcluster
from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks
from scipy.spatial.distance import squareform, pdist
from scipy.sparse import csr_matrix
from scipy.stats import mannwhitneyu, entropy
from statsmodels.stats.multitest import multipletests

from threadpoolctl import threadpool_limits
from tqdm import tqdm
from kneed import KneeLocator
from PyPDF2 import PdfMerger

# Set thread limits and suppress warnings
threadpool_limits(limits=8)
os.environ['OMP_NUM_THREADS'] = '6'
warnings.filterwarnings('ignore')

In [3]:
# =============================================================================
# Define a Node class for storing the hierarchical clustering tree
# =============================================================================
class Node:
    def __init__(self, level, path=None):
        self.level = level
        self.path = path if path is not None else []
        self.scaler = None
        self.nmf = None
        self.xgb_model = None
        self.feature_importances = None  # feature importances at the split
        self.children = {}
        self.factors_to_use = None

In [4]:
# -------------------------------------------------------------------------
# Utility functions (internal)
# -------------------------------------------------------------------------
def _compute_seeded_NMF(data, gamma_min=0.8, gamma_max=1.5, gamma_num=100):
    """
    Private method to compute seeded NMF (as in embedding) on the given data.
    Parameters
    ----------
    data : pd.DataFrame
        DataFrame of pixels x lipids.
    gamma_min : float, optional
        Minimum gamma value for Leiden resolution search. (Default is 0.8)
    gamma_max : float, optional
        Maximum gamma value for Leiden resolution search. (Default is 1.5)
    gamma_num : int, optional
        Number of gamma values to try. (Default is 100)
    Returns
    -------
    nmfdf : pd.DataFrame
        NMF factor matrix (W).
    factor_to_lipid : np.ndarray
        The H matrix (components x lipids).
    N_factors : int
        Number of factors.
    nmf_model : NMF
        Fitted NMF model.
    """
    # 1. Calculate correlation matrix
    corr = np.corrcoef(data.values.T)
    corr_matrix = np.abs(corr)
    np.fill_diagonal(corr_matrix, 0)
    # Build dummy AnnData for neighbors
    adata_dummy = anndata.AnnData(X=np.zeros_like(corr_matrix))
    adata_dummy.obsp['connectivities'] = csr_matrix(corr_matrix)
    adata_dummy.uns['neighbors'] = {
        'connectivities_key': 'connectivities',
        'distances_key': 'distances',
        'params': {'n_neighbors': 10, 'method': 'custom'}
    }
    G = nx.from_numpy_array(corr_matrix)
    gamma_values = np.linspace(gamma_min, gamma_max, num=gamma_num) 
    num_communities = []
    modularity_scores = []
    objective_values = []
    for gamma in gamma_values:
        sc.tl.leiden(adata_dummy, resolution=gamma, key_added=f'leiden_{gamma}')
        clusters = adata_dummy.obs[f'leiden_{gamma}'].astype(int).values
        num_comms = len(np.unique(clusters))
        num_communities.append(num_comms)
        partition = [np.where(clusters == i)[0] for i in range(num_comms)]
        modularity = nx.community.modularity(G, partition)
        modularity_scores.append(modularity)
    epsilon = 1e-10
    alpha = 0.7
    for Q, N_c in zip(modularity_scores, num_communities):
        f_gamma = Q**alpha * np.log(N_c + 1 + epsilon)
        objective_values.append(f_gamma)
    max_index = np.argmax(objective_values)
    best_gamma = gamma_values[max_index]
    best_num_comms = num_communities[max_index]
    sc.tl.leiden(adata_dummy, resolution=best_gamma, key_added='leiden_best')
    clusters = adata_dummy.obs['leiden_best'].astype(int).values
    N_factors = best_num_comms
    # 4. Choose representative lipid per cluster
    dist = 1 - corr_matrix
    np.fill_diagonal(dist, 0)
    dist = np.maximum(dist, dist.T)
    representatives = []
    for i in range(N_factors):
        cluster_members = np.where(clusters == i)[0]
        if len(cluster_members) > 0:
            mean_dist = dist[cluster_members][:, cluster_members].mean(axis=1)
            central_idx = cluster_members[np.argmin(mean_dist)]
            representatives.append(central_idx)
    W_init = data.values[:, representatives]
    H_init = corr[representatives, :]
    H_init[H_init < 0] = 0.
    N_factors = W_init.shape[1]
    nmf = NMF(n_components=N_factors, init='custom', random_state=42)
    data_offset = data - np.min(data) + 1e-7
    data_offset = np.ascontiguousarray(data_offset)
    W_init = np.ascontiguousarray(W_init)
    H_init = np.ascontiguousarray(H_init)
    W = nmf.fit_transform(data_offset, W=W_init, H=H_init)
    nmf_result = nmf.transform(data_offset)
    nmfdf = pd.DataFrame(nmf_result, index=data.index)
    factor_to_lipid = nmf.components_
    return nmfdf, factor_to_lipid, N_factors, nmf

def _continuity_check(spat, spat_columns=['zccf','yccf','Section'], 
                       min_val_threshold=10, min_nonzero_sections=3, gaussian_sigma=1.8, default_peak_ratio=10): 
    """
    Check whether clusters are continuous along the AP axis.
    Parameters
    ----------
    spat : np.array or DataFrame
        Array or DataFrame with spatial coordinates.
    spat_columns : list, optional
        List of column names to use for the spatial coordinates. (Default is ['zccf','yccf','Section'])
    min_val_threshold : int, optional
        Minimum value threshold to consider a section count significant. (Default is 10)
    min_nonzero_sections : int, optional
        Minimum number of sections with nonzero counts required. (Default is 3)
    gaussian_sigma : float, optional
        Sigma value for Gaussian smoothing. (Default is 1.8)
    default_peak_ratio : float, optional
        Default peak ratio if insufficient peaks are found. (Default is 10)
    Returns
    -------
    Tuple of continuity flags and peak info.
    """
    dd2 = pd.DataFrame(spat, columns=spat_columns)
    dd2['Section'] = dd2['Section'].astype(int)
    vcnorm = dd2['Section'].astype(int).value_counts()
    vcnorm.index = vcnorm.index.astype(int)
    vcnorm = vcnorm.sort_index()
    enough_sectionss = []
    number_of_peakss = []
    peak_ratios = []
    # For each unique cluster color we expect two values
    for cluster in [0, 1]:
        test = dd2  # assuming test is subset for a given cluster
        value_counts = test['Section'].value_counts().sort_index()
        ap = value_counts.values.copy()
        ap[ap < min_val_threshold] = 0
        ap_nonnull = np.sum(ap > 0) > min_nonzero_sections
        apflag = any(ap[i] != 0 and ap[i+1] != 0 for i in range(len(ap)-1))
        enough_sections = ap_nonnull and apflag
        ap_norm = value_counts / vcnorm.loc[value_counts.index].values
        ap_norm = np.array(ap_norm)
        zero_padded_ap = np.pad(ap_norm, pad_width=1, mode='constant', constant_values=0)
        smoothed_ap = gaussian_filter1d(zero_padded_ap, sigma=gaussian_sigma)
        peaks, properties = find_peaks(smoothed_ap, height=0)
        number_of_peaks = len(peaks)
        if number_of_peaks > 1:
            peak_heights = properties['peak_heights']
            top_peaks = np.sort(peak_heights)[-2:]
            peak_ratio = top_peaks[1] / top_peaks[0]
        else:
            peak_ratio = default_peak_ratio
        enough_sectionss.append(enough_sections)
        number_of_peakss.append(number_of_peaks)
        peak_ratios.append(peak_ratio)
    return enough_sectionss[0], enough_sectionss[1], number_of_peakss[0], number_of_peakss[1], peak_ratios[0], peak_ratios[1]

def _differential_lipids(lipidata, kmeans_labels, min_fc=0.2, pthr=0.05):
    """
    Compare two groups (assumed binary) for differential lipids.
    Returns the number of altered lipids and a table of promoted ones.
    """
    results = []
    a = lipidata[kmeans_labels == 0, :]
    b = lipidata[kmeans_labels == 1, :]
    for rrr in range(lipidata.shape[1]):
        groupA = a[:, rrr]
        groupB = b[:, rrr]
        meanA = np.mean(groupA)
        meanB = np.mean(groupB)
        log2fold_change = np.abs(np.log2(meanB/meanA)) if meanA > 0 and meanB > 0 else np.nan
        try:
            _, p_value = mannwhitneyu(groupA, groupB, alternative='two-sided')
        except ValueError:
            p_value = np.nan
        results.append({'lipid': rrr, 'log2fold_change': log2fold_change, 'p_value': p_value})
    results_df = pd.DataFrame(results)
    reject, pvals_corrected, _, _ = multipletests(results_df['p_value'].values, alpha=0.05, method='fdr_bh')
    results_df['p_value_corrected'] = pvals_corrected
    promoted = results_df[(results_df['log2fold_change'] > min_fc) & (results_df['p_value_corrected'] < pthr)]
    alteredlips = np.sum((results_df['log2fold_change'] > min_fc) & (results_df['p_value_corrected'] < pthr))
    return alteredlips, promoted

def _rank_features_by_combined_score(tempadata):
    """
    Rank features by combining variance-of-variances and mean variances.
    """
    sections = tempadata.obsm['spatial'][:, 2]
    unique_sections = np.unique(sections)
    var_of_vars = []
    mean_of_vars = []
    for i in range(tempadata.X.shape[1]):
        feature_values = tempadata.X[:, i]
        section_variances = []
        for sec in unique_sections:
            section_values = feature_values[sections == sec]
            section_variance = np.var(section_values)
            section_variances.append(section_variance)
        var_of_vars.append(np.var(section_variances))
        mean_of_vars.append(np.mean(section_variances))
    var_of_vars = np.array(var_of_vars) / np.mean(var_of_vars)
    mean_of_vars = np.array(mean_of_vars) / np.mean(mean_of_vars)
    combined_score = -var_of_vars/2 + mean_of_vars
    ranked_indices = np.argsort(combined_score)[::-1]
    return ranked_indices

def _find_elbow_point(values):
    """
    Find the elbow point in cumulative absolute loadings.
    """
    sorted_values = np.sort(np.abs(values))[::-1]
    cumulative_variance = np.cumsum(sorted_values) / np.sum(sorted_values)
    kneedle = KneeLocator(range(1, len(cumulative_variance)+1), cumulative_variance, curve='concave', direction='increasing')
    elbow = kneedle.elbow
    return elbow

def _generate_combinations(n, limit=200):
    """
    Generate sorted combinations (of component indices) to try for splitting.
    """
    all_combinations = []
    for r in range(n, 0, -1):
        for comb in itertools.combinations(range(n), r):
            all_combinations.append(comb)
            if len(all_combinations) >= limit:
                return all_combinations
    return all_combinations

def _leidenalg_clustering(inputdata, Nneigh=40, Niter=5):
    """
    Faster Leiden clustering using leidenalg.
    """
    nn = NearestNeighbors(n_neighbors=Nneigh, n_jobs=4)
    nn.fit(inputdata)
    knn = nn.kneighbors_graph(inputdata)
    G = nx.Graph(knn)
    g = ig.Graph.from_networkx(G)
    partitions = leidenalg.find_partition(g, leidenalg.ModularityVertexPartition, n_iterations=Niter, seed=230598)
    labels = np.array(partitions.membership)
    return labels

def _undersample(X, y, sampling_strategy='auto'):
    """
    Under-sample majority class.
    """
    rus = RandomUnderSampler(sampling_strategy=sampling_strategy, random_state=42)
    X_res, y_res = rus.fit_resample(X, y)
    return X_res, y_res

In [None]:
# DESTROY THE FUNCTION

In [26]:
K=60
min_voxels=150
min_diff_lipids=2
min_fc=0.2
pthr=0.05
thr_signal=1e-10
penalty1=1.5
penalty2=2
ACCTHR=0.6
max_depth=15
ds_factor=1
spat_columns=['zccf','yccf','Section']
min_val_threshold=10
min_nonzero_sections=3
gaussian_sigma=1.8 
default_peak_ratio=10 
peak_count_threshold=3 
peak_ratio_threshold=1.4 
combinations=200
xgb_n_estimators=1000
xgb_max_depth=8 
xgb_learning_rate=0.02  
xgb_subsample=0.8 
xgb_colsample_bytree=0.8 
xgb_gamma=0.5
xgb_random_state=42
xgb_n_jobs=6 
early_stopping_rounds=7 

In [10]:
import scanpy as sc
import anndata as ad

adata = ad.read_h5ad(filename='/data/luca/lipidatlas/euclid/euclid_msi/my_adata.h5ad')
adata

AnnData object with n_obs × n_vars = 122246 × 151
    obs: 'SectionID', 'x', 'y', 'Path', 'Sample', 'Sex', 'Condition', 'Section', 'BadSection', 'xccf', 'yccf', 'zccf', 'x_index', 'y_index', 'z_index', 'boundary', 'acronym', 'id', 'name', 'structure_id_path', 'structure_set_ids', 'rgb_triplet', 'allencolor', 'division', 'SectionPlot'
    var: 'old_feature_names'
    uns: 'feature_selection_scores'
    obsm: 'X_01norm', 'X_Harmonized', 'X_NMF', 'X_TSNE', 'X_approximated'

In [11]:
DS = 1

In [12]:
from sklearn.preprocessing import StandardScaler
standardized_embeddings_GLOBAL = pd.DataFrame(StandardScaler().fit_transform(adata.obsm['X_Harmonized']),
                                              index=adata.obs_names)[::DS]
metadata = adata.obs.copy()[::DS]
coordinates = metadata[['x','y','SectionID', 'SectionID']][::DS]
coordinates.columns = ["zccf","yccf","Section","xccf"]
coordinates

Unnamed: 0,zccf,yccf,Section,xccf
section1_pixel18_104,18,104,1,1
section1_pixel18_105,18,105,1,1
section1_pixel18_106,18,106,1,1
section1_pixel18_110,18,110,1,1
section1_pixel18_112,18,112,1,1
...,...,...,...,...
section3_pixel229_142,229,142,3,3
section3_pixel229_143,229,143,3,3
section3_pixel229_144,229,144,3,3
section3_pixel229_145,229,145,3,3


In [16]:
reconstructed_data_df = adata.obsm['X_approximated']

In [24]:
K=60
min_voxels=150
min_diff_lipids=2
min_fc=0.2
pthr=0.05
ACCTHR=0.6
max_depth=3 ######
min_nonzero_sections=1 ######
gaussian_sigma=1.8
peak_count_threshold=3
peak_ratio_threshold=1.4
combinations=3

In [18]:
unique_sections = coordinates['Section'].unique()
print("Unique sections:", unique_sections)
num_sections = len(unique_sections)

# Check if we have enough sections for the original splitting approach
if num_sections >= 3:
    # Initialize empty arrays
    valsec = np.array([], dtype=int)
    testsec = np.array([], dtype=int)

    # Apply the proportional rule, but ensure at least one section in each split
    if num_sections >= 5:
        valsec = (unique_sections[::5] + 2)[:-1]
        testsec = (unique_sections[::5] + 1)[:-1]
    else:
        # For fewer sections but still >=3, assign at least one to each group
        valsec = np.array([unique_sections[0]])
        testsec = np.array([unique_sections[1]])

    # Double check that validation and test sections are not empty
    if len(valsec) == 0:
        valsec = np.array([unique_sections[0]])
    if len(testsec) == 0:
        # Avoid overlap with validation
        for sec in unique_sections:
            if sec not in valsec:
                testsec = np.array([sec])
                break

    # The rest go to training
    trainsec = np.setdiff1d(np.setdiff1d(unique_sections, testsec), valsec)
    if len(trainsec) == 0:
        if len(valsec) > len(testsec):
            trainsec = np.array([valsec[-1]])
            valsec = valsec[:-1]
        else:
            trainsec = np.array([testsec[-1]])
            testsec = testsec[:-1]

    print("Validation sections (valsec):", valsec)
    print("Test sections (testsec):", testsec)
    print("Train sections (trainsec):", trainsec)

    # Identify point indices for each group
    valpoints = coordinates.loc[coordinates['Section'].isin(valsec),:].index
    testpoints = coordinates.loc[coordinates['Section'].isin(testsec),:].index
    trainpoints = coordinates.loc[coordinates['Section'].isin(trainsec),:].index

else:
    # Classic 60-20-20 split on the rows (ignoring sections)
    print("Less than 3 unique sections found. Using 60-20-20 split on rows.")

    # Shuffle indices to ensure random selection
    all_indices = coordinates.index.values
    np.random.shuffle(all_indices)

    # Calculate split sizes
    n_samples = len(all_indices)
    n_train = int(0.6 * n_samples)
    n_val = int(0.2 * n_samples)

    # Split indices
    trainpoints = all_indices[:n_train]
    valpoints = all_indices[n_train:n_train+n_val]
    testpoints = all_indices[n_train+n_val:]

    # For consistency with the section-based approach
    trainsec = np.array([])
    valsec = np.array([])
    testsec = np.array([])

Unique sections: [1 2 3]
Validation sections (valsec): [1]
Test sections (testsec): [2]
Train sections (trainsec): [3]


In [19]:
# Prepare data for clustering
data = pd.DataFrame(reconstructed_data_df.copy(), index=standardized_embeddings_GLOBAL.index)
print("Data shape:", data.shape)
rawlips = data.copy()
print("Raw lipids data shape:", rawlips.shape)

Data shape: (122246, 151)
Raw lipids data shape: (122246, 151)


In [20]:
# Normalize raw data using percentiles (2% and 98%)
p2 = rawlips.quantile(0.02)
p98 = rawlips.quantile(0.98)
print("2nd percentile values:\n", p2)
print("98th percentile values:\n", p98)
normalized_values = (rawlips.values - p2.values) / (p98.values - p2.values)
print("Normalized values shape:", normalized_values.shape)
clipped_values = np.clip(normalized_values, 0, 1)
normalized_datemp = pd.DataFrame(clipped_values, columns=rawlips.columns, index=rawlips.index)
print("Normalized and clipped data shape:", normalized_datemp.shape)

2nd percentile values:
 0      0.004364
1      0.004416
2      0.005842
3      0.004443
4      0.005088
         ...   
146    0.004375
147    0.004620
148    0.004348
149    0.004357
150    0.004332
Name: 0.02, Length: 151, dtype: float64
98th percentile values:
 0      0.004405
1      0.004563
2      0.009400
3      0.004828
4      0.006201
         ...   
146    0.004590
147    0.005765
148    0.004569
149    0.004740
150    0.005182
Name: 0.98, Length: 151, dtype: float64
Normalized values shape: (122246, 151)
Normalized and clipped data shape: (122246, 151)


In [21]:
# Prepare a Scanpy object with raw lipids for differential testing.
adata = sc.AnnData(X=data)
adata.obsm['spatial'] = coordinates[['zccf','yccf','Section']].loc[data.index].values
adata.obsm['lipids'] = normalized_datemp
print("Created AnnData with shape:", adata.shape)
print("Sample spatial coordinates:\n", adata.obsm['spatial'][:5])
print("Sample lipid data:\n", normalized_datemp.iloc[:5])

Created AnnData with shape: (122246, 151)
Sample spatial coordinates:
 [[ 18 104   1]
 [ 18 105   1]
 [ 18 106   1]
 [ 18 110   1]
 [ 18 112   1]]
Sample lipid data:
                            0         1         2         3         4    \
section1_pixel18_104  0.411354  0.684223  1.000000  0.322268  0.828163   
section1_pixel18_105  0.097681  0.634793  0.761229  0.422297  0.468835   
section1_pixel18_106  0.105777  0.601587  0.856403  0.490948  0.554111   
section1_pixel18_110  0.205442  0.569963  0.751861  0.259524  0.548884   
section1_pixel18_112  0.199275  0.719065  0.810193  0.456313  0.536024   

                           5         6    7         8         9    ...  141  \
section1_pixel18_104  0.716551  0.013703  0.0  0.721167  0.562022  ...  0.0   
section1_pixel18_105  0.342687  0.000000  0.0  0.886240  0.606591  ...  0.0   
section1_pixel18_106  0.423469  0.000000  0.0  1.000000  0.646003  ...  0.0   
section1_pixel18_110  0.371520  0.000000  0.0  0.452979  0.541941  ...  

In [27]:
# Initialize a log DataFrame for clustering history.
column_names = [f"level_{i}" for i in range(1, max_depth+1)]
clusteringLOG = pd.DataFrame(0, index=data.index, columns=column_names)[::ds_factor]
print("Initialized clustering log with shape:", clusteringLOG.shape)

Initialized clustering log with shape: (122246, 15)


In [None]:









# Define the recursive splitting function.
def _dosplit(current_adata, embds, path=[], splitlevel=0):
    print("\n=== Entering _dosplit at level:", splitlevel, "with", current_adata.X.shape[0], "voxels ===")
    if current_adata.X.shape[0] < min_voxels:
        print("Branch exhausted due to low voxel count:", current_adata.X.shape[0])
        return None

    # Compute a local NMF on current data
    nmfdf, loadings, N_factors, nmf_model = _compute_seeded_NMF(pd.DataFrame(current_adata.X, index=current_adata.obs_names))
    nmf_result = nmfdf.values
    print("Computed NMF. nmfdf shape:", nmfdf.shape, "N_factors:", N_factors)
    print("Loadings shape:", loadings.shape)
    print("Mean absolute values per factor:", np.abs(nmf_result).mean(axis=0))

    filter1 = np.abs(nmf_result).mean(axis=0) > thr_signal
    loadings_sel = loadings[filter1, :]
    nmf_result = nmf_result[:, filter1]
    original_nmf_indices = np.arange(N_factors)[filter1]
    print("Selected", filter1.sum(), "factors after filtering out of", len(filter1))
    print("Shape of filtered nmf_result:", nmf_result.shape)

    tempadata = sc.AnnData(X=nmf_result)
    tempadata.obsm['spatial'] = current_adata.obsm['spatial']

    # Rank features by combined score
    goodpcs = _rank_features_by_combined_score(tempadata)
    print("Ranked features (goodpcs):", goodpcs)
    goodpcs_indices = original_nmf_indices[goodpcs.astype(int)]
    top_pcs_data = nmf_result[:, goodpcs.astype(int)]
    loadings_sel = loadings_sel[goodpcs.astype(int), :]
    print("Top PCs data shape:", top_pcs_data.shape)
    print("Indices of good principal components:", goodpcs_indices)

    multiplets = _generate_combinations(len(goodpcs), limit=combinations)
    print("Generated multiplets count:", len(multiplets))
    flag = False
    aaa = 0

    # Begin iterative search for acceptable split
    while (not flag) and (aaa < len(multiplets)):
        bestpcs = multiplets[aaa]
        print("\nIteration", aaa, "using bestpcs indices:", bestpcs)
        embeddings_local = top_pcs_data[:, bestpcs]
        loadings_current = loadings_sel[list(bestpcs), :]
        selected_nmf_indices = goodpcs_indices[list(bestpcs)]
        scaler_local = StandardScaler()
        standardized_embeddings = scaler_local.fit_transform(embeddings_local)
        print("Standardized embeddings shape:", standardized_embeddings.shape)

        # Combine with previous split and global embeddings
        globembds = standardized_embeddings_GLOBAL.loc[current_adata.obs_names].values / penalty2
        embspace = np.concatenate((standardized_embeddings, embds/penalty1, globembds), axis=1)
        print("Combined embedding space shape:", embspace.shape)

        kmeans = KMeans(n_clusters=K, random_state=230598)
        kmeans_labels = kmeans.fit_predict(embspace)
        print("KMeans labels distribution:", np.bincount(kmeans_labels))

        # Reaggregate via backSPIN (using its API)
        data_for_clustering = pd.DataFrame(current_adata.X, index=current_adata.obs_names, columns=current_adata.var_names)
        data_for_clustering['label'] = kmeans_labels
        centroids = data_for_clustering.groupby('label').mean()
        centroids = pd.DataFrame(StandardScaler().fit_transform(centroids), columns=centroids.columns, index=centroids.index).T
        print("Centroids shape after standardization and transpose:", centroids.shape)
        row_ix, columns_ix = backSPIN.SPIN(centroids, widlist=4)
        centroids = centroids.iloc[row_ix, columns_ix]
        print("Centroids shape after backSPIN reordering:", centroids.shape)
        _, _, _, gr1, gr2, _, _, _, _ = backSPIN._divide_to_2and_resort(sorted_data=centroids.values, wid=5)
        gr1 = np.array(centroids.columns)[gr1]
        gr2 = np.array(centroids.columns)[gr2]
        print("Division groups sizes: gr1 =", len(gr1), "gr2 =", len(gr2))
        data_for_clustering['lab'] = 1
        data_for_clustering.loc[data_for_clustering['label'].isin(gr2), 'lab'] = 2

        # Check continuity along AP axis using coordinates and differential lipids in adata.obsm['lipids']
        enough_sections0, enough_sections1, num_peaks0, num_peaks1, peak_ratio0, peak_ratio1 = _continuity_check(
            current_adata.obsm['spatial'], 
            spat_columns=spat_columns, 
            min_val_threshold=min_val_threshold,
            min_nonzero_sections=min_nonzero_sections, 
            gaussian_sigma=gaussian_sigma, 
            default_peak_ratio=default_peak_ratio 
        )
        print("Continuity check results:",
              "enough_sections0 =", enough_sections0,
              "enough_sections1 =", enough_sections1,
              "num_peaks0 =", num_peaks0,
              "num_peaks1 =", num_peaks1,
              "peak_ratio0 =", peak_ratio0,
              "peak_ratio1 =", peak_ratio1)

        alteredlips, promoted = _differential_lipids(current_adata.obsm['lipids'].values, kmeans_labels, min_fc, pthr)
        print("Differential lipids count:", alteredlips, "Promoted:", promoted)

        flag = ((np.sum(kmeans_labels == 1) > min_voxels or np.sum(kmeans_labels == 0) > min_voxels)
                and (alteredlips > min_diff_lipids)
                and enough_sections0 and enough_sections1
                and ((num_peaks0 < peak_count_threshold) or (peak_ratio0 > peak_ratio_threshold)) 
                and ((num_peaks1 < peak_count_threshold) or (peak_ratio1 > peak_ratio_threshold)))

        print(np.sum(kmeans_labels == 1) > min_voxels)
        print(np.sum(kmeans_labels == 0) > min_voxels)
        print(alteredlips > min_diff_lipids)



        print("Flag condition evaluated to:", flag)
        aaa += 1
        kmeans_labels = data_for_clustering['lab'].astype(int)

    if not flag:
        print("Branch exhausted due to failure of continuity or differential criteria.")
        return None

    # Train an XGB classifier on the embeddings
    embeddings_df = pd.DataFrame(embspace, index=current_adata.obs_names)
    print("Embeddings dataframe shape:", embeddings_df.shape)

    X_train = embeddings_df.loc[embeddings_df.index.isin(trainpoints), :]
    X_val = embeddings_df.loc[embeddings_df.index.isin(valpoints), :]
    X_test = embeddings_df.loc[embeddings_df.index.isin(testpoints), :]
    print("Training set shape:", X_train.shape)
    print("Validation set shape:", X_val.shape)
    print("Test set shape:", X_test.shape)

    kmeans_labels = kmeans_labels - 1
    y_train = kmeans_labels.loc[X_train.index]
    y_val = kmeans_labels.loc[X_val.index]
    y_test = kmeans_labels.loc[X_test.index]

    X_train_sub, y_train_sub = _undersample(X_train, y_train)
    print("After undersampling, training set shape:", X_train_sub.shape)

    xgb_model = XGBClassifier(  
        n_estimators=xgb_n_estimators, 
        max_depth=xgb_max_depth, 
        learning_rate=xgb_learning_rate,
        subsample=xgb_subsample, 
        colsample_bytree=xgb_colsample_bytree, 
        gamma=xgb_gamma, 
        random_state=xgb_random_state,  
        n_jobs=xgb_n_jobs  
    )
    print("Training XGB classifier...")
    xgb_model.fit(
        X_train_sub,
        y_train_sub,
        eval_set=[(X_val, y_val)],
        # callbacks=[xgb.callback.EarlyStopping(rounds=early_stopping_rounds)],  #### CURRENTLY FROZEN DUE TO PACKAGE INCOMPATIBILITIES THAT ARE NON TRIVIAL TO FIX.
        verbose=False
    )
    test_pred = xgb_model.predict(X_test)
    test_acc = accuracy_score(y_test, test_pred)
    print(f"Test accuracy: {test_acc}")
    if test_acc < ACCTHR:
        print("Branch exhausted due to poor classifier generalization.")
        return None

    # Overwrite cluster labels with classifier predictions (for consistency)
    new_labels = pd.concat([pd.Series(xgb_model.predict(X_train), index=X_train.index),
                             pd.Series(xgb_model.predict(X_val), index=X_val.index),
                             pd.Series(xgb_model.predict(X_test), index=X_test.index)])
    new_labels = new_labels.loc[embeddings_df.index]
    new_labels = new_labels + 1  # adjust if needed
    print("New labels distribution:\n", new_labels.value_counts())

    # Update clustering log
    clusteringLOG.loc[new_labels.index, f"level_{splitlevel+1}"] = new_labels.values
    print("Updated clustering log for level", splitlevel+1)

    # Create a Node for this split
    node = Node(splitlevel, path=path)
    node.scaler = scaler_local
    node.nmf = nmf_model
    node.xgb_model = xgb_model
    node.feature_importances = xgb_model.feature_importances_
    node.factors_to_use = selected_nmf_indices
    print("Created node at level", splitlevel, "with factors:", selected_nmf_indices)

    # Recursively split the two branches
    idx0 = embeddings_df.index[new_labels == 1]
    idx1 = embeddings_df.index[new_labels == 2]
    print("Branch indices - group 1:", idx0, "\nBranch indices - group 2:", idx1)
    adata0 = current_adata[current_adata.obs_names.isin(idx0)]
    adata1 = current_adata[current_adata.obs_names.isin(idx1)]
    embd0 = embeddings_df.loc[idx0].values
    embd1 = embeddings_df.loc[idx1].values
    print("Recursing on child 0 with", adata0.X.shape[0], "voxels")
    child0 = _dosplit(adata0, embd0, path + [0], splitlevel + 1)
    print("Recursing on child 1 with", adata1.X.shape[0], "voxels")
    child1 = _dosplit(adata1, embd1, path + [1], splitlevel + 1)
    node.children[0] = child0
    node.children[1] = child1
    print("Completed _dosplit at level", splitlevel)
    return node

# Start the recursive clustering from the root
print("\n=== Starting recursive clustering ===")
root_node = _dosplit(adata[::ds_factor], standardized_embeddings_GLOBAL[::ds_factor].values, path=[], splitlevel=0)
print("Recursive clustering complete.")

# Save the clustering log and tree to file.
clusteringLOG.to_parquet("tree_clustering.parquet")
with open("rootnode_clustering.pkl", "wb") as f:
    pickle.dump(root_node, f)
print("Clustering log and root node saved to file.")

return root_node, clusteringLOG