## Sphingolipid score

In [None]:
import pandas as pd

data = pd.read_parquet("./zenodo/maindata_2.parquet")
sample_list = ['Male1', 'Male2', 'Male3', 
               'Female1', 'Female2', 'Female3', 
               'Pregnant1', 'Pregnant2', 'Pregnant4']

data = data.loc[data['Sample'].isin(sample_list),:]

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

data_subset = data[data['Sample'].isin(sample_list)]
data_oligo = data_subset
sphingo_cols = [col for col in data_subset.columns 
                if col.startswith('HexCer') or col.startswith('Cer') or col.startswith('SM')]
data_sphingo = data_oligo.loc[:, data_oligo.columns.isin(sphingo_cols)]
data_lipids = data_sphingo
data_lipids_z = data_lipids.apply(lambda col: (col - col.mean()) / col.std(), axis=0)
data_lipids_z[['supertype', 'Sample']] = data_subset[['supertype', 'Sample']]
grouped_data = data_lipids_z.groupby(['supertype', 'Sample']).mean()
supertypexsample = grouped_data.mean(axis=1).unstack()
corr = data[['supertype', 'class']].drop_duplicates()
corr.index = corr.supertype
grouped_data = supertypexsample.loc[corr.index[corr["class"] == "111"],:]

import pandas as pd
import matplotlib.pyplot as plt

df_centered = grouped_data.sub(grouped_data.mean(axis=1), axis=0)

new_df = pd.DataFrame({
    'Male'    : df_centered.filter(regex='^Male').mean(axis=1),
    'Female'  : df_centered.filter(regex='^Female').mean(axis=1),
    'Pregnant': df_centered.filter(regex='^Pregnant').mean(axis=1),
})

max_dev = new_df.mean(axis=1).abs().max()
print("Max absolute deviation from zero:", max_dev)

fig, ax = plt.subplots(figsize=(6,4))

order = ['Male','Female','Pregnant']
ddd  = [new_df[col].values for col in order]
positions = [1, 2, 3]
colors    = ['blue','pink','purple']

vp = ax.violinplot(ddd, positions=positions, showmeans=False, showmedians=True)

for body, color in zip(vp['bodies'], colors):
    body.set_facecolor(color)
    body.set_edgecolor('black')
    body.set_alpha(0.7)
for part in ('cmedians', 'cmins', 'cmaxes', 'cbars'):
    vp[part].set_color('black')
    vp[part].set_edgecolor('black')
for vals in new_df[order].values:
    ax.plot(positions, vals, color='gray', alpha=0.3, linewidth=0.8)

ax.set_xticks(positions)
ax.set_xticklabels(order)
ax.set_ylabel('Centered mean by supertype')
ax.set_title('Category means after row-centering')
plt.tight_layout()
plt.savefig("violins.pdf")
plt.show()

## Define comodulation

In [None]:
import numpy as np
import pandas as pd
import networkx as nx

def compute_edge_modulation_scores(
    foldchanges: pd.DataFrame,
    metabolicmodule: pd.DataFrame,
    comodulation_clusters: dict,
    min_component_size: int = 2
) -> dict:
    """
    Compute edge modulation scores per cluster, with pre-filtering to avoid NaNs,
    and restrict to connected components of the metabolic network with at least
    `min_component_size` lipids.

    Parameters
    ----------
    foldchanges : pd.DataFrame
        DataFrame of shape (n_supertypes, n_lipids) containing log2FC values.
        Rows are supertypes, columns are lipids.
    metabolicmodule : pd.DataFrame
        Adjacency boolean matrix (lipids x lipids) indicating metabolic edges.
    comodulation_clusters : dict
        Mapping from cluster labels to lists of supertypes in each cluster.
    min_component_size : int
        Minimum number of lipids in a connected component to keep.

    Returns
    -------
    modulation_scores : dict
        Dictionary mapping each cluster label to a DataFrame (lipids x lipids)
        of modulation scores for edges within that cluster.
    """
    common_lipids = foldchanges.columns.intersection(metabolicmodule.index)
    fc = foldchanges[common_lipids]
    adj = metabolicmodule.loc[common_lipids, common_lipids].astype(int)

    # drop lipids with zero variance across supertypes
    stds = fc.std(axis=0)
    zero_std = stds[stds == 0].index
    if len(zero_std) > 0:
        fc = fc.drop(columns=zero_std)
        adj = adj.drop(index=zero_std, columns=zero_std)

    # keep only connected components with >= min_component_size
    G0 = nx.from_pandas_adjacency(adj)
    large_comps = [comp for comp in nx.connected_components(G0) if len(comp) >= min_component_size]
    large_lipids = set().union(*large_comps)
    fc = fc[sorted(large_lipids)]
    adj = adj.loc[sorted(large_lipids), sorted(large_lipids)]

    # compute Z-scores per lipid across all supertypes
    zscores = fc.sub(fc.mean(axis=0), axis=1).div(fc.std(axis=0), axis=1)

    abs_z = zscores.abs().T 
    abs_z[fc.T == 0] = 0 
    edge_scores = {}
    for supertype in zscores.index:
        arr = abs_z[supertype].values
        edge_vals = arr[:, None]*arr[None, :]
        edge_mat = pd.DataFrame(edge_vals, index=abs_z.index, columns=abs_z.index)
        edge_scores[supertype] = edge_mat * adj

    modulation_scores = {}
    for cluster_label, members in comodulation_clusters.items():
        valid = [m for m in members if m in edge_scores]
        if not valid:
            raise ValueError(f"No valid supertypes in cluster {cluster_label} after filtering.")
        summed = sum(edge_scores[m] for m in valid)
        modulation_scores[cluster_label] = summed / len(valid)

    return modulation_scores

In [None]:
import matplotlib.pyplot as plt

reactions = pd.read_csv("./zenodo/csv/corereactions_wenzymes.csv", index_col=0)
unique_reagents = set(reactions['reagent'].unique())
unique_products = set(reactions['product'].unique())
unique_species = sorted(unique_reagents.union(unique_products))

metabolicmodule = pd.DataFrame(
    data=0,
    index=unique_species,
    columns=unique_species,
    dtype=int
)

for _, row in reactions.iterrows():
    reagent = row['reagent']
    product = row['product']
    metabolicmodule.at[reagent, product] = 1
    
plt.imshow(metabolicmodule)

## Prepare the log2FCs, filtered by significance of expressed and 98- Bayes

In [None]:
shift = pd.read_parquet("shift_pregnancy.parquet")
baseline = pd.read_parquet("baseline_pregnancy.parquet")
significance = pd.read_parquet("sign_significance_pregnancy.parquet")
exp = pd.read_parquet("exp_pregnancy.parquet").T
shif = pd.read_parquet("shif_pregnancy.parquet").T

shifts = np.log2((shift+baseline)/baseline).fillna(0)

shifts[~(exp & shif)] = 0.0

todrop = ['11211222', '11222222', '11221121', '11221111', '11212121',
       '11221112', '22112212']
shifts = shifts.drop(todrop)

shifts = shifts.drop(['TG 72:9', 'TG 67:2', 'HexCer 36:1:O2'], axis=1) # remove corrupted lipids
shifts

## Prepare the pixel data and lipizone centroids

In [None]:
from threadpoolctl import threadpool_limits, threadpool_info
threadpool_limits(limits=8)
import os
os.environ['OMP_NUM_THREADS'] = '6'

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from jax import random
import jax
import jax.numpy as jnp
import jax.nn as jnn
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
import optax
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
import statsmodels.formula.api as smf
import arviz as az
from sklearn.preprocessing import LabelEncoder
from scipy.spatial import cKDTree
import scipy.stats as stats
from jax.ops import segment_sum
from numba import njit
import matplotlib.gridspec as gridspec
import random as py_random
from scipy.stats import norm

from euclid_hbda import *

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

sub_alldata = load_data()
coords      = sub_alldata[['x','y','SectionID']]
sub_alldata

In [None]:
import numpy as np
import matplotlib.pyplot as plt

x = sub_alldata['tsne1'].values
y = sub_alldata['tsne2'].values

c_all = sub_alldata['HexCer 42:2;O2'].values

vmin, vmax = np.percentile(c_all, [2, 98])
mask_preg   = sub_alldata['Condition'] == 'pregnant'
mask_not    = ~mask_preg

fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)

for ax, mask, title in zip(axes,
                           [mask_preg, mask_not],
                           ['Pregnant', 'Not pregnant']):
    sc = ax.scatter(
        x[mask], y[mask],
        c=c_all[mask],
        cmap='plasma',
        vmin=vmin, vmax=vmax,
        s=0.001,          
        alpha=1.0,    
        rasterized=True
    )
    ax.set_title(title)
    ax.set_aspect('equal')

cbar = fig.colorbar(sc, ax=axes, location='right', fraction=0.02, pad=0.04)
cbar.set_label('tsne1 value')  
plt.tight_layout()
plt.savefig("tsnes.pdf")
plt.show()

In [None]:
import pandas as pd

df = sub_alldata.copy().loc[sub_alldata['Condition'] == "naive",:]
features = df.columns[:173]
lower = df[features].quantile(0.005)
upper = df[features].quantile(0.995)
df_clipped = df.copy()
df_clipped[features] = df_clipped[features].clip(lower=lower, upper=upper, axis=1)
df_clipped[features] = (df_clipped[features] - lower) / (upper - lower)
centroids = df_clipped.groupby('supertype')[features].mean()

## Inspect the overall log2FC matrix across supertypes x lipids and compute comodulation clusters

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.cluster.hierarchy as sch
from scipy.spatial.distance import pdist
from matplotlib.colors import ListedColormap, to_rgba
from mpl_toolkits.axes_grid1 import make_axes_locatable
import re
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

sub_alldata['class_color'] = (
    sub_alldata
    .groupby('class')['lipizone_color']
    .transform(lambda x: x.mode().iat[0] if not x.mode().empty else np.nan)
)

allen_color_dic = (
    sub_alldata
    .groupby('supertype')['allencolor']
    .agg(lambda x: x.mode().iat[0] if not x.mode().empty else None)
    .to_dict()
)

class_color_dic = (
    sub_alldata
    .groupby('supertype')['class_color']
    .agg(lambda x: x.mode().iat[0] if not x.mode().empty else None)
    .to_dict()
)

subclass_color_dict_tmp = {'11111': '#8000ff',
 '11112': '#7a09ff',
 '11121': '#7216ff',
 '11122': '#6a22fe',
 '11211': '#4c50fc',
 '11212': '#445cfb',
 '11221': '#386df9',
 '11222': '#3176f8',
 '12111': '#09a9ee',
 '12112': '#04b9ea',
 '12121': '#22d6e0',
 '12122': '#2fe0db',
 '12211': '#42edd3',
 '12212': '#4df3ce',
 '12221': '#64fbc3',
 '12222': '#72febb',
 '21111': '#94fda8',
 '21112': '#a2f9a0',
 '21120': '#b6f193',
 '21211': '#cbe486',
 '21212': '#ded579',
 '21221': '#f0c46c',
 '21222': '#feb562',
 '22111': '#ff9850',
 '22112': '#ff8e4a',
 '22121': '#ff7e41',
 '22122': '#ff703a',
 '22211': '#ff4d27',
 '22212': '#ff3e1f',
 '22221': '#ff2613',
 '22222': '#ff0000'}

sub_alldata['continuum'] = sub_alldata['subclass'].map(subclass_color_dict_tmp)

subclass_color_dict = (
    sub_alldata
    .groupby('supertype')['continuum']
    .agg(lambda x: x.mode().iat[0] if not x.mode().empty else None)
    .to_dict()
)
subclass_color_dict

def cosine_clean(u, v):
    mask = (u != 0) | (v != 0)
    if not mask.any():
        return 0.0
    u2, v2 = u[mask], v[mask]
    nu, nv = np.linalg.norm(u2), np.linalg.norm(v2)
    if nu == 0 or nv == 0:
        return 1.0
    sim = np.dot(u2, v2) / (nu * nv)
    return 1.0 - sim

def optimal_reorder_dataframe_cosine_clean(df, method='weighted', thresh=0.8):
    n_rows, n_cols = df.shape
    row_frac = (df != 0).sum(axis=1) / n_cols
    col_frac = (df != 0).sum(axis=0) / n_rows
    rows_dense = row_frac > thresh
    cols_dense = col_frac > thresh
    df_main = df.loc[~rows_dense, ~cols_dense]
    col_d = pdist(df_main.T.values, metric=cosine_clean)
    col_L = sch.linkage(col_d, method=method, optimal_ordering=True)
    col_order = sch.leaves_list(col_L)
    row_d = pdist(df_main.values, metric=cosine_clean)
    row_L = sch.linkage(row_d, method=method, optimal_ordering=True)
    row_order = sch.leaves_list(row_L)
    cols_main = df_main.columns.tolist()
    rows_main = df_main.index.tolist()
    ordered_cols_main = [cols_main[i] for i in col_order]
    ordered_rows_main = [rows_main[i] for i in row_order]
    ordered_cols = ordered_cols_main + df.columns[cols_dense].tolist()
    ordered_rows = ordered_rows_main + df.index[rows_dense].tolist()
    df_reordered = df.loc[ordered_rows, ordered_cols]
    return df_reordered, row_L, col_L, ordered_rows, ordered_cols, rows_dense, cols_dense, ordered_rows_main

def generate_distinct_colors(n):
    if n <= 20:
        return plt.cm.tab20(np.linspace(0, 1, min(n, 20)))
    hues = np.linspace(0, 1, n, endpoint=False)
    return [plt.cm.hsv(h) for h in hues]

df = shifts.copy()
(df_opt, row_L, col_L,
 ordered_rows, ordered_cols,
 rows_dense, cols_dense,
 ordered_rows_main) = optimal_reorder_dataframe_cosine_clean(df, method='weighted', thresh=0.5)
k_row = 16
clusters_main = sch.fcluster(row_L, t=k_row, criterion='maxclust')
df_main = df.loc[~rows_dense, ~cols_dense]
clusters_main_ordered = [clusters_main[df_main.index.get_loc(lbl)] for lbl in ordered_rows_main]
n_dense_rows = rows_dense.sum()
row_clusters_full = np.concatenate([clusters_main_ordered, np.full(n_dense_rows, k_row+1, dtype=int)])

ddf = pd.DataFrame(df.columns, columns=["lipid_name"]).fillna('')
ddf["class"] = ddf["lipid_name"].apply(lambda x: re.split(r'[ \(]', x)[0])
ddf["carbons"] = ddf["lipid_name"].str.extract(r'(\d+):').astype(float)
ddf["insaturations"] = ddf["lipid_name"].str.extract(r':(\d+)').astype(float)
ddf["insaturations_per_Catom"] = ddf["insaturations"] / ddf["carbons"]
ddf["broken"] = ddf["lipid_name"].str.endswith('_uncertain')
ddf.loc[ddf["broken"], ['class','carbons','insaturations','insaturations_per_Catom']] = np.nan
colors = pd.read_hdf("lipidclasscolors.h5ad", key="table")
ddf['color'] = ddf['class'].map(colors['classcolors']).fillna("#000000")
ddf.loc[ddf["broken"], 'color'] = "#888888"
ddf.index = ddf['lipid_name']
lipid_colors = [ddf.loc[col, 'color'] if col in ddf.index else "#888888" for col in df_opt.columns]

allen_rgba = [to_rgba(allen_color_dic.get(r, '#888888')) for r in df_opt.index]
class_rgba  = [to_rgba(class_color_dic.get(r, '#888888')) for r in df_opt.index]
subclass_rgba = [to_rgba(subclass_color_dict.get(r, '#888888')) for r in df_opt.index]
nonzero_counts = (df_opt != 0).sum(axis=1).values

fig, ax = plt.subplots(figsize=(16, 10))
im = ax.imshow(df_opt.values, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
ax.set_xticks([])
ax.set_yticks([])
divider = make_axes_locatable(ax)
cax_allen = divider.append_axes("left", size="2%", pad=0.05)
cax_allen.imshow(np.array(allen_rgba)[:, None, :], aspect='auto')
cax_allen.set_xticks([])
cax_allen.set_yticks([])
cax_class = divider.append_axes("left", size="2%", pad=0.05)
cax_class.imshow(np.array(class_rgba)[:, None, :], aspect='auto')
cax_class.set_xticks([])
cax_class.set_yticks([])
cax_subclass = divider.append_axes("left", size="2%", pad=0.05)
cax_subclass.imshow(np.array(subclass_rgba)[:, None, :], aspect='auto')
cax_subclass.set_xticks([])
cax_subclass.set_yticks([])
colors_row = generate_distinct_colors(k_row+1)
colors_row[-1] = to_rgba('gray')
cmap_row = ListedColormap(colors_row)
cax_clusters = divider.append_axes("left", size="2%", pad=0.05)
cax_clusters.imshow(row_clusters_full[:, None], aspect='auto', cmap=cmap_row, vmin=1, vmax=k_row+1)
cax_clusters.set_xticks([])
cax_clusters.set_yticks([])
rgba_lip = [to_rgba(c) for c in lipid_colors]
cax_top = divider.append_axes("left", size="2%", pad=0.05)
cax_top.imshow(np.repeat(np.array(rgba_lip)[:, None, :], 5, axis=0), aspect='auto')
cax_top.set_xticks([])
cax_top.set_yticks([])
cax_bar = divider.append_axes("right", size="15%", pad=0.1)
y_pos = np.arange(len(nonzero_counts))
clipped = np.clip(nonzero_counts, 0, 50)
bar_colors = [colors_row[c-1] for c in row_clusters_full]
cax_bar.barh(y_pos, clipped, height=0.8, color=bar_colors, alpha=0.8)
cax_bar.set_ylim(-0.5, len(nonzero_counts)-0.5)
cax_bar.set_xlim(0,50)
cax_bar.invert_yaxis()
cax_bar.set_yticks([])
cax_bar.grid(True, alpha=0.3, axis='x')
bounds = np.where(np.diff(row_clusters_full[:len(ordered_rows_main)]) != 0)[0] + 0.5
for b in bounds:
    for axis in [ax, cax_allen, cax_class, cax_subclass, cax_clusters, cax_bar]:
        axis.axhline(b, color='white', linewidth=1.0, alpha=0.7)
plt.tight_layout()
plt.savefig("overview_with_subclass.pdf")
plt.show()

import matplotlib.pyplot as plt
from matplotlib.patches import Patch
class_colors = (
    ddf[['class', 'color']]
    .drop_duplicates(subset='class')
    .set_index('class')['color']
)
patches = [Patch(facecolor=col, edgecolor='none', label=cls)
           for cls, col in class_colors.items()]
fig, ax = plt.subplots()
ax.legend(handles=patches, title='Lipid Class', frameon=False, loc='center')
ax.axis('off')  
plt.tight_layout()
plt.savefig("classcolors.pdf")
plt.show()

def rgba_to_hex(rgba_array):
    """Convert RGBA array (values 0-1) to hex string"""
    r = int(rgba_array[0] * 255)
    g = int(rgba_array[1] * 255)
    b = int(rgba_array[2] * 255)
    return f"#{r:02x}{g:02x}{b:02x}"

bar_colors_hex_manual = [rgba_to_hex(rgba) for rgba in bar_colors]

In [None]:
colors = pd.Series(bar_colors_hex_manual, index=df_opt.index)
clusters_series = pd.Series(row_clusters_full, index=df_opt.index)
loool = pd.DataFrame((np.abs(shifts) > 0.2).sum(axis=1).sort_values(), columns = ['nmod'])
loool['linkage'] = loool.index.map(clusters_series)
loool['color'] = loool.index.map(colors)
loool['linkage'].value_counts()
loool.loc[loool['linkage'].isin(loool['linkage'].value_counts().index[loool['linkage'].value_counts() <= 4]), 'linkage'] = 17
loool['linkage'].value_counts() # cluster "1" is miscellaneous
clusters_series = loool['linkage'][loool['linkage'] != 17]
colors_series = loool['color'][loool['linkage'] != 17]
metadata = sub_alldata
Linkage_clusters = clusters_series.unique()
Linkage_colors = colors_series.unique()
Linkage_colors

color_df = pd.DataFrame({
    'Linkage Cluster': Linkage_clusters,
    'Color': Linkage_colors
})

color_df.to_csv("color_df_pregnancy.csv")
metadata['cluster_variation'] = metadata['supertype'].map(clusters_series).fillna("lightgray")
color_df.index = color_df['Linkage Cluster']
metadata['cluster_variation_color'] = metadata['cluster_variation'].map(color_df['Color'])
metadata['cluster_variation_color'] = metadata['cluster_variation_color']

unique_samples = sorted(metadata['Sample'].unique())
unique_sections = sorted(metadata['SectionPlot'].unique())

fig, axes = plt.subplots(6, 6, figsize=(20, 12))

for sample_idx, sample in enumerate(unique_samples[:6]):
    for section_idx, section in enumerate(unique_sections[:6]):
        ax = axes[sample_idx, section_idx]

        ddf = metadata[
            (metadata['Sample'] == sample) & 
            (metadata['SectionPlot'] == section)
        ]

        ax.scatter(
            ddf['y'], 
            -ddf['x'], 
            c=ddf['cluster_variation_color'].astype(object).fillna('#CCCCCC').tolist(),  
            s=0.5, 
            rasterized=True
        )

        ax.axis('off')
        ax.set_aspect('equal')

        ax.set_title(f'Sample {sample}, Section {section}', fontsize=8)

plt.tight_layout(rect=[0, 0, 0.9, 1])
plt.savefig("comodulation_clusters.pdf")
plt.show()

## Prepare for 3d rendering

In [None]:
xxx = metadata[['supertype', 'cluster_variation', 'cluster_variation_color']].drop_duplicates().reset_index().iloc[:, 1:]
xxx.columns = ['supertype', 'comodulation_cluster', 'comodulation_cluster_color']
xxx.index = xxx['supertype']
xxx.to_csv("comodclcol.csv")
xxx

## Comodulation network and thumbnails

In [None]:
comodulation_clusters = {}
for cluster in clusters_series.unique():
    members = clusters_series[clusters_series == cluster].index.tolist()
    comodulation_clusters[str(cluster)] = members

In [None]:
mod_scores = compute_edge_modulation_scores(foldchanges=shifts,
                                            metabolicmodule=metabolicmodule,
                                            comodulation_clusters=comodulation_clusters)


In [None]:
mod_scores['11']

In [None]:
metabolicmodule = metabolicmodule.loc[mod_scores['11'].index, mod_scores['11'].columns]
Linkage_color_map = dict(zip(color_df['Linkage Cluster'], color_df['Color']))

In [None]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch

def plot_modulation_network_separated_v2(modulation_scores,
                                         metabolicmodule,
                                         k=0.2,
                                         tie_rad=0.1,
                                         color_map=None):
    """
    Plot the lipid–lipid modulation network, collapsing “others” into 2-way ties 
    and separating those ties by dynamically scaling the curvature based on node distance.

    Parameters
    ----------
    modulation_scores : dict[str, pandas.DataFrame]
        cluster → DataFrame of modulation scores (indexed & columned by lipid names)
    metabolicmodule : pandas.DataFrame
        binary adjacency (0/1 ints) for the lipid–lipid module
    k : float
        spring_layout spacing parameter
    tie_rad : float
        base curvature radius for tie-edges; actual rad = tie_rad / dist_uv
    color_map : dict[str, str], optional
        cluster → color; if None, a default palette is used
    """
    common = set.intersection(*(set(df.index) for df in modulation_scores.values()))
    lipids = [L for L in metabolicmodule.index if L in common]
    adj = metabolicmodule.loc[lipids, lipids].astype(int)

    G = nx.from_pandas_adjacency(adj)
    pos = nx.spring_layout(G, seed=42, k=k, iterations=50)

    clusters = list(modulation_scores.keys())
    if color_map is None:
        palette = ['#1b9e77','#d95f02','#7570b3','#e7298a',
                   '#66a61e','#e6ab02','#a6761d','#666666']
        color_map = {c: palette[i % len(palette)] for i, c in enumerate(clusters)}
    else:
        color_map = {c: color_map.get(c, '#CCCCCC') for c in clusters}

    single_edges = {c: [] for c in clusters}
    tie_edges    = []

    for u, v in G.edges():
        scores = np.array([modulation_scores[c].loc[u, v] for c in clusters])
        if np.isnan(scores).any():
            continue

        if (scores == scores.max()).sum() == 1:
            idx = scores.argmax()
            single_edges[clusters[idx]].append((u, v))
        else:
            top2 = np.argsort(scores)[-2:]
            tie_edges.append(((u, v),
                              clusters[top2[0]],
                              clusters[top2[1]]))

    fig, ax = plt.subplots(figsize=(12, 12))
    nx.draw_networkx_nodes(G, pos,
                           node_size=100,
                           node_color='lightgray',
                           ax=ax)

    for c, eds in single_edges.items():
        nx.draw_networkx_edges(G, pos,
                               edgelist=eds,
                               edge_color=color_map[c],
                               width=2,
                               ax=ax)

    for (u, v), c1, c2 in tie_edges:
        x1, y1 = pos[u]
        x2, y2 = pos[v]
        dist = np.hypot(x2 - x1, y2 - y1) + 1e-6
        rad = tie_rad / dist

        patch1 = FancyArrowPatch((x1, y1), (x2, y2),
                                 connectionstyle=f"arc3,rad={rad}",
                                 color=color_map[c1],
                                 linewidth=2,
                                 arrowstyle='-')
        patch2 = FancyArrowPatch((x1, y1), (x2, y2),
                                 connectionstyle=f"arc3,rad={-rad}",
                                 color=color_map[c2],
                                 linewidth=2,
                                 arrowstyle='-')
        ax.add_patch(patch1)
        ax.add_patch(patch2)

    xs = np.array([p[0] for p in pos.values()])
    ys = np.array([p[1] for p in pos.values()])
    cx, cy = xs.mean(), ys.mean()
    offset = 0.04 * min(xs.max() - xs.min(), ys.max() - ys.min())

    for node, (x, y) in pos.items():
        vx, vy = x - cx, y - cy
        norm = np.hypot(vx, vy)
        if norm == 0:
            lx, ly = x, y
        else:
            lx = x + (vx / norm) * offset
            ly = y + (vy / norm) * offset
        ax.text(lx, ly, node, fontsize=8, zorder=5)
        ax.annotate('', xy=(x, y), xytext=(lx, ly),
                    arrowprops=dict(arrowstyle='-', color='gray', lw=0.5))

    for spine in ax.spines.values():
        spine.set_visible(False)
    handles = [plt.Line2D([0], [0], color=color_map[c], lw=2)
               for c in clusters]
    ax.legend(handles, clusters, title="Clusters", loc='upper right')

    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
cmap_comod = dict(zip(color_df['Linkage Cluster'], color_df['Color']))
cmap_comod

In [None]:
mod_scores_str = {str(k): v for k, v in mod_scores.items()}

cmap_comod_str = {str(k): v for k, v in cmap_comod.items()}

plot_modulation_network_separated_v2(
    mod_scores_str, 
    metabolicmodule, 
    k=0.2, 
    tie_rad=0.01, 
    color_map=cmap_comod_str
)

In [None]:
(np.abs(shifts) > 0.2).sum(axis=1).groupby(clusters_series).mean().sort_values()

In [None]:
ddf = pd.DataFrame(df.columns, columns=["lipid_name"]).fillna('')
ddf["class"] = ddf["lipid_name"].apply(lambda x: re.split(r'[ \(]', x)[0])
ddf["carbons"] = ddf["lipid_name"].str.extract(r'(\d+):').astype(float)
ddf["insaturations"] = ddf["lipid_name"].str.extract(r':(\d+)').astype(float)
ddf["insaturations_per_Catom"] = ddf["insaturations"] / ddf["carbons"]
ddf["broken"] = ddf["lipid_name"].str.endswith('_uncertain')
ddf.loc[ddf["broken"], ['class','carbons','insaturations','insaturations_per_Catom']] = np.nan

colors = pd.read_hdf("./zenodo/csv/lipidclasscolors.h5ad", key="table")
ddf['color'] = ddf['class'].map(colors['classcolors']).fillna("#000000")
ddf.loc[ddf["broken"], 'color'] = "#888888"
ddf.index = ddf['lipid_name']
lipid_colors = [ddf.loc[col, 'color'] if col in ddf.index else "#888888"
                for col in df_opt.columns]

ddf

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

def plot_modulation_thumbnails(modulation_scores, metabolicmodule, ddf, thresholds=0.1, k=0.2, palette=None, max_width=5, figsize_per_plot=(3,3)):
    """
    Draw one thumbnail per cluster:
    - Nodes in fixed layout (no labels)
    - Edges only where score >= threshold for that cluster
    - Edge width ∝ comodulation value
    - Edge color = cluster color
    - Connected nodes colored by ddf dataframe, others gray

    Parameters
    ----------
    modulation_scores : dict[str, pd.DataFrame]
        For each cluster c, a square DataFrame of com-modulation scores.
    metabolicmodule : pd.DataFrame
        Adjacency matrix of the underlying network (indexed by lipid).
    ddf : pd.DataFrame
        DataFrame with 'lipid' and 'color' columns for node coloring.
    thresholds : float or dict[str, float]
        If float, same threshold for all clusters; otherwise map cluster→threshold.
    k : float
        spring_layout "k" parameter.
    palette : dict[str, color] or None
        map cluster→hex color; defaults to 8-color palette.
    max_width : float
        Maximum edge width (for the largest score in each thumbnail).
    figsize_per_plot : tuple
        Size (w,h) of each thumbnail; total fig size calculated based on grid.
    """
    common = set.intersection(*(set(df.index) for df in modulation_scores.values()))
    common_lipids = [L for L in metabolicmodule.index if L in common]
    adj = metabolicmodule.loc[common_lipids, common_lipids].astype(int)
    G = nx.from_pandas_adjacency(adj)

    pos = nx.spring_layout(G, seed=42, k=k, iterations=50)

    clusters = list(modulation_scores.keys())
    if isinstance(thresholds, dict):
        thr_map = thresholds
    else:
        thr_map = {c: thresholds for c in clusters}

    if palette is None:
        base = ['#1b9e77', '#d95f02', '#7570b3', '#e7298a', '#66a61e', '#e6ab02', '#a6761d', '#666666']
        palette = {c: base[i % len(base)] for i,c in enumerate(clusters)}
    else:
        palette = {c: palette.get(c, '#CCCCCC') for c in clusters}

    lipid_to_color = dict(zip(ddf['lipid_name'], ddf['color']))

    n = len(clusters)
    ncols = 3
    nrows = (n + ncols - 1) // ncols  
    node_size = max(4, 12 - nrows)
    scaled_max_width = max(1, max_width * (8 / max(8, nrows))) 

    fig, axes = plt.subplots(nrows, ncols, figsize=(figsize_per_plot[0]*ncols, figsize_per_plot[1]*nrows))
    
    if nrows == 1 and ncols == 1:
        axes = [axes]
    elif nrows == 1:
        axes = axes.reshape(1, -1)
    elif ncols == 1:
        axes = axes.reshape(-1, 1)

    if nrows > 1 or ncols > 1:
        axes_flat = axes.flatten()
    else:
        axes_flat = axes

    for i, c in enumerate(clusters):
        ax = axes_flat[i]
        
        thr = thr_map[c]
        scores = modulation_scores[c]
        edges = [(u,v) for u,v in G.edges() if not np.isnan(scores.loc[u,v]) and scores.loc[u,v] >= thr]
        values = np.array([scores.loc[u,v] for u,v in edges])
        if len(values)>0:
            widths = (values / values.max()) * scaled_max_width
        else:
            widths = []

        connected_nodes = set()
        for u, v in edges:
            connected_nodes.add(u)
            connected_nodes.add(v)

        node_colors = []
        node_sizes = []
        for node in G.nodes():
            if node in connected_nodes and node in lipid_to_color:
                node_colors.append(lipid_to_color[node])
                node_sizes.append(node_size * 8) 
            else:
                node_colors.append('lightgray')
                node_sizes.append(node_size)  
        nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color=node_colors, ax=ax)
        if edges:
            nx.draw_networkx_edges(G, pos, edgelist=edges, width=widths, edge_color='black', ax=ax)
        
        if pos:
            x_coords = [pos[node][0] for node in pos]
            y_coords = [pos[node][1] for node in pos]
            center_x = (max(x_coords) + min(x_coords)) / 2
            center_y = (max(y_coords) + min(y_coords)) / 2
            radius = max(max(x_coords) - min(x_coords), max(y_coords) - min(y_coords)) / 2 * 1.2
            
            circle = plt.Circle((center_x, center_y), radius, fill=False, 
                              color=palette[c], linewidth=3, alpha=0.8,clip_on=False)
            ax.add_patch(circle)
        
        ax.set_aspect('equal')    
        ax.set_axis_off()

    for i in range(n, len(axes_flat)):
        axes_flat[i].set_visible(False)

    plt.tight_layout(pad=1.0, w_pad=0.5, h_pad=0.5)
    plt.savefig("comodclusters.pdf")
    plt.show()
    
plot_modulation_thumbnails(
    mod_scores_str,         
    metabolicmodule, ddf,
    thresholds=0.01,k=0.3,
    palette=cmap_comod_str 
)

## Study which supertypes are changed the most

In [None]:
cfg = LipidAnalysisConfig()
cfg.normalize_percentiles = (0.5, 99.5)
for lipid_name in sub_alldata.columns.values[:173]:
    sub_alldata = normalize_lipid_column(
        sub_alldata, 
        lipid_name,
        lower_percentile=config.normalize_percentiles[0],
        upper_percentile=config.normalize_percentiles[1]
    )

stindex = np.sort(sub_alldata['supertype'].unique())

xxx = sub_alldata.columns.values[0]
wholebrainmean = sub_alldata.loc[sub_alldata['Condition']=="naive", xxx].mean()
wholebrainmean

from scipy.special import expit 

lipid_upreg_xsupertype = []
lipid_downreg_xsupertype = []
lipid_expressed_xsupertype = []

for xxx in tqdm(sub_alldata.columns.values[:173]):
    params = np.load(xxx+"_model_params.npy", allow_pickle=True).item()

    loc_susc   = params["alpha_supertype_susceptibility_loc"]   
    scale_susc = params["alpha_supertype_susceptibility_scale"]

    loc_unconst   = params["alpha_supertype_unconst_loc"]       
    scale_unconst = params["alpha_supertype_unconst_scale"]

    n_samples = 1000

    samples_shift = np.random.default_rng(1234).normal(
        loc=loc_susc[None, :],
        scale=scale_susc[None, :],
        size=(n_samples, loc_susc.shape[0])
    )

    samples_unconst = np.random.default_rng(1234).normal(
        loc=loc_unconst[None, :],
        scale=scale_unconst[None, :],
        size=(n_samples, loc_unconst.shape[0])
    )

    samples_baseline = expit(samples_unconst)

    upregulation = samples_shift > 0.2*samples_baseline
    downregulation = -samples_shift > 0.2*samples_baseline
    expressed = samples_baseline > 0.05 #1.5*wholebrainmean

    lipid_upreg_xsupertype.append(np.mean(upregulation, axis=0) > 0.98)
    lipid_downreg_xsupertype.append(np.mean(downregulation, axis=0) > 0.98)
    lipid_expressed_xsupertype.append(np.mean(expressed, axis=0) > 0.98)
    
lipid_upreg_xsupertypes = pd.DataFrame(lipid_upreg_xsupertype, columns = stindex, index = sub_alldata.columns.values[:173])
lipid_downreg_xsupertypes = pd.DataFrame(lipid_downreg_xsupertype, columns = stindex, index = sub_alldata.columns.values[:173])
lipid_expressed_xsupertypes = pd.DataFrame(lipid_expressed_xsupertype, columns = stindex, index = sub_alldata.columns.values[:173])
shifted_xsupertypes = lipid_upreg_xsupertypes | lipid_downreg_xsupertypes

lipid_expressed_xsupertypes

In [None]:
mean_score = []
ci_lowers = []
ci_uppers = []

for yyy in tqdm(range(0,222)):

    exp = lipid_expressed_xsupertypes.iloc[:,yyy]
    shif = shifted_xsupertypes.iloc[:,yyy]

    (exp & shif).sum() / exp.sum()

    n, B = len(exp), 10000
    rng = np.random.default_rng(42)  
    # bootstrap replicates
    scores = [
        (exp[idx] & shif[idx]).sum() / exp[idx].sum()
        for idx in rng.integers(0, n, size=(B, n))
    ]

    # point estimate and 95% CI
    mean_score.append(np.mean(scores))
    ci_lower, ci_upper = np.percentile(scores, [2.5, 97.5])
    ci_lowers.append(ci_lower)
    ci_uppers.append(ci_upper)


In [None]:
sub_alldata['comodulation_cluster'] = sub_alldata['supertype'].map(clusters_series)
sub_alldata['comodulation_cluster_color'] = sub_alldata['cluster_variation_color']
sub_alldata = sub_alldata.loc[~sub_alldata['supertype'].isin(todrop),:]
tmp = sub_alldata[['supertype','comodulation_cluster_color']]
supertype_to_color = tmp.set_index('supertype')['comodulation_cluster_color'].to_dict()

scoresseries = pd.Series(mean_score, index = stindex).drop(todrop)

cilowsseries = pd.Series(ci_lowers, index = stindex).drop(todrop)
ciupsseries = pd.Series(ci_uppers, index = stindex).drop(todrop)

colors = [supertype_to_color[idx] for idx in scoresseries.index]

lower_err = scoresseries - cilowsseries
upper_err = ciupsseries - scoresseries
yerr = [lower_err, upper_err]

scoresseries = pd.DataFrame(scoresseries)
scoresseries['comodulation_cluster'] = scoresseries.index.map(clusters_series)
scoresseries.columns = ["dff", "comodulation_cluster"]

colors = [supertype_to_color.get(idx, 'lightgray') for idx in scoresseries.index]
colors = ['lightgray' if isinstance(c, float) and np.isnan(c) else c for c in colors]

y0 = 10/173

fig, (ax_high, ax_low) = plt.subplots(
    2, 1,
    sharex=True,
    figsize=(14, 6),
    gridspec_kw={'height_ratios': [1, 2]}  
)

lower_err = scoresseries - cilowsseries
upper_err = ciupsseries - scoresseries
yerr = [lower_err, upper_err]

for ax in (ax_low, ax_high):
    scoresseries.plot(
        kind='bar',
        color=colors,
        yerr=yerr,
        capsize=1,
        error_kw={'elinewidth': 0.2},
        ax=ax
    )
    ax.axhline(
        y=y0,
        color='gray',
        linestyle='--',
        linewidth=2,
        alpha=0.8,
        label='10/173 modulated lipids in pregnancy'
    )
    ax.legend(loc='upper right')

ax_low.set_ylim(0, 0.4)
ax_high.set_ylim(0.7, scoresseries.max() * 1.05)

ax_low.spines['top'].set_visible(False)
ax_high.spines['bottom'].set_visible(False)
ax_low.tick_params(labeltop=False)  #
d = .015 

kwargs = dict(transform=ax_high.transAxes, color='k', clip_on=False)
ax_high.plot((-d, +d), (-d, +d), **kwargs)
ax_high.plot((1 - d, 1 + d), (-d, +d), **kwargs) 

kwargs.update(transform=ax_low.transAxes) 
ax_low.plot((-d, +d), (1 - d, 1 + d), **kwargs)        
ax_low.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)  

ax_high.set_ylabel('Score')
ax_low.set_ylabel('Score')
ax_low.set_xlabel('Supertype')
plt.xticks([])  
plt.tight_layout()
plt.savefig("barplot_preg_st.pdf")
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

scoresseries = pd.DataFrame(scoresseries)
scoresseries['comodulation_cluster'] = scoresseries.index.map(clusters_series)
scoresseries.columns = ["dff", "comodulation_cluster"]

top2_list = (
    scoresseries
    .sort_values('dff', ascending=False)
    .groupby('comodulation_cluster')
    .head(1).index
    .tolist()
)
top_dff = scoresseries.groupby('comodulation_cluster')['dff'].max()
shifts = shifts.drop(['TG 72:9', 'TG 67:2', 'HexCer 36:1:O2'], axis=1)

matplotlib.rcParams['pdf.fonttype'] = 42

for sample in top2_list:
    s = shifts.loc[sample]

    neg = s[s < 0].nsmallest(4) 
    pos = s[s > 0].nlargest(4)    
    neg = neg.sort_values(ascending=False) 
    pos = pos.sort_values(ascending=True)  
    combined = pd.concat([neg, pos])
    values = combined.values
    colors = [ddf.loc[name, 'color'] for name in combined.index]

    fig, ax = plt.subplots(figsize=(8, 3.6))
    ax.barh(combined.index, values, color=colors)
    ax.axvline(0, color='gray', linewidth=1) 

    max_abs = max(abs(values.max()), abs(values.min()))
    ax.set_xlim(-max_abs * 1.05, max_abs * 1.05)

    ax.invert_yaxis()    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_xlabel('Shift value')
    ax.set_ylabel('')
    ax.set_xticks(
        np.linspace(-max_abs, max_abs, 5)
    )

    plt.tight_layout()
    plt.savefig(f"shift_st_{sample}.pdf")
    plt.show()

import numpy as np
import matplotlib.pyplot as plt
from skimage import measure
from scipy import ndimage
from scipy.signal import savgol_filter
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

def smooth_contours_savgol(input_array, window_length=15, polyorder=3, passes=1):
    smoothed = ndimage.gaussian_filter(input_array.astype(float), sigma=0.75)
    contours = measure.find_contours(smoothed, 0.5)
    out = []
    for contour in contours:
        if len(contour) <= 3:
            out.append(contour)
            continue
        wl = window_length if len(contour) > window_length else max(5, len(contour) - 2)
        if wl % 2 == 0:
            wl -= 1
        po = polyorder if wl > polyorder else max(1, wl - 1)
        x, y = contour[:, 1], contour[:, 0]
        try:
            for _ in range(passes):
                x = savgol_filter(x, wl, po)
                y = savgol_filter(y, wl, po)
            out.append(np.column_stack((y, x)))
        except Exception:
            out.append(contour)
    return out

ann = np.load("/data/luca/lipidatlas/ManuscriptGithub/zenodo/mixed/eroded_annot.npy")

for mmii in top2_list:
    samples = sub_alldata['Sample'].unique()
    n_samples = len(samples)
    n_sections = 6

    fig, axes = plt.subplots(n_samples, n_sections, figsize=(18, 20),
                             sharex=True, sharey=True)
    for i, samp in enumerate(samples):
        for j in range(1, n_sections + 1):
            ax = axes[i, j-1]

            sel_all = sub_alldata[
                (sub_alldata['Sample'] == samp) &
                (sub_alldata['SectionPlot'] == j)
            ]
            if len(sel_all) == 0:
                contours = []
            else:
                mean_x = int(sel_all['x_index'].mean())
                ANN_slice = ann[mean_x, :, :]
                contours = smooth_contours_savgol(
                    ANN_slice,
                    window_length=15,
                    polyorder=3,
                    passes=2
                )

            for cnt in contours:
                ax.fill(cnt[:,1], -cnt[:,0],
                        facecolor='white',
                        edgecolor='lightgray',
                        linewidth=2)

            sel_mmii = sel_all[sel_all['supertype'] == mmii]
            ax.scatter(sel_mmii['z_index'], -sel_mmii['y_index'],
                       c=sel_mmii['comodulation_cluster_color'], s=0.5,
                       alpha=0.7, zorder=2)

            ax.set_aspect('equal')
            ax.set_xticks([]); ax.set_yticks([])
            for spine in ax.spines.values():
                spine.set_visible(False)

            if j == 1:
                ax.set_ylabel(samp, rotation=0, labelpad=40, va='center')
            if i == 0:
                ax.set_title(f"Section {j}")

    plt.tight_layout()
    plt.suptitle(f"Subtype: {mmii}", fontsize=16, y=1.02)
    plt.savefig("preg_"+mmii+".pdf")
    plt.show()


## Example focus on changed lipizones

In [None]:
yyy = sub_alldata.loc[(sub_alldata['supertype'] == "21222111") & (sub_alldata['Condition'] == "pregnant"), 'acronym'].value_counts()
yyy = yyy[yyy > 500].index.values

intheseacronyms_preg = (sub_alldata.loc[(sub_alldata['acronym'].isin(yyy)) & (sub_alldata['Condition'] == "pregnant"), 'supertype'].value_counts()[:5]).index.values

intheseacronyms_nopreg = (sub_alldata.loc[(sub_alldata['acronym'].isin(yyy)) & (sub_alldata['Condition'] != "pregnant"), 'supertype'].value_counts()[:5]).index.values

tocheck = np.concatenate((intheseacronyms_preg, intheseacronyms_nopreg))

sec = sub_alldata.loc[sub_alldata.index[sub_alldata['supertype'].isin(tocheck)],:]
labs = sub_alldata.loc[sec.index, 'Condition'] == "pregnant"

from scipy.stats import mannwhitneyu, entropy
import matplotlib.pyplot as plt
from tqdm import tqdm
from statsmodels.stats.multitest import multipletests
from tqdm import tqdm

def differential_lipids(lipidata, kmeans_labels, min_fc=0.2, pthr=0.05):
    results = []

    a = lipidata.loc[kmeans_labels == 0,:]
    b = lipidata.loc[kmeans_labels == 1,:]
    
    for rrr in range(lipidata.shape[1]):
       
        groupA = a.iloc[:,rrr]
        groupB = b.iloc[:,rrr]
    
        # log2 fold change
        meanA = np.mean(groupA) + 0.00000000001
        meanB = np.mean(groupB) + 0.00000000001
        log2fold_change = np.log2(meanB / meanA) if meanA > 0 and meanB > 0 else np.nan
    
        # Wilcoxon test
        try:
            _, p_value = mannwhitneyu(groupA, groupB, alternative='two-sided')
        except ValueError:
            p_value = np.nan
    
        results.append({'lipid': rrr, 'log2fold_change': log2fold_change, 'p_value': p_value})

    results_df = pd.DataFrame(results)

    # correct for multiple testing
    reject, pvals_corrected, _, _ = multipletests(results_df['p_value'].values, alpha=0.05, method='fdr_bh')
    results_df['p_value_corrected'] = pvals_corrected
    
    return results_df

diff = differential_lipids(sec.iloc[:,:173], labs, min_fc=0.2, pthr=0.05)
diff.index = sec.iloc[:,:173].columns
diff.sort_values(by="log2fold_change")[::-1]

diff = diff.loc[diff['p_value_corrected'] < 0.05,:].sort_values(by="log2fold_change")[::-1]

def plot_lipid_grid(sub_alldata, coeffmap, tocheck, lipids, section=3):
    """
    Plot spatial distributions for a grid of lipids for Pregnant4 sample.
    - Only sample 'Pregnant4'
    - 4x4 grid of subplots for provided lipids
    - Colors determined by matching column in coeffmap
    - Single joint coolwarm colormap across all lipids
    """
    data = sub_alldata[(sub_alldata['Sample'] == 'Pregnant4') & \
                       (sub_alldata['SectionPlot'] == section)]

    vals = []
    for lipid in lipids:
        subset_idx = data.index
        vals.append(coeffmap.loc[subset_idx, lipid].values)
    all_vals = np.concatenate(vals)
    vmin, vmax = np.nanpercentile(all_vals, [2, 98])

    fig, axes = plt.subplots(4, 4, figsize=(12, 12))
    cmap = plt.cm.coolwarm
    norm = Normalize(vmin=vmin, vmax=vmax)

    for ax, lipid in zip(axes.flatten(), lipids):
        ax.scatter(
            data['y'], -data['x'],c=data['lipizone_names'].astype("category").cat.codes,
            cmap='Greys', s=0.005, alpha=0.2, rasterized=True
        )
        ax.scatter(
            data.loc[data.index.isin(sec.index),'y'], -data.loc[data.index.isin(sec.index),'x'],
            c=coeffmap.loc[data.loc[data.index.isin(sec.index),'y'].index, lipid], cmap=cmap, norm=norm,
            s=0.5, alpha=0.6, rasterized=True
        )
        ax.set_title(lipid, fontsize=8)
        ax.axis('off')
        ax.set_aspect('equal')

    cax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    sm = ScalarMappable(norm=norm, cmap=cmap)
    fig.colorbar(sm, cax=cax)

    plt.tight_layout(rect=[0, 0, 0.9, 1])
    return fig


top16 = diff.sort_values(by='log2fold_change', ascending=False).index[:16].tolist()
fig = plot_lipid_grid(sub_alldata, coeffmap, tocheck, top16, section=3)
fig.savefig('grid_spatial_distribution_pregnant4.pdf', dpi=300, bbox_inches='tight')
plt.show()

## A "new" lipizone in pregnancy

In [None]:
supersample = pd.concat([sub_alldata['supertype'], sub_alldata['Sample']], axis=1)
supersample = pd.crosstab(supersample.iloc[:, 0], supersample.iloc[:, 1])
supersample = supersample / supersample.sum()
from scipy.stats import ttest_ind

female_cols = ['Female1', 'Female2', 'Female3']
pregnant_cols = ['Pregnant1', 'Pregnant2', 'Pregnant4']

p_values = supersample.apply(
    lambda row: ttest_ind(row[female_cols], row[pregnant_cols], equal_var=False).pvalue,
    axis=1
)

supersample['p_value'] = p_values

FCs = supersample[['Pregnant1', 'Pregnant2', 'Pregnant4']].mean(axis=1) / supersample[['Female1', 'Female2', 'Female3']].mean(axis=1)
FCs.sort_values()[::-1][:20]

sub_alldata['supertype'] = sub_alldata['supertype']

for aaa in range(1, 7):
    sec1 = sub_alldata.loc[sub_alldata['SectionPlot'] == aaa,:]

    samples = sec1['Sample'].unique()
    num_samples = len(samples)
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))

    axes = axes.flatten()

    for idx, samp in enumerate(samples):
        ax = axes[idx]
        xxx = sec1.loc[sec1['Sample'] == samp, :]

        ax.scatter(
            xxx['y'], -xxx['x'],
            c=xxx['supertype'].astype("category").cat.codes,
            s=0.05,
            alpha=0.7,
            cmap="Greys"
        )

        yyy = xxx.loc[xxx['supertype'] == "21222111", :]

        ax.scatter(
            yyy['y'], -yyy['x'],
            c="red",
            s=0.05,
            alpha=0.7
        )

        ax.set_aspect('equal')

        ax.set_xticks([])
        ax.set_yticks([])

        for spine in ax.spines.values():
            spine.set_visible(False)

        ax.set_title(samp)

    for idx in range(num_samples, len(axes)):
        fig.delaxes(axes[idx])
    plt.tight_layout()
    plt.show()
    
## this lipizone indeed touches several portions of layer 1, which was found to be altered
# it is fundamentally a "new" lipizone, even if potentially the label transfer is bad, for sure it is spatially very consistent and means nothing was better transferrable
# so as a minimum this would highlight changes in layer 1

## General figures

In [None]:
coords = sub_alldata[['xccf','yccf','zccf','SectionID', 'Sample', 'SectionPlot', 'x', 'y']]
shift = pd.read_parquet("shift_pregnancy.parquet")
baseline = pd.read_parquet("baseline_pregnancy.parquet")
significance = pd.read_parquet("sign_significance_pregnancy.parquet")
significance = significance.loc[shift.index, shift.columns]
shift[~significance] = 0.0

relshift = shift/baseline

# !!we decide to work all in LOG FOLD CHANGES space to avoid confusion!!
susc_df = np.log2((shift + baseline) / baseline)

susc_df

import matplotlib.pyplot as plt
plt.hist(susc_df.max(), bins=10, color="black")
plt.savefig("maxsusperlip.pdf")
plt.show()

plt.hist(susc_df.min(), bins=10, color="black")
plt.savefig("minsusperlip.pdf")
plt.show()

plt.hist(susc_df.values.flatten(), bins=100, color="black")
plt.savefig("allsuspreg.pdf")
plt.show()

susc_df['supertype'] = susc_df.index

coeffmap = pd.merge(
    sub_alldata[['supertype']], 
    susc_df,
    on='supertype',
    how='left'
)
coeffmap

coeffmap = coeffmap.iloc[:, 1:]
coeffmap.index = sub_alldata.index

# check that susceptibilities are a "regularized" version of the trivial centroid differences
centroids = sub_alldata.loc[sub_alldata['Condition'] != "pregnant",shift.columns].groupby(sub_alldata["supertype"]).mean()
centroids2 = sub_alldata.loc[sub_alldata['Condition'] == "pregnant",shift.columns].groupby(sub_alldata["supertype"]).mean()
centroids2 = centroids2.loc[centroids.index, centroids.columns]
delll = centroids2 - centroids
plt.scatter(shift.values.flatten(), delll.values.flatten(), s=0.05, c="gray", alpha=0.7, rasterized=True)
plt.savefig("bayesrecapdelta.pdf")
plt.show()

print(np.corrcoef(susc_df.iloc[:,:-1].values.flatten(), delll.values.flatten())) # is it "too" correlated? haha, then why do we even need bayes...?
# we need it to partition the vertical line = nonsignificant stuff

In [None]:
filtered_data = sub_alldata
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

checklip = ["HexCer 42:2;O2", "PA 34:1"] # check the winner spatially

for currentPC in checklip:
    
    print(currentPC)
    results = []

    for section in filtered_data['SectionID'].unique():
        subset = filtered_data[filtered_data['SectionID'] == section]

        perc_2 = subset[currentPC].quantile(0.02)
        perc_98 = subset[currentPC].quantile(0.98)

        results.append([section, perc_2, perc_98])
    percentile_df = pd.DataFrame(results, columns=['SectionID', '2-perc', '98-perc'])
    med2p = percentile_df['2-perc'].median()
    med98p = percentile_df['98-perc'].median()

    cmap = plt.cm.plasma

    unique_samples = sorted(filtered_data['Sample'].unique())
    unique_sections = sorted(filtered_data['SectionPlot'].unique())

    fig, axes = plt.subplots(6, 6, figsize=(20, 12))

    for sample_idx, sample in enumerate(unique_samples[:6]):
        for section_idx, section in enumerate(unique_sections[:6]):
            ax = axes[sample_idx, section_idx]

            try:
                ddf = filtered_data[
                    (filtered_data['Sample'] == sample) & 
                    (filtered_data['SectionPlot'] == section)
                ]

                ax.scatter(
                    ddf['y'], 
                    -ddf['x'], 
                    c=ddf[currentPC], 
                    cmap="plasma", 
                    s=0.5, 
                    rasterized=True, 
                    vmin=med2p, 
                    vmax=med98p
                )

                ax.axis('off')
                ax.set_aspect('equal')

                ax.set_title(f'Sample {sample}, Section {section}', fontsize=8)

            except:
                continue

    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    norm = Normalize(vmin=med2p, vmax=med98p)
    sm = ScalarMappable(norm=norm, cmap=cmap)
    fig.colorbar(sm, cax=cbar_ax)

    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.show() 

In [None]:
# look at total production and total consumption and absolute balance
activity = shift.copy()

activitypersupertype = activity.sum(axis=1).sort_values()
plt.hist(activitypersupertype, bins = 50)
plt.show()

activitypersupertype = pd.DataFrame(activitypersupertype)
activitypersupertype['supertype'] = activitypersupertype.index
actmap = pd.merge(
    sub_alldata[['supertype']], 
    activitypersupertype,
    on='supertype',
    how='left'
)
actmap.index = sub_alldata.index
actmap = actmap.iloc[:, 1:]
actmap.columns = ['actmap']
filtered_data2 = pd.concat([coords, actmap],axis=1)

for currentPC in ["actmap"]:
    
    print(currentPC)

    unique_samples = sorted(filtered_data2['Sample'].unique())
    unique_sections = sorted(filtered_data2['SectionPlot'].unique())

    fig, axes = plt.subplots(6, 6, figsize=(20, 12))

    for sample_idx, sample in enumerate(unique_samples[:6]):
        for section_idx, section in enumerate(unique_sections[:6]):
            ax = axes[sample_idx, section_idx]

            try:
                ddf = filtered_data2[
                    (filtered_data2['Sample'] == sample) & 
                    (filtered_data2['SectionPlot'] == section)
                ]

                ax.scatter(
                    ddf['y'], 
                    -ddf['x'], 
                    c=ddf[currentPC], 
                    cmap="coolwarm", 
                    s=0.5, 
                    rasterized=True, 
                    vmin= -20.0, ########################
                    vmax= 35.0 ########################
                )

                ax.axis('off')
                ax.set_aspect('equal')

                ax.set_title(f'Sample {sample}, Section {section}', fontsize=8)

            except:
                continue

    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.show()
    
namingtable = {
    "cluster": [
        11111, 11112, 11121, 11122, 11211, 11212, 11221, 11222, 12111, 12112, 
        12121, 12122, 12211, 12212, 12221, 12222, 21111, 21112, 21120, 21211, 
        21212, 21221, 21222, 22111, 22112, 22121, 22122, 22211, 22212, 22221, 22222
    ],
    "zone": [
        "Mixed and hindbrain white matter", "Core callosal white matter", 
        "Callosal and cerebellar white matter", "Ventral white matter", 
        "Boundary white matter", "Thalamic and mid/hindbrain white matter", 
        "Mid/hindbrain white matter", "Mixed white matter", 
        "Choroid plexus and ventricles", "Ventricular linings", 
        "Thalamic and midbrain regions", "White and gray matter boundary", 
        "Thalamic mixed gray and white matter", "Thalamic mixed gray and white matter #2", 
        "Neuron-rich lateral white matter", "Neuron-rich lateral white matter #2", 
        "Pallidum and projections", "Cortical layer 4", 
        "Subcortical plate, hippocampus and hypothalamus", 
        "GABA-ergic Purkinje cells of the cerebellum", "Cortical layers 2-3 and 4", 
        "Piriform cortex", "Cortical layers 1 and 2-3", "Cortical layer 5", 
        "Cortical layer 6, dentate gyrus", "Striatum, hypothalamus and hippocampus", 
        "Striatum, hypothalamus and hippocampus #2", 
        "Retrosplenial, cortical, cerebellar", "Cortical layer 6 and cerebellar Y", 
        "Cerebellar glutamatergic neurons", "Cortical layer 6 and thalamic"
    ],
    "color": [
        "#360064", "#980053", "#170b3b", "#ac2f5c", "#2a3f6d", "#002657", 
        "#21366b", "#3e4b6c", "#f75400", "#ef633e", "#a5d4e6", "#6399c6", 
        "#853a00", "#edeef4", "#fdbf71", "#ce710e", "#940457", "#a2d36c", 
        "#d5edb5", "#0065d6", "#bcf18b", "#a68d68", "#79e47e", "#2f0097", 
        "#47029f", "#7500a8", "#d70021", "#ca99c9", "#d4b9da", "#e00085", 
        "#f6f3f8"
    ]
}

namingtable = pd.DataFrame(namingtable)
namingtable.index = namingtable['cluster'].astype(str)

# similar results as doing abs prod and abs degr independently
activity = susc_df.iloc[:,:-1].copy()

overallproduction = activity.sum(axis=1).sort_values()
subclasses_tocheck = [x[:5] for x in list(overallproduction.index)]
namingtable.loc[subclasses_tocheck, "zone"][:20] # ventricular linings seem to be degrading membrane...

for currentPC in ["actmap"]:
    print(currentPC)

    unique_sections = sorted(filtered_data2['SectionPlot'].unique())
    
    # Create a figure
    fig = plt.figure(figsize=(20, 4))
    scatter_plots = []  # Collect scatter plots for colorbar

    # Define starting position and width
    left_start = 0.05  # Starting position
    width = 0.2        # Width of each subplot
    overlap = 0.1      # Amount of lateral overlap (50% of width)

    for section_idx, section in enumerate(unique_sections[:6][::-1]):
        left = left_start + section_idx * (width - overlap)  # Overlap each subplot

        # Create an axis with adjusted position
        ax = fig.add_axes([left, 0.2, width, 0.6])  # [left, bottom, width, height]

        try:
            ddf = filtered_data2[
                (filtered_data2['Sample'] == 'Pregnant2') & 
                (filtered_data2['SectionPlot'] == section)
            ]

            scatter = ax.scatter(
                ddf['y'], 
                -ddf['x'], 
                c=ddf[currentPC], 
                cmap="coolwarm", 
                s=0.5, 
                rasterized=True, 
                vmin=-20.0,  # Minimum value for color scale
                vmax=35.0    # Maximum value for color scale
            )
            scatter_plots.append(scatter)

            ax.axis('off')  # Hide axes
            ax.set_aspect('equal')

        except Exception as e:
            print(f"Error with section {section}: {e}")
            continue

    # Add a colorbar
    cbar = fig.colorbar(scatter_plots[0], ax=fig.axes, orientation='vertical', fraction=0.02, pad=0.04)
    cbar.set_label('Activity Map', fontsize=10)
    plt.savefig("pregnancy_membraneactmap.pdf")
    plt.show()


## PCA males-females-pregnancy

In [None]:
import pandas as pd

data = pd.read_parquet("/data/luca/lipidatlas/ManuscriptGithub/zenodo/maindata_2.parquet")
sample_list = ['Male1', 'Male2', 'Male3', 
               'Female1', 'Female2', 'Female3', 
               'Pregnant1', 'Pregnant2', 'Pregnant4']


data = data.loc[data['Sample'].isin(sample_list),:]
data
data_subset = data[data['Sample'].isin(sample_list)]
data_oligo = data_subset
sphingo_cols = [col for col in data_subset.columns 
                if col.startswith('HexCer') or col.startswith('Cer') or col.startswith('SM')]
data_sphingo = data_oligo.loc[:, data_oligo.columns.isin(sphingo_cols)]
data_lipids = data_sphingo
data_lipids_z = data_lipids.apply(lambda col: (col - col.mean()) / col.std(), axis=0)
data_lipids_z[['supertype', 'Sample']] = data_subset[['supertype', 'Sample']]
grouped_data = data_lipids_z.groupby(['supertype', 'Sample']).mean()
supertypexsample = grouped_data.mean(axis=1).unstack()
corr = data[['supertype', 'class']].drop_duplicates()
corr.index = corr.supertype
corr.index[corr["class"] == "111"]
grouped_data = supertypexsample.loc[corr.index[corr["class"] == "111"],:]

lips = data.iloc[:,:173]

colors = ["pink", "pink", "pink", "blue", "blue", "blue", "purple", "purple", "purple"]
datemp = lips.copy() 
p2 = datemp.quantile(0.005)
p98 = datemp.quantile(0.995)

datemp_values = datemp.values
p2_values = p2.values
p98_values = p98.values

normalized_values = (datemp_values - p2_values) / (p98_values - p2_values)

clipped_values = np.clip(normalized_values, 0, 1)

normalized_datemp = pd.DataFrame(clipped_values, columns=datemp.columns, index=datemp.index)
centroids = normalized_datemp.groupby([data['Sample'], data['supertype']]).mean()
centroids = centroids.unstack()
centroidsOLD = centroids.copy()
centroids = normalized_datemp.groupby([data['SectionID'], data['supertype']]).mean()
centroids = centroids.unstack()
centroids = centroids.fillna(0.0)
data['colors'] = [x[:-1] for x in data['Sample']]
mdnow=data[['SectionID', 'colors']].drop_duplicates().reset_index()
mdnow.index = mdnow['SectionID']
mdnow = mdnow.loc[centroids.index,:]
mdnow.loc[mdnow['colors'] == "Female", 'colors'] = "pink"
mdnow.loc[mdnow['colors'] == "Male", 'colors'] = "blue"
mdnow.loc[mdnow['colors'] == "Pregnant", 'colors'] = "purple"


import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

scaler = StandardScaler()
scaled_data = pd.DataFrame(scaler.fit_transform(centroids), 
                          index=centroids.index, 
                          columns=centroids.columns)

pca = PCA(n_components=2)
pca_result = pca.fit_transform(scaled_data)

colors = mdnow['colors']
var_explained = pca.explained_variance_ratio_ * 100

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms

df_pca = pd.DataFrame({
    'PC1': pca_result[:, 0],
    'PC2': pca_result[:, 1],
    'Color': colors
})

def plot_confidence_ellipse(x, y, ax, color, n_std=2.0, **kwargs):
    if len(x) < 3:
        return

    cov = np.cov(x, y)
    if np.linalg.det(cov) == 0:
        return

    pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1])
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)

    mean_x = np.mean(x)
    mean_y = np.mean(y)

    scale_x = np.sqrt(cov[0, 0]) * n_std
    scale_y = np.sqrt(cov[1, 1]) * n_std

    ellipse = Ellipse((0, 0),
                      width=ell_radius_x * 2,
                      height=ell_radius_y * 2,
                      facecolor='none',
                      edgecolor=color,
                      linewidth=2,
                      alpha=0.6,
                      **kwargs)

    transf = transforms.Affine2D().rotate_deg(45).scale(scale_x, scale_y).translate(mean_x, mean_y)
    ellipse.set_transform(transf + ax.transData)
    ax.add_patch(ellipse)

AP = data['xccf'].groupby(data['SectionID']).mean()

np.corrcoef(AP, pca_result[:,0])