In [None]:
# load the dataset
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import normalized_mutual_info_score
from sklearn.preprocessing import MinMaxScaler
import pickle
import matplotlib.colors as mcolors
from scipy.ndimage import gaussian_filter1d
import seaborn as sns
import matplotlib as mpl
from tqdm import tqdm
from matplotlib import cm
from scipy.cluster.hierarchy import linkage, dendrogram,fcluster
import networkx as nx
from scipy.spatial import distance
from scipy.spatial.distance import squareform
from scipy.spatial.distance import pdist
import matplotlib.colors as mcolors
from scipy.ndimage import gaussian_filter
from skimage import measure
mpl.rcParams['pdf.fonttype'] = 42
import scipy
import scanpy as sc
import zarr
# import tarfile
import io
import anndata as ad
import os
import pandas as pd
import anndata as ad
import scanpy as sc
import numpy as np
from scipy.spatial.distance import cdist
import os
import re
from tqdm import tqdm

In [None]:
dat = pd.read_parquet("atlas.parquet")
dat = pd.concat([dat[['Section', 'y_index', 'z_index']], dat.iloc[:,:173]],axis=1)
dat

## Impute lipids onto Langlieb et al. cells by colocalization

In [None]:
all_results = []

for i in tqdm(range(1, 33)):
    try:
        gexpr = pd.read_hdf(f"gexpr_{i}.h5", key="table")
        
        sec = dat.loc[dat['Section'] == i,:]
        sec = sec.loc[sec['z_index'] > 456/2,:]
        
        distance_matrix = cdist(gexpr[['y_index', 'z_index']], 
                              sec[['y_index', 'z_index']], 
                              metric='euclidean')
        distance_df = pd.DataFrame(distance_matrix, 
                                 index=gexpr.index, 
                                 columns=sec.index)
        distance_df_masked = distance_df.copy()
        distance_df_masked[distance_df_masked > 4] = np.nan
        
        # find closest cells and their indices
        closest_indices = distance_df_masked.idxmin(axis=1)
        min_distances = distance_df_masked.min(axis=1)
        
        section_result = pd.DataFrame({
            'closest_cell': closest_indices,
            'distance': min_distances,
            'section': i
        }, index=gexpr.index)
        
        # impute the lipidome of the closest voxel for simplicity
        section_result = section_result.merge(
            sec, 
            left_on='closest_cell',
            right_index=True,
            how='left'
        )
        
        all_results.append(section_result)
        del gexpr, distance_df_masked, distance_df
        
    except Exception as e:
        print(f"Error processing section {i}: {str(e)}")
        continue

final_df = pd.concat(all_results, axis=0)
final_df.to_parquet("cells_lipidimputed.parquet")
final_df = final_df.dropna()
final_df

## Filter, normalize, prepare the imputed transcriptomics data

In [None]:
# prepare their transcriptome by ***concatenation*** of simple assignment of scrna-seq centroids

gexprs = []
sizes = []
for i in tqdm(range(1, 33)):
    try:
        gexpr = pd.read_hdf(f"gexpr_{i}.h5", key="table")
        gexpr['Section'] = i
        print(i)
        gexprs.append(gexpr)
        sizes.append(gexpr.shape[0])
        del gexpr
    except:
        continue
        
gexpr = pd.concat(gexprs)
gexpr.drop_duplicates(inplace=True)
gexpr

In [None]:
gexpr = gexpr[~gexpr.index.duplicated(keep='first')]
final_df = final_df.loc[gexpr.index,:]
final_df = final_df[~final_df.index.duplicated(keep='first')]
final_df = final_df.loc[final_df.index.isin(np.intersect1d(final_df.index, gexpr.index)),:]
gexpr = gexpr.loc[final_df.index,:]

# normalize the transcriptomes
# preprocess it as if it were common scRNA-seq data (as it fundamentally is)
import scanpy as sc
import anndata

adata = anndata.AnnData(gexpr.iloc[:,:-3]) 
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
gexpr.iloc[:,:-3] = adata.X 
gexpr

# save this pseudo-sc multimodal dataset
gexpr = pd.concat([gexpr, final_df],axis=1)
coords = gexpr[["y_index",	"z_index",	"Section"]]
gexpr.drop(["y_index",	"z_index",	"Section"], axis=1,inplace=True)
gexpr.to_parquet("multimodal_on_macoscko.parquet")
gexpr

## Run a multi-omics factor analysis (MOFA)

In [None]:
# read in and prepare the data
gexpr = pd.read_parquet("multimodal_on_macoscko.parquet")
gexpr = gexpr[~gexpr.index.duplicated(keep='first')]
genes = gexpr.iloc[:,:-176]
lipids = gexpr.iloc[:,-173:]

# downsample or MOFA will take 4ever
gexpr = genes[::10]
print(gexpr.shape)

# remove zero-variance features
variance = genes.var()
zero_var_columns = variance[variance < 0.0001].index
print(f"Columns to remove (zero variance): {list(zero_var_columns)}")
genes = genes.drop(columns=zero_var_columns)
lipids = lipids.loc[genes.index,:]
genes.shape

In [None]:
# run a MOFA

import pandas as pd
import numpy as np
from mofapy2.run.entry_point import entry_point

data = [
    [genes],
    [lipids]
]

# Create MOFA object
ent = entry_point()

# Set data options
ent.set_data_options(scale_groups=True, scale_views=True)
ent.set_data_matrix(
    data,
    likelihoods=["gaussian", "gaussian"],  
    views_names=["gene_expression", "lipid_profiles"],
    samples_names=[gexpr.index.tolist()]  # same samples across views
)

# Set model options
ent.set_model_options(
    factors=100,  # number of factors to learn ########################
    spikeslab_weights=True,  # spike and slab sparsity on weights
    ard_weights=True  # Automatic Relevance Determination on weights
)

# Set training options
ent.set_train_options(
    iter=10,########################
    convergence_mode="fast",
    startELBO=1,
    freqELBO=1,
    dropR2=0.001,
    verbose=True
)

# Build and run the model
ent.build()
ent.run()

In [None]:
# get the model output, extract factors and weights
model = ent.model
expectations = model.getExpectations()
factors = expectations["Z"]["E"] 
weights = [w["E"] for w in expectations["W"]]  

# extract the coordinates of single cells in factors embeddings
factors_df = pd.DataFrame(
    factors,
    index=genes.index,
    columns=[f"Factor_{i+1}" for i in range(factors.shape[1])]
)

# extract the contribution of each gene and lipid to each factor, and the top markers of each factor for both modalities
weights_gene = pd.DataFrame(
    weights[0],  # first view (genes)
    index=genes.columns,
    columns=factors_df.columns
)
weights_lipid = pd.DataFrame(
    weights[1],  # second view (lipids)
    index=lipids.columns,
    columns=factors_df.columns
)
def get_top_features(weights_df, n=10):
    top_features = {}
    for factor in weights_df.columns:
        top_pos = weights_df[factor].nlargest(n).index.tolist()
        top_neg = weights_df[factor].nsmallest(n).index.tolist()
        top_features[factor] = {'positive': top_pos, 'negative': top_neg}
    return top_features
top_genes = get_top_features(weights_gene)
top_lipids = get_top_features(weights_lipid)

# example - reassuring, well-known white matter markers for both modalities
print("\nTop features for Factor 1:")
print("Genes (positive):", ", ".join(top_genes['Factor_1']['positive'][:10]))
print("Lipids (positive):", ", ".join(top_lipids['Factor_1']['positive'][:5]))

factors_df.to_csv("minimofa_factors.csv")
weights_gene.to_csv("minimofa_weights_genes.csv")
weights_lipid.to_csv("minimofa_weights_lipids.csv")

In [None]:
factors_df.to_hdf("factors_dfMOFA.h5ad", key="table")

## Characterize the embeddings

In [None]:
# do t-SNE on top of MOFA

import os
import gc
import numpy as np
import pandas as pd
import scanpy as sc
import anndata
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform
from scipy.sparse import csr_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import NMF
from openTSNE import TSNEEmbedding, affinity, initialization
from tqdm import tqdm
from collections import deque
import harmonypy as hm
import networkx as nx
from threadpoolctl import threadpool_limits, threadpool_info

# configure thread limits
threadpool_limits(limits=8)
os.environ['OMP_NUM_THREADS'] = '6'

embds = factors_df

scaler = StandardScaler()
x_train = scaler.fit_transform(embds)

affinities_train = affinity.PerplexityBasedNN(
    x_train,
    perplexity=30,
    metric="euclidean",
    n_jobs=8,
    random_state=42,
    verbose=True,
)

init_train = x_train[:,[0, 1]] # initialize with two factors, note this affects results a bit

embedding_train = TSNEEmbedding(
    init_train,
    affinities_train,
    negative_gradient_method="fft",
    n_jobs=8,
    verbose=True,
)

embedding_train_1 = embedding_train.optimize(n_iter=500, exaggeration=1.2)
np.save("minimofageneslipidsembedding_train_1.npy", np.array(embedding_train_1))
embedding_train_N = embedding_train_1.optimize(n_iter=100, exaggeration=2.5)
np.save("minimofageneslipidsembedding_train_N.npy", np.array(embedding_train_N))

In [None]:
plt.scatter(embedding_train_N[:,0], embedding_train_N[:,1],s=0.005, c=lipids['HexCer 42:2;O2'], vmin = np.percentile(lipids['HexCer 42:2;O2'], 3), vmax = np.percentile(lipids['HexCer 42:2;O2'], 97), cmap="plasma", rasterized=True)

ax = plt.gca()
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
    spine.set_visible(False)

plt.savefig('mofa_hexcer.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
plt.scatter(embedding_train_N[:,0], embedding_train_N[:,1],s=0.005, c=genes['Mbp'], cmap="plasma", rasterized=True)

ax = plt.gca()
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
    spine.set_visible(False)

plt.savefig('mofa_mbp.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd

celltype_classes = np.array([_.split('=')[-1].split('_')[0] for _ in ct])
fig, ax = plt.subplots()
celltype_series = pd.Series(celltype_classes).astype("category")
codes = celltype_series.cat.codes  
categories = celltype_series.cat.categories  # Actual names of each category
sc = ax.scatter(
    embedding_train_N[:, 0],
    embedding_train_N[:, 1],
    s=0.005,
    c=codes,
    cmap="nipy_spectral",
    rasterized=True
)
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
    spine.set_visible(False)
counts = celltype_series.value_counts()
top_7_categories = counts.index[:7]
top_7_codes = [celltype_series.cat.categories.get_loc(cat) for cat in top_7_categories]
cmap = plt.cm.nipy_spectral
num_cats = len(categories)
patches = []
for code, cat in zip(top_7_codes, top_7_categories):
    color = cmap(code / (num_cats - 1) if num_cats > 1 else 0.5)  # handle single-cat edge case
    patches.append(mpatches.Patch(color=color, label=cat))
ax.legend(
    handles=patches,
    bbox_to_anchor=(1.05, 1),
    loc='upper left',
    title="Top 7 Cell Types"
)

plt.savefig('celltypeclasses_top7.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# plot the cell types one at a time in MOFA space to see if some substructure dictated by lipids emerges

ct = pd.read_hdf("celltypesnow.h5ad", key="table")
ct = ct.loc[factors_df.index]
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np

unique_cts = ct.unique()
def clean_cell_type(cell_type):
    if '=' in cell_type:
        return cell_type.split('=')[1]
    return cell_type

with PdfPages('cell_type_MOFA_plots.pdf') as pdf:
    plots_per_page = 100
    total_pages = 17

    for page in range(total_pages):
        start_idx = page * plots_per_page
        end_idx = min(start_idx + plots_per_page, len(unique_cts))
        
        fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
        axes = axes.flatten()
        
        for i, cell_type in enumerate(unique_cts[start_idx:end_idx]):
            ax = axes[i]
            
            ax.scatter(
                embedding_train_N[:, 0],
                embedding_train_N[:, 1],
                c='lightgray',
                s=0.005, rasterized=True
            )
            
            mask = (ct == cell_type)
            ax.scatter(
                embedding_train_N[mask, 0],
                embedding_train_N[mask, 1],
                c='darkred',
                s=0.01, rasterized=True
            )
            
            ax.set_title(clean_cell_type(cell_type), fontsize=8)
            ax.set_xticks([])
            ax.set_yticks([])
        
        for j in range(i+1, 100):
            axes[j].axis('off')
        
        plt.tight_layout()

        pdf.savefig(fig)
        plt.close(fig)

In [None]:
metadata = pd.read_parquet("cells_lipidimputed.parquet")
md = metadata.loc[genes.index, ['z_index', 'y_index', 'Section']]
md = md[~md.index.duplicated(keep='first')]
md

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

subfactors_df = factors_df.loc[md['Section'].isin([7,11,18]),:].iloc[:, :40]

def compute_medians(factors, md, currentLipid):
    results = []
    for section in md['Section'].unique():
        subset = factors[md['Section'] == section]
        p2 = subset[currentLipid].quantile(0.02)
        p98 = subset[currentLipid].quantile(0.98)
        results.append([section, p2, p98])
    df = pd.DataFrame(results, columns=['Section', '2-perc', '98-perc'])
    return df['2-perc'].median(), df['98-perc'].median()

num_per_page = 10
lipid_list = subfactors_df.columns

with PdfPages("lipids_10_per_page.pdf") as pdf:
    for chunk_start in range(0, len(lipid_list), num_per_page):
        chunk = lipid_list[chunk_start : chunk_start + num_per_page]
        
        fig_height = 3 * len(chunk)
        fig_width = 8
        fig = plt.figure(figsize=(fig_width, fig_height))

        for i, currentLipid in enumerate(chunk):
            med2p, med98p = compute_medians(factors_df, md, currentLipid)
            
            rowHeight = 1.0 / len(chunk)
            y0 = 1.0 - (i+1)*rowHeight 
            
            w_scatter = 0.8 
            h_scatter = rowHeight  
            single_ax_width = w_scatter / 3.0
            
            axes = []
            for col_idx in range(3):
                x0 = col_idx * single_ax_width
                ax = fig.add_axes([x0, y0, single_ax_width, h_scatter])
                axes.append(ax)
            
            for ax, section in zip(axes, [7, 11, 18]):
                ddf = pd.concat([
                    md[md['Section'] == section],
                    factors_df[md['Section'] == section]
                ], axis=1)
                sc = ax.scatter(ddf['z_index'], -ddf['y_index'],
                                c=ddf[currentLipid],
                                cmap="PuOr",
                                s=0.5,
                                rasterized=True,
                                vmin=med2p,
                                vmax=med98p)
                ax.axis('off')
                ax.set_aspect('equal')
            
            cbar_left = 0.82
            cbar_bottom = y0 + 0.15*rowHeight 
            cbar_width = 0.015
            cbar_height = 0.7*rowHeight
            cbar_ax = fig.add_axes([cbar_left, cbar_bottom, cbar_width, cbar_height])
            
            norm = Normalize(vmin=med2p, vmax=med98p)
            sm = ScalarMappable(norm=norm, cmap="PuOr")
            fig.colorbar(sm, cax=cbar_ax)
            
            top_ax_left = 0.86
            top_ax_bottom = y0 + 0.55*rowHeight
            top_ax_width = 0.12
            top_ax_height = 0.3*rowHeight
            top_ax = fig.add_axes([top_ax_left, top_ax_bottom, top_ax_width, top_ax_height])
            
            bottom_ax_left = 0.86
            bottom_ax_bottom = y0 + 0.15*rowHeight
            bottom_ax_width = 0.12
            bottom_ax_height = 0.3*rowHeight
            bottom_ax = fig.add_axes([bottom_ax_left, bottom_ax_bottom, bottom_ax_width, bottom_ax_height])
            
            for ann_ax in [top_ax, bottom_ax]:
                ann_ax.set_xlim(0, 1)
                ann_ax.set_ylim(0, 1)
                ann_ax.axis('off')
            
            top_genes_list = top_genes[currentLipid]['positive'][:6]
            top_lipids_list = top_lipids[currentLipid]['positive'][:6]
            bottom_genes_list = top_genes[currentLipid]['negative'][:6]
            bottom_lipids_list = top_lipids[currentLipid]['negative'][:6]
            
            for j, (gene, lipid) in enumerate(zip(top_genes_list, top_lipids_list)):
                y_pos = 0.88 - j * 0.12
                top_ax.text(0.10, y_pos, gene, fontsize=4, va='top')
                top_ax.text(0.65, y_pos, lipid, fontsize=4, va='top')
            
            for j, (gene, lipid) in enumerate(zip(bottom_genes_list, bottom_lipids_list)):
                y_pos = 0.88 - j * 0.12
                bottom_ax.text(0.10, y_pos, gene, fontsize=4, va='top')
                bottom_ax.text(0.65, y_pos, lipid, fontsize=4, va='top')
            axes[0].text(0.0, 1.02, currentLipid, fontsize=9, transform=axes[0].transAxes)
            
        pdf.savefig(fig)
        plt.close(fig)

## Do an alluvial diagram that compares gene ontologies and lipid classes that dictate MOFA factors

In [None]:
positive_f1 = weights_gene.index[weights_gene['Factor_1'] > np.percentile(weights_gene['Factor_1'], 99)]
negative_f1 = weights_gene.index[weights_gene['Factor_1'] < np.percentile(weights_gene['Factor_1'], 1)]
negative_f1

In [None]:
# do gene ontology on such clusters: does it make sense?

# (by colocalization, it's something)

from __future__ import print_function
import os
import numpy as np
import matplotlib.pyplot as plt
import goatools
from goatools.anno.genetogo_reader import Gene2GoReader
from goatools.base import download_go_basic_obo, download_ncbi_associations
from goatools.obo_parser import GODag
from goatools.test_data.genes_NCBI_10090_ProteinCoding import GENEID2NT as GeneID2nt_mus
from goatools.goea.go_enrichment_ns import GOEnrichmentStudyNS
import collections as cx
import pandas as pd
from goatools.godag_plot import plot_gos, plot_results, plot_goid2goobj
from goatools.associations import read_ncbi_gene2go
from goatools.anno.factory import get_objanno
from goatools.go_enrichment import GOEnrichmentStudy
import mygene
import matplotlib as mpl
import matplotlib.pyplot as plt
from goatools.gosubdag.gosubdag import GoSubDag
mpl.rcParams['pdf.fonttype'] = 42

obo_dag = GODag("go-basic.obo")
# wget http://geneontology.org/ontology/go-basic.obo
obo_fname = download_go_basic_obo()

# dictionary of  gene symbols: Gene Ontology terms
associations = read_ncbi_gene2go('gene2go', taxids=[10090],namespace='MF')  # 10090 is the taxid for mouse
obj_ncbi = get_objanno('gene2go', taxid=10090)
associations = obj_ncbi.get_id2gos(namespace='all')

bp_terms = []
for go_id, go_term in obo_dag.items():
    if go_term.namespace == 'biological_process':
        bp_terms.append(go_id)
    elif go_term.namespace == 'molecular_function':
        bp_terms.append(go_id)
    elif go_term.namespace == 'cellular_component':
        bp_terms.append(go_id)
        
go_subdag = GoSubDag(bp_terms, obo_dag)
bp_associations = {}
for gene, terms in associations.items():
    bp_associations[gene] = [term for term in terms if term in bp_terms]
    
mg = mygene.MyGeneInfo()

population_genes = weights_gene.index.values
def convert_symbols_to_entrez(gene_symbols):
    gene_info = mg.querymany(gene_symbols, scopes='symbol', fields='entrezgene', species='mouse')
    entrez_ids = [int(gene['entrezgene']) for gene in gene_info if 'entrezgene' in gene]
    return entrez_ids
population_genes = convert_symbols_to_entrez(population_genes)

In [None]:
# loop over clusters

from itertools import islice

gosPOS = []
gosNEG = []

for f in tqdm(weights_gene.columns):

    positive_f = weights_gene.index[weights_gene[f] > np.percentile(weights_gene[f], 99)]
    negative_f = weights_gene.index[weights_gene[f] < np.percentile(weights_gene[f], 1)]

    markersNOW = positive_f

    study_genes = convert_symbols_to_entrez(markersNOW)

    g = GOEnrichmentStudy(
        population_genes,
        associations,
        obo_dag,
        propagate_counts=False,
        alpha=0.05,  # default significance level
        methods=['fdr_bh']  # use FDR Benjamini-Hochberg correction
    )

    # run the enrichment analysis
    results = g.run_study(study_genes)
    results_df = pd.DataFrame(columns=['GO_name', 'p-value', 'number_items', 'category'])

    for r in results:
        results_df.loc[r.GO] = [r.name, r.p_fdr_bh, r.study_count, obo_dag[r.GO].namespace]

    results_df = results_df.loc[results_df['p-value'] < 0.1,:]

    results_df['cluster'] = cl

    print(results_df['GO_name'][:10])
    gosPOS.append(results_df)


    markersNOW = negative_f

    study_genes = convert_symbols_to_entrez(markersNOW)

    g = GOEnrichmentStudy(
        population_genes,
        associations,
        obo_dag,
        propagate_counts=False,
        alpha=0.05,  # default significance level
        methods=['fdr_bh']  # use FDR Benjamini-Hochberg correction
    )

    # run the enrichment analysis
    results = g.run_study(study_genes)
    results_df = pd.DataFrame(columns=['GO_name', 'p-value', 'number_items', 'category'])

    for r in results:
        results_df.loc[r.GO] = [r.name, r.p_fdr_bh, r.study_count, obo_dag[r.GO].namespace]

    results_df = results_df.loc[results_df['p-value'] < 0.1,:]

    results_df['cluster'] = cl

    print(results_df['GO_name'][:10])
    gosNEG.append(results_df)

In [None]:
import pickle
pickle_filename = 'mofa.pkl'
with open(pickle_filename, 'wb') as file:
    pickle.dump(model, file)

In [None]:
import pickle

pickle_filename = 'gosPOS.pkl'
with open(pickle_filename, 'wb') as file:
    pickle.dump(gosPOS, file)
pickle_filename = 'gosNEG.pkl'
with open(pickle_filename, 'wb') as file:
    pickle.dump(gosNEG, file)

In [None]:
## start by keeping only the very top terms...

i = 0
xs = []
for go in gosPOS:
    i = i + 1
    gotmp = go.loc[(go['p-value'] < 0.01) & (go['category'] == "biological_process"),:]
    x = pd.DataFrame(-np.log(gotmp['p-value']).values, index = gotmp['GO_name'], columns = ["Factor_"+str(i)+"_+"])
    xs.append(x)
    
i = 0
for go in gosNEG:
    i = i + 1
    gotmp = go.loc[(go['p-value'] < 0.01) & (go['category'] == "biological_process"),:]
    x = pd.DataFrame(-np.log(gotmp['p-value']).values, index = gotmp['GO_name'], columns = ["Factor_"+str(i)+"_-"])
    xs.append(x)
    
goandfact = pd.concat(xs, axis=1).fillna(0)
goandfact = goandfact.loc[:,goandfact.sum() > 0]
goandfact # index needs cleaning...

In [None]:
def permutation_test_categorical(
    test_labels, 
    other_labels, 
    n_permutations=10_000, 
    alternative='two-sided', 
    random_state=None
):
    """
    Perform a permutation test to assess whether each category in test_labels 
    is over- or under-represented compared to what we would expect by chance.
    
    Parameters
    ----------
    test_labels : 1D array-like of categorical labels (the "test" set)
    other_labels : 1D array-like of categorical labels (all non-test elements)
    n_permutations : int, optional
        Number of random permutations
    alternative : {'two-sided', 'greater', 'less'}, optional
        - 'two-sided': tests if the proportion differs in either direction
        - 'greater': tests if test_labels has a higher proportion of the category
        - 'less': tests if test_labels has a lower proportion of the category
    random_state : int, optional
        If provided, sets the random seed for reproducibility
    
    Returns
    -------
    results : pd.DataFrame
        A DataFrame with columns: 'category', 'observed_count', 'expected_count',
        'observed_proportion', 'expected_proportion', 'p_value'
    """
    if random_state is not None:
        np.random.seed(random_state)
    
    test_labels = np.array(test_labels)
    other_labels = np.array(other_labels)
    
    all_labels = np.concatenate([test_labels, other_labels])
    n_test = len(test_labels)
    unique_categories = np.unique(all_labels)
    
    # calculate expected proportions from full dataset
    total_counts = {cat: np.sum(all_labels == cat) for cat in unique_categories}
    expected_props = {cat: count/len(all_labels) for cat, count in total_counts.items()}
    
    # observed counts and proportions
    observed_counts = {cat: np.sum(test_labels == cat) for cat in unique_categories}
    observed_props = {cat: count/n_test for cat, count in observed_counts.items()}
    
    # store permutation counts
    perm_counts = {cat: np.zeros(n_permutations) for cat in unique_categories}
    for i in range(n_permutations):
        np.random.shuffle(all_labels)
        perm_test = all_labels[:n_test]
        for cat in unique_categories:
            perm_counts[cat][i] = np.sum(perm_test == cat)
    results = []
    for cat in unique_categories:
        observed = observed_counts[cat]
        distribution = perm_counts[cat]
        expected = expected_props[cat] * n_test
        
        if alternative == 'two-sided':
            # count permutations that deviate from expected as much as or more than observed
            observed_dev = abs(observed - expected)
            p_value = np.mean(abs(distribution - expected) >= observed_dev)
        
        elif alternative == 'greater':
            # count permutations where count >= observed
            p_value = np.mean(distribution >= observed)
        
        elif alternative == 'less':
            # count permutations where count <= observed
            p_value = np.mean(distribution <= observed)
        
        results.append({
            'category': cat,
            'observed_count': observed,
            'expected_count': expected,
            'observed_proportion': observed_props[cat],
            'expected_proportion': expected_props[cat],
            'p_value': p_value
        })
    
    return pd.DataFrame(results)

import re

df = pd.DataFrame(weights_lipid.index)
df.columns = ["lipid_name"]

# extract the "class" etc from the lipid_name
df["class"] = df["lipid_name"].apply(lambda x: 
    "PC O" if x.startswith("PC O") else
    "PE O" if x.startswith("PE O") else
    re.split(' |\(', x)[0]
)
df["carbons"] = df["lipid_name"].apply(lambda x: int(re.search(r'(\d+):', x).group(1)) if re.search(r'(\d+):', x) else np.nan)
df["insaturations"] = df["lipid_name"].apply(lambda x: int(re.search(r':(\d+)', x).group(1)) if re.search(r':(\d+)', x) else np.nan)
df["insaturations_per_Catom"] = df["insaturations"] / df["carbons"]
df["broken"] = df["lipid_name"].str.endswith('_uncertain')
df.loc[df["broken"], 'carbons'] = np.nan
df.loc[df["broken"], 'class'] = np.nan
df.loc[df["broken"], 'insaturations'] = np.nan
df.loc[df["broken"], 'insaturations_per_Catom'] = np.nan
colors = pd.read_hdf("lipidclasscolors.h5ad", key="table")
df['color'] = df['class'].map(colors['classcolors'])
df.loc[df["broken"], 'color'] = "gray"
df.index = df['lipid_name']
df = df.drop_duplicates()
df

In [None]:
ys = []

for f in tqdm(weights_lipid.columns):

    positive_f = weights_lipid.index[weights_lipid[f] > np.percentile(weights_lipid[f], 90)]
    negative_f = weights_lipid.index[weights_lipid[f] < np.percentile(weights_lipid[f], 10)]

    class_enrichments = permutation_test_categorical(
    df.loc[positive_f.values, 'class'], df.loc[np.setdiff1d(weights_lipid.index.values, positive_f), 'class'], 
    n_permutations=5000, 
    alternative='two-sided', 
    random_state=42
    )
    class_enrichments = class_enrichments.loc[(class_enrichments['p_value'] < 0.1),:]
    y = pd.DataFrame(-np.log(class_enrichments['p_value']).values, index = class_enrichments['category'], columns = [f+"_+"])
    ys.append(y)
    
    class_enrichments = permutation_test_categorical(
    df.loc[negative_f.values, 'class'], df.loc[np.setdiff1d(weights_lipid.index.values, negative_f), 'class'], 
    n_permutations=5000, 
    alternative='two-sided', 
    random_state=42
    )
    class_enrichments = class_enrichments.loc[(class_enrichments['p_value'] < 0.1),:]
    y = pd.DataFrame(-np.log(class_enrichments['p_value']).values, index = class_enrichments['category'], columns = [f+"_-"])
    ys.append(y)

In [None]:
loandfact = pd.concat(ys, axis=1).fillna(0)
lipidsandfact = loandfact.loc[:,goandfact.columns]

# clip outliers
lipidsandfact[lipidsandfact > 10] = 10
goandfact[goandfact > 10] = 10
lipidsandfact = lipidsandfact.loc[:,lipidsandfact.sum() > 0]
goandfact = goandfact.loc[:, lipidsandfact.columns]
goandfact

In [None]:
import plotly.express as px
import numpy as np

def build_color_dict(terms, palette=px.colors.qualitative.Dark24):
    """
    Assign each term a color (in hex or "rgb(...)") from the chosen palette.
    Cycles if there are more terms than colors.
    """
    color_dict = {}
    pal_len = len(palette)
    for i, term in enumerate(terms):
        color_dict[term] = palette[i % pal_len]  # e.g., "#1f77b4"
    return color_dict

def to_rgba(hex_or_rgb_str, alpha=1.0):
    """
    Convert a Plotly palette color (like "#1f77b4" or "rgb(31,119,180)")
    into an RGBA string with the specified alpha.
    """
    if hex_or_rgb_str.startswith("#"):
        hex_val = hex_or_rgb_str.lstrip("#")
        r = int(hex_val[0:2], 16)
        g = int(hex_val[2:4], 16)
        b = int(hex_val[4:6], 16)
        return f"rgba({r},{g},{b},{alpha})"
    elif hex_or_rgb_str.startswith("rgb("):
        inside = hex_or_rgb_str.strip("rgb() ")
        r, g, b = inside.split(",")
        return f"rgba({r.strip()},{g.strip()},{b.strip()},{alpha})"
    else:
        return hex_or_rgb_str

go_terms = goandfact.index.tolist()
factors = goandfact.columns.tolist()
lipid_terms = lipidsandfact.index.tolist()
go_color_dict = build_color_dict(go_terms, px.colors.qualitative.Dark24)
lipid_color_dict = build_color_dict(lipid_terms, px.colors.qualitative.Light24)
go_factor_values = []
for go_term in go_terms:
    for factor in factors:
        flow_value = goandfact.loc[go_term, factor]
        go_factor_values.append(flow_value)

p70 = np.percentile(go_factor_values, 70)  # keep flows >= this

import plotly.graph_objects as go

node_labels = go_terms + factors + lipid_terms
label_to_index = {label: idx for idx, label in enumerate(node_labels)}

# Node colors
# if it's a GO term, color = go_color_dict[term]
# if it's a lipid term, color = lipid_color_dict[term]
factor_color = "rgba(160,160,160,1.0)"  # light gray for factors
node_colors = []
for lbl in node_labels:
    if lbl in go_color_dict:
        node_colors.append(to_rgba(go_color_dict[lbl], alpha=1.0))
    elif lbl in lipid_color_dict:
        node_colors.append(to_rgba(lipid_color_dict[lbl], alpha=1.0))
    else:
        node_colors.append(factor_color)

# build links (source, target, value) + link colors
sources = []
targets = []
values = []
link_colors = []

# --- GO -> Factor ---
for go_term in go_terms:
    for factor in factors:
        flow_value = goandfact.loc[go_term, factor]
        # only keep flows in top 30% (>= p70)
        if flow_value >= p70:
            go_idx = label_to_index[go_term]
            factor_idx = label_to_index[factor]
            sources.append(go_idx)
            targets.append(factor_idx)
            values.append(flow_value)
            # color of flow = GO color with partial alpha
            base_col = go_color_dict[go_term]
            link_colors.append(to_rgba(base_col, alpha=0.4))

# --- Factor -> Lipid ---
for lipid_term in lipid_terms:
    for factor in factors:
        flow_value = lipidsandfact.loc[lipid_term, factor]
        # let's keep all for now
        if flow_value > 0:
            factor_idx = label_to_index[factor]
            lipid_idx = label_to_index[lipid_term]
            sources.append(factor_idx)
            targets.append(lipid_idx)
            values.append(flow_value)
            # color of flow = lipid color
            base_col = lipid_color_dict[lipid_term]
            link_colors.append(to_rgba(base_col, alpha=1.0))

# construct the Sankey figure
fig = go.Figure(data=[go.Sankey(
    arrangement="snap",
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=node_labels,
        color=node_colors
    ),
    link=dict(
        source=sources,
        target=targets,
        value=values,
        color=link_colors
    )
)])

fig.update_layout(
    font_size=10,
    height=2000
)
fig.write_image("sankey.pdf", format="pdf", engine="kaleido")
fig.show()

In [None]:
goandfactbu = goandfact.copy()
goandfact = goandfact.loc[(goandfact >0).sum(axis=1) > 2,:]
goandfact = goandfact.loc[:,goandfact.sum() > 0]
lipidsandfact = lipidsandfact.loc[:, goandfact.columns]
goandfact

In [None]:
# visualize a subset zoom in

import plotly.express as px
import numpy as np

def build_color_dict(terms, palette=px.colors.qualitative.Dark24):
    color_dict = {}
    pal_len = len(palette)
    for i, term in enumerate(terms):
        color_dict[term] = palette[i % pal_len]
    return color_dict

def to_rgba(hex_or_rgb_str, alpha=1.0):
    if hex_or_rgb_str.startswith("#"):
        hex_val = hex_or_rgb_str.lstrip("#")
        r = int(hex_val[0:2], 16)
        g = int(hex_val[2:4], 16)
        b = int(hex_val[4:6], 16)
        return f"rgba({r},{g},{b},{alpha})"
    elif hex_or_rgb_str.startswith("rgb("):
        inside = hex_or_rgb_str.strip("rgb() ")
        r, g, b = inside.split(",")
        return f"rgba({r.strip()},{g.strip()},{b.strip()},{alpha})"
    else:
        return hex_or_rgb_str

go_terms = goandfact.index.tolist()
factors = goandfact.columns.tolist()
lipid_terms = lipidsandfact.index.tolist()

go_color_dict = build_color_dict(go_terms, px.colors.qualitative.Dark24)
lipid_color_dict = build_color_dict(lipid_terms, px.colors.qualitative.Light24)

go_factor_values = []
for go_term in go_terms:
    for factor in factors:
        flow_value = goandfact.loc[go_term, factor]
        go_factor_values.append(flow_value)

p70 = np.percentile(go_factor_values, 70)

import plotly.graph_objects as go

node_labels = go_terms + factors + lipid_terms
label_to_index = {label: idx for idx, label in enumerate(node_labels)}

factor_color = "rgba(160,160,160,1.0)"

node_colors = []
for lbl in node_labels:
    if lbl in go_color_dict:
        node_colors.append(to_rgba(go_color_dict[lbl], alpha=1.0))
    elif lbl in lipid_color_dict:
        node_colors.append(to_rgba(lipid_color_dict[lbl], alpha=1.0))
    else:
        node_colors.append(factor_color)

sources = []
targets = []
values = []
link_colors = []

for go_term in go_terms:
    for factor in factors:
        flow_value = goandfact.loc[go_term, factor]
        if flow_value >= p70:
            go_idx = label_to_index[go_term]
            factor_idx = label_to_index[factor]
            sources.append(go_idx)
            targets.append(factor_idx)
            values.append(flow_value)
            base_col = go_color_dict[go_term]
            link_colors.append(to_rgba(base_col, alpha=0.4))

for lipid_term in lipid_terms:
    for factor in factors:
        flow_value = lipidsandfact.loc[lipid_term, factor]
        if flow_value > 0:
            factor_idx = label_to_index[factor]
            lipid_idx = label_to_index[lipid_term]
            sources.append(factor_idx)
            targets.append(lipid_idx)
            values.append(flow_value)
            base_col = lipid_color_dict[lipid_term]
            link_colors.append(to_rgba(base_col, alpha=1.0))

fig = go.Figure(data=[go.Sankey(
    arrangement="snap",
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=node_labels,
        color=node_colors
    ),
    link=dict(
        source=sources,
        target=targets,
        value=values,
        color=link_colors
    )
)])

fig.update_layout(
    font_size=10,
    height=2000,
)
fig.write_image("sankey.pdf", format="pdf", engine="kaleido")
fig.show()