In [None]:
import pandas as pd
from scipy.spatial.distance import pdist, squareform

datavignettes = pd.read_parquet("./zenodo/maindata_2.parquet")
datavignettes = datavignettes.loc[datavignettes['Sample'] == "ReferenceAtlas",:]

In [None]:
datavignettes['branches4'] = datavignettes['level_1'].astype(str) + datavignettes['level_2'].astype(str)

In [None]:
import numpy as np

datemp = datavignettes.iloc[:, :173]
p2 = datemp.quantile(0.02)
p98 = datemp.quantile(0.98)
datemp_values = datemp.values
p2_values = p2.values
p98_values = p98.values
normalized_values = (datemp_values - p2_values) / (p98_values - p2_values)
clipped_values = np.clip(normalized_values, 0, 1)
lips01 = pd.DataFrame(clipped_values, columns=datemp.columns, index=datemp.index)
lips01

In [None]:
wm = datavignettes.loc[datavignettes['level_1'] == 1,:]
wm

In [None]:
hexcer_score = wm['HexCer + hexosylceramides'].groupby(datavignettes['division']).mean()
hexcer_score = (hexcer_score - np.min(hexcer_score)) / (np.max(hexcer_score) - np.min(hexcer_score))
hexcer_score = hexcer_score.drop(["Olfactory areas", "General", "ventricular systems"])
hexcer_score = hexcer_score.sort_values()[::-1]
hexcer_score

In [None]:
tmp = datavignettes[['division', 'allencolor']]
tmp = tmp.reset_index(drop=True)
def most_frequent_color(group):
    most_common_color = group['allencolor'].value_counts().index[0]
    return group[group['allencolor'] == most_common_color].iloc[0]

result = (tmp.groupby('division')
             .apply(most_frequent_color)
             .reset_index(drop=True))
result.index = result.division.values
result = result.drop(["Olfactory areas", "General", "ventricular systems"], axis=0)
result

In [None]:
cols = result.loc[hexcer_score.index,'allencolor'].values
cols

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6)) 
plt.bar(hexcer_score.index, hexcer_score, color=cols)  

ax = plt.gca() 
ax.spines['top'].set_visible(False)  
ax.spines['right'].set_visible(False)  
ax.spines['left'].set_visible(False) 
ax.spines['bottom'].set_visible(False)  

plt.xticks(rotation=90)

plt.show()

In [None]:
tmp = datavignettes[['branches4', 'lipizone_color']]
tmp = tmp.reset_index(drop=True)
def most_frequent_color(group):
    most_common_color = group['lipizone_color'].value_counts().index[0]
    return group[group['lipizone_color'] == most_common_color].iloc[0]

result = (tmp.groupby('branches4')
             .apply(most_frequent_color)
             .reset_index(drop=True))
result.index = result['branches4'].values
result.columns = ['branch', 'branch4color']
datavignettes['branch4color'] = datavignettes['branches4'].map(result['branch4color'])
datavignettes['branch4color']

In [None]:
# plot the putative subclasses

import os
dot_size = 0.3
sections_to_plot = range(1, 33)
dd2 = datavignettes

global_min_z = dd2['zccf'].min()
global_max_z = dd2['zccf'].max()
global_min_y = -dd2['yccf'].max()
global_max_y = -dd2['yccf'].min()
unique_lev4cols = np.sort(dd2['class'].unique())

fig, axes = plt.subplots(4, 8, figsize=(40, 20))
axes = axes.flatten()
for i, section_num in enumerate(sections_to_plot):
    ax = axes[i]
    xx = dd2[dd2["Section"] == section_num]
    sc1 = ax.scatter(xx['zccf'], -xx['yccf'], c=xx['branch4color'],
                      s=dot_size * 3, alpha=1.0, rasterized=True)
    ax.axis('off')
    ax.set_aspect('equal')
    ax.set_xlim(global_min_z, global_max_z)
    ax.set_ylim(global_min_y, global_max_y)
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])
plt.tight_layout()
plt.show()
plt.close(fig)

In [None]:
datavignettes['color'] = datavignettes.groupby('subclass')['lipizone_color'].transform(lambda x: x.mode().iloc[0] if not x.mode().empty else None)
wm = datavignettes.loc[datavignettes['level_1'] == 1,:]
wm = wm.loc[~wm['old_lipizone_names'].isin(["Choroid plexus and ventricles", "Ventricular linings"]),:]
wm

In [None]:
len(wm['lipizone_names'].unique()) # 245 lipizones in the white matter!

In [None]:
# plot the 4 main branches

import os
dot_size = 0.3
sections_to_plot = range(1, 33)
dd2 = datavignettes

global_min_z = dd2['zccf'].min()
global_max_z = dd2['zccf'].max()
global_min_y = -dd2['yccf'].max()
global_max_y = -dd2['yccf'].min()
unique_lev4cols = np.sort(dd2['class'].unique())

fig, axes = plt.subplots(4, 8, figsize=(40, 20))
axes = axes.flatten()
for i, section_num in enumerate(sections_to_plot):
    ax = axes[i]
    xx = dd2[dd2["Section"] == section_num]
    sc1 = ax.scatter(xx['zccf'], -xx['yccf'], c=xx['lipizone_names'].astype("category").cat.codes,cmap='Grays',
                      s=dot_size * 2, alpha=0.2, rasterized=True)
    xx_highlight = wm[wm["Section"] == section_num]
    sc2 = ax.scatter(xx_highlight['zccf'], -xx_highlight['yccf'],
                     c=xx_highlight['color'], s=dot_size, alpha=1, rasterized=True)
    ax.axis('off')
    ax.set_aspect('equal')
    ax.set_xlim(global_min_z, global_max_z)
    ax.set_ylim(global_min_y, global_max_y)
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])
plt.tight_layout()
plt.show()
plt.close(fig)

In [None]:
wm['branches4'] = datavignettes['level_1'].astype(str) + datavignettes['level_2'].astype(str) + datavignettes['level_3'].astype(str)
tmp = wm[['branches4', 'lipizone_color']]
tmp = tmp.reset_index(drop=True)
def most_frequent_color(group):
    most_common_color = group['lipizone_color'].value_counts().index[0]
    return group[group['lipizone_color'] == most_common_color].iloc[0]

result = (tmp.groupby('branches4')
             .apply(most_frequent_color)
             .reset_index(drop=True))
result.index = result['branches4'].values
result.columns = ['branch', 'branch4color']
wm['branch4color'] = wm['branches4'].map(result['branch4color'])
wm['branch4color']

In [None]:
# plot the 4 main branches

import os
dot_size = 0.3
sections_to_plot = range(1, 33)
dd2 = datavignettes

global_min_z = dd2['zccf'].min()
global_max_z = dd2['zccf'].max()
global_min_y = -dd2['yccf'].max()
global_max_y = -dd2['yccf'].min()
unique_lev4cols = np.sort(dd2['class'].unique())

fig, axes = plt.subplots(4, 8, figsize=(40, 20))
axes = axes.flatten()
for i, section_num in enumerate(sections_to_plot):
    ax = axes[i]
    xx = dd2[dd2["Section"] == section_num]
    sc1 = ax.scatter(xx['zccf'], -xx['yccf'], c=xx['lipizone_names'].astype("category").cat.codes,cmap='Grays',
                      s=dot_size * 2, alpha=0.2, rasterized=True)
    xx_highlight = wm[wm["Section"] == section_num]
    sc2 = ax.scatter(xx_highlight['zccf'], -xx_highlight['yccf'],
                     c=xx_highlight['branch4color'], s=dot_size, alpha=1, rasterized=True)
    ax.axis('off')
    ax.set_aspect('equal')
    ax.set_xlim(global_min_z, global_max_z)
    ax.set_ylim(global_min_y, global_max_y)
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])
plt.tight_layout()
plt.show()
plt.close(fig)

In [None]:
from matplotlib.colors import ListedColormap

stacked_data = pd.DataFrame()
for iii in wm['branches4'].unique():
    wmn = wm.loc[wm['branches4'] == iii]
    value_counts = wmn['Section'].value_counts()
    value_counts.index = value_counts.index.astype(int)
    aap = value_counts.sort_index()
    vcnorm = wm['Section'].value_counts()
    vcnorm.index = vcnorm.index.astype(int)
    vcnorm = vcnorm.sort_index()
    
    normalized_counts = aap / vcnorm
    stacked_data[f'{iii}'] = normalized_counts

tmp = wm[['branches4','branch4color']].drop_duplicates()

tmp.index = tmp['branches4']
tmp = tmp.loc[stacked_data.columns,:]
colors = tmp['branch4color'].values
custom_cmap = ListedColormap(colors)

fig, ax = plt.subplots(figsize=(12, 6))
bottom = None
for i, column in enumerate(stacked_data.columns):
    ax.bar(range(len(stacked_data.index)), stacked_data[column], bottom=bottom, 
           label=column, color=custom_cmap(i))
    if bottom is None:
        bottom = stacked_data[column]
    else:
        bottom += stacked_data[column]

ax.set_xlabel('Section', fontsize=12)
ax.set_ylabel('Normalized Count', fontsize=12)
ax.legend(title='Group', bbox_to_anchor=(1.05, 1), loc='upper left')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.tick_params(axis='both', which='both', length=0)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_xlim(-0.5, len(stacked_data.index) - 0.5)
plt.tight_layout()
plt.show()

In [None]:
callosalwm = wm.loc[wm['branches4'] == '1.01.01.0',:]
callosalwm['branch4color'] = callosalwm['color']
callosalwm

In [None]:
# plot the 4 main branches of the callosal WM

import os
dot_size = 0.3
sections_to_plot = range(1, 33)
dd2 = datavignettes

global_min_z = dd2['zccf'].min()
global_max_z = dd2['zccf'].max()
global_min_y = -dd2['yccf'].max()
global_max_y = -dd2['yccf'].min()
unique_lev4cols = np.sort(dd2['class'].unique())

fig, axes = plt.subplots(4, 8, figsize=(40, 20))
axes = axes.flatten()
for i, section_num in enumerate(sections_to_plot):
    ax = axes[i]
    xx = dd2[dd2["Section"] == section_num]
    sc1 = ax.scatter(xx['zccf'], -xx['yccf'], c=xx['lipizone_names'].astype("category").cat.codes,cmap='Grays',
                      s=dot_size * 2, alpha=0.2, rasterized=True)
    xx_highlight = callosalwm[callosalwm["Section"] == section_num]
    sc2 = ax.scatter(xx_highlight['zccf'], -xx_highlight['yccf'],
                     c=xx_highlight['color'], s=dot_size, alpha=1, rasterized=True)
    ax.axis('off')
    ax.set_aspect('equal')
    ax.set_xlim(global_min_z, global_max_z)
    ax.set_ylim(global_min_y, global_max_y)
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])
plt.tight_layout()
plt.show()
plt.close(fig)

In [None]:
from matplotlib.colors import ListedColormap

stacked_data = pd.DataFrame()
for iii in callosalwm['subclass'].unique():
    callosalwmn = callosalwm.loc[callosalwm['subclass'] == iii]
    value_counts = callosalwmn['Section'].value_counts()
    value_counts.index = value_counts.index.astype(int)
    aap = value_counts.sort_index()
    vcnorm = callosalwm['Section'].value_counts()
    vcnorm.index = vcnorm.index.astype(int)
    vcnorm = vcnorm.sort_index()
    
    normalized_counts = aap / vcnorm
    stacked_data[f'{iii}'] = normalized_counts

tmp = callosalwm[['subclass','color']].drop_duplicates()

tmp.index = tmp['subclass']
tmp = tmp.loc[stacked_data.columns,:]
colors = tmp['color'].values
custom_cmap = ListedColormap(colors)

fig, ax = plt.subplots(figsize=(12, 6))
bottom = None
for i, column in enumerate(stacked_data.columns):
    ax.bar(range(len(stacked_data.index)), stacked_data[column], bottom=bottom, 
           label=column, color=custom_cmap(i))
    if bottom is None:
        bottom = stacked_data[column]
    else:
        bottom += stacked_data[column]

ax.set_xlabel('Section', fontsize=12)
ax.set_ylabel('Normalized Count', fontsize=12)
ax.legend(title='Group', bbox_to_anchor=(1.05, 1), loc='upper left')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.tick_params(axis='both', which='both', length=0)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_xlim(-0.5, len(stacked_data.index) - 0.5)
plt.tight_layout()
plt.show()

In [None]:
callosalwm['lipizone_names'].value_counts() # 65 callosal "core wm" lipizones

In [None]:
lipizones = callosalwm['lipizone_names'].unique()

In [None]:
annotated = pd.read_csv("./zenodo/csv/callosalwm_annotated.csv", header=None)
annotated.columns = ['automatedanno', 'manualanno']
annotated.index = [annotated['automatedanno'][i][:-4] for i in range(annotated.shape[0])]
counts = annotated.groupby('manualanno').cumcount()
annotated['manualanno'] = annotated['manualanno'].where(counts == 0, 
                                    annotated['manualanno'] + ' ' + counts.astype(str))

len(annotated['manualanno'].unique())

In [None]:
import os
import numpy as np
dot_size = 0.3
sections_to_plot = range(1, 33)
dd2 = datavignettes

global_min_z = dd2['zccf'].min()
global_max_z = dd2['zccf'].max()
global_min_y = -dd2['yccf'].max()
global_max_y = -dd2['yccf'].min()
unique_lev4cols = np.sort(dd2['class'].unique())

fig, axes = plt.subplots(4, 8, figsize=(40, 20))
axes = axes.flatten()
for i, section_num in enumerate(sections_to_plot):
    ax = axes[i]
    xx = dd2[dd2["Section"] == section_num]
    sc1 = ax.scatter(xx['zccf'], -xx['yccf'], c=xx['lipizone_names'].astype("category").cat.codes,cmap='Grays',
                      s=dot_size * 2, alpha=0.2, rasterized=True)
    xx_highlight = callosalwm[callosalwm["Section"] == section_num]
    sc2 = ax.scatter(xx_highlight['zccf'], -xx_highlight['yccf'],
                     c=xx_highlight['lipizone_color'], s=dot_size, alpha=1, rasterized=True)
    ax.axis('off')
    ax.set_aspect('equal')
    ax.set_xlim(global_min_z, global_max_z)
    ax.set_ylim(global_min_y, global_max_y)
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])
plt.tight_layout()
plt.show()
plt.close(fig)

In [None]:
callosalwm['cleanlipizone'] = callosalwm['old_lipizone_names'].map(annotated['manualanno'])
callosalwm['cleanlipizone']

In [None]:
callosalwm['base'] = callosalwm['level_1'].astype(int).astype(str)+callosalwm['level_2'].astype(int).astype(str)+callosalwm['level_3'].astype(int).astype(str)+callosalwm['level_4'].astype(int).astype(str)+callosalwm['level_5'].astype(int).astype(str)
callosalwm['base'].unique()

In [None]:
# reassign color to help with readability

from matplotlib import colors as mcolors
data = callosalwm.iloc[:,:173]
data_std = (data.values - np.mean(data.values, axis=1)[:, None]) / (np.std(data.values, axis=1)[:, None] + 1e-8)
data_std = pd.DataFrame(data_std, index=data.index, columns= np.array(data.columns)) 
dd2 = pd.concat([data_std, callosalwm[['Section', 'xccf','yccf', 'zccf']], callosalwm[['base', 'cleanlipizone', 'lipizone']]], axis=1) #_std

In [None]:
from tqdm import tqdm

lipid_columns = np.array(data.columns)

divisions = dd2['base'].unique()
colormaps = [ "terrain", "plasma", "cividis", "PuOr"]

dd2['R'] = np.nan
dd2['G'] = np.nan
dd2['B'] = np.nan

dfs = []

for division, colormap_name in tqdm(zip(divisions, colormaps)):

    if len(dd2.loc[dd2['base'] == division, 'cleanlipizone'].unique()) > 1:

        print(division)

        datasubaqueo = dd2[dd2['base'] == division]

        clusters = datasubaqueo['cleanlipizone'].unique()

        lipid_df = pd.DataFrame(columns=lipid_columns)

        for i in range(len(clusters)):
            datasub = datasubaqueo[datasubaqueo['cleanlipizone'] == clusters[i]] 
            lipid_data = datasub.loc[:,lipid_columns].mean(axis=0)
            lipid_row = pd.DataFrame([lipid_data], columns=lipid_columns)
            lipid_df = pd.concat([lipid_df, lipid_row], ignore_index=True)

        column_means = lipid_df.mean()
        normalized_lipid_df = lipid_df.div(column_means, axis='columns')

        normalized_lipid_df.index = clusters
        normalized_lipid_df = normalized_lipid_df.T

        pca_columns = datasubaqueo.loc[:, lipid_columns]
        grouped_data = datasubaqueo[['cleanlipizone']].join(pca_columns)
        centroids = grouped_data.groupby('cleanlipizone').mean()

        distance_matrix = squareform(pdist(centroids, metric='euclidean'))
        distance_df = pd.DataFrame(distance_matrix, index=centroids.index, columns=centroids.index)

        np.fill_diagonal(distance_df.values, np.inf)
        initial_min_index = np.unravel_index(np.argmin(distance_df.values), distance_df.shape)
        ordered_elements = [distance_df.index[initial_min_index[0]], distance_df.columns[initial_min_index[1]]]
        distances = [0, distance_df.iloc[initial_min_index]]

        while len(ordered_elements) < len(distance_df):
            last_added = ordered_elements[-1]
            remaining_distances = distance_df.loc[last_added, ~distance_df.columns.isin(ordered_elements)]
            next_element = remaining_distances.idxmin()
            ordered_elements.append(next_element)
            distances.append(remaining_distances[next_element])

        ordered_elements

        leaf_sequence = ordered_elements

        sequential_distances = distances

        cumulative_sequential_distances = np.cumsum(sequential_distances)

        normalized_distances = cumulative_sequential_distances / cumulative_sequential_distances[-1]
        colormap = plt.get_cmap(colormap_name)
        colors = [colormap(value) for value in normalized_distances]

        hsv_colors = [mcolors.rgb_to_hsv(rgb[:3]) for rgb in colors] 

        modified_hsv_colors = []
        for i, (h, s, v) in enumerate(hsv_colors):
            if (i + 1) % 2 != 0: 
                s = min(1, s + 0.7 * s)
            modified_hsv_colors.append((h, s, v))

        modified_rgb_from_hsv = [mcolors.hsv_to_rgb(hsv) for hsv in modified_hsv_colors]

        rgb_list = [list(rgb) for rgb in modified_rgb_from_hsv]

        lipocolor = pd.DataFrame(rgb_list, index=leaf_sequence, columns=['R', 'G', 'B'])

        lipocolor_reset = lipocolor.reset_index().rename(columns={'index': 'cleanlipizone'})
        print(lipocolor_reset)
        indices = datasubaqueo.index

        datasubaqueo = datasubaqueo.iloc[:,:-3]
        datasubaqueo = pd.merge(datasubaqueo, lipocolor_reset, on='cleanlipizone', how='left')

        datasubaqueo.index = indices

        dd2.update(datasubaqueo[['R', 'G', 'B']])

    else:
        datasubaqueo = dd2[dd2['base'] == division]
        datasubaqueo['R'] = 0
        datasubaqueo['G'] = 0
        datasubaqueo['B'] = 0
        dd2.update(datasubaqueo[['R', 'G', 'B']])

def rgb_to_hex(r, g, b):
    try:
        """Convert RGB (0-1 range) to hexadecimal color."""
        r, g, b = [int(255 * x) for x in [r, g, b]]  # scale to 0-255
        return f'#{r:02x}{g:02x}{b:02x}'
    except:
        return np.nan

dd2['lipizone_color'] = dd2.apply(lambda row: rgb_to_hex(row['R'], row['G'], row['B']), axis=1)

dd2['lipizone_color'].fillna('gray', inplace=True) 

In [None]:
fig, axes = plt.subplots(4, 8, figsize=(40, 20))
axes = axes.flatten()
dot_size = 0.3

sections_to_plot = range(1,33)

global_min_z = dd2['zccf'].min()
global_max_z = dd2['zccf'].max()
global_min_y = -dd2['yccf'].max() 
global_max_y = -dd2['yccf'].min()  

for i, section_num in enumerate(sections_to_plot):
    ax = axes[i]
    xx = dd2[dd2["Section"] == section_num]
    sc1 = ax.scatter(xx['zccf'], -xx['yccf'],
                     c=np.array(xx['lipizone_color']), s=dot_size, alpha=1, rasterized=True)

    ax.axis('off')
    ax.set_aspect('equal')  
    ax.set_xlim(global_min_z, global_max_z)
    ax.set_ylim(global_min_y, global_max_y)

for j in range(i+1, len(axes)):
    fig.delaxes(axes[j])
plt.tight_layout()
plt.show()

In [None]:
# make the tree

clusterxname = callosalwm[['cleanlipizone', 'lipizone']].drop_duplicates()
clusterxname = clusterxname.sort_values('lipizone')
clusterxname.index = clusterxname['lipizone'].astype(str)
clusterxname = clusterxname[['cleanlipizone']]
clusterxname

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, dendrogram

def hierarchical_distance(s1, s2, weighting='linear'):
    if len(s1) != len(s2):
        raise ValueError("Strings must be of equal length.")
    n = len(s1)
    if weighting == 'linear':
        weights = list(range(n, 0, -1))
    elif weighting == 'exponential':
        weights = [2**(n - i - 1) for i in range(n)]
    else:
        pass
    distance = 0
    differences = []
    for i, (c1, c2, w) in enumerate(zip(s1, s2, weights)):
        if c1 != c2:
            distance += w
            differences.append((i, c1, c2, w))
    return distance

data = clusterxname.index[::-1]
n = len(data)
dist_matrix = np.zeros((n, n))
for i in range(n):
    for j in range(i+1, n):
        dist = hierarchical_distance(data[i], data[j])
        dist_matrix[i, j] = dist
        dist_matrix[j, i] = dist

condensed_dist = []
for i in range(n):
    for j in range(i+1, n):
        condensed_dist.append(dist_matrix[i, j])
condensed_dist = np.array(condensed_dist)

Z = linkage(condensed_dist, method='average')

plt.figure(figsize=(4, 9))

dendro = dendrogram(
    Z,labels=data,
    orientation='left',
    color_threshold=0,
    above_threshold_color='black',
)

plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['bottom'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.tick_params(axis='both', which='both', length=0)
plt.xticks([])
plt.yticks(fontsize=12)
plt.tight_layout()

plt.show()

In [None]:
dd2 = pd.concat([dd2, callosalwm['lipizone']],axis=1)

In [None]:
clusterxname = dd2[['cleanlipizone', 'lipizone', 'lipizone_color']].drop_duplicates()
clusterxname = clusterxname.iloc[:, [0,1,4]]########
clusterxname                               

In [None]:
clusterxname = clusterxname.sort_values('lipizone')
clusterxname.index = clusterxname['lipizone'].astype(str)
clusterxname = clusterxname[['cleanlipizone', 'lipizone_color']].loc[dendro['ivl'],:][::-1]
clusterxname

In [None]:
unique_colors = clusterxname['lipizone_color'].unique()
unique_labels = clusterxname['cleanlipizone'].unique()

fig, ax = plt.subplots(figsize=(5, len(unique_colors)*0.5))

ax.axis('off')

legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                               markerfacecolor=color, 
                               markersize=10, 
                               label=label)
                   for color, label in zip(unique_colors, unique_labels)]

ax.legend(handles=legend_elements, loc='center left')
    
plt.tight_layout()

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
import xgboost as xgb
import matplotlib.pyplot as plt

DS = 1
basespace = dd2.iloc[:,:173]
target = dd2['cleanlipizone']
X = basespace[::DS]
y = target[::DS]
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(
    X, y_encoded, 
    test_size=0.2, 
    random_state=42, 
    stratify=y_encoded
)

smote = SMOTE(random_state=42)
X_train_res, y_train_res = smote.fit_resample(X_train, y_train)

print("Original class distribution:", np.bincount(y_train))
print("Resampled class distribution:", np.bincount(y_train_res))

xgb_clf = xgb.XGBClassifier(
    objective='multi:softprob',
    num_class=len(np.unique(y_train)),
    eval_metric='mlogloss',
    use_label_encoder=False,
    random_state=42
)

xgb_clf.fit(
    X_train_res, 
    y_train_res,
    eval_set=[(X_test, y_test)],
    early_stopping_rounds=10,
    verbose=True
)

y_pred = xgb_clf.predict(X_test)
y_test_decoded = label_encoder.inverse_transform(y_test)
y_pred_decoded = label_encoder.inverse_transform(y_pred)

print("Classification Report:")
print(classification_report(y_test_decoded, y_pred_decoded))

print("Confusion Matrix:")
print(confusion_matrix(y_test_decoded, y_pred_decoded))

xgb.plot_importance(xgb_clf)
plt.show()

In [None]:
plt.imshow(confusion_matrix(y_test_decoded, y_pred_decoded), cmap="Reds")
plt.show()

In [None]:
import shap

explainer = shap.TreeExplainer(xgb_clf)
shap_values = explainer.shap_values(X_test)
feature_importance = np.abs(shap_values).mean(axis=0)

feature_importance = pd.DataFrame(feature_importance, columns = label_encoder.classes_, index = X_test.columns).T
feature_importance

In [None]:
top3_columns = feature_importance.apply(lambda row: row.nlargest(3).index.tolist(), axis=1)
palette = top3_columns.explode().unique().tolist()
print(len(palette))

palette = feature_importance.loc[:, palette]

plt.imshow(palette)
plt.show()

clusterxname.index = clusterxname['cleanlipizone']
colors = clusterxname.loc[palette.index, 'lipizone_color'].values

toplot = palette

num_rows, num_columns = toplot.shape

fig, axes = plt.subplots(1, num_columns, figsize=(10, 7), sharey=True)

for i, ax in enumerate(axes):
   
    ax.bar(
        np.arange(num_rows),
        toplot.iloc[:, i],
        color= colors
    )
    
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='both', which='both', length=0)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(toplot.columns[i], rotation=90, ha='center', va='top')

plt.tight_layout()
plt.show()

In [None]:
import scipy.cluster.hierarchy as sch

normalized_df = palette
linkage = sch.linkage(sch.distance.pdist(normalized_df.T), method='weighted', optimal_ordering=True)
order = sch.leaves_list(linkage)
normalized_df = normalized_df.iloc[:, order]
order = np.argmax(normalized_df.T.values, axis=1)
order = np.argsort(order)
normalized_df = normalized_df.iloc[:, order]
normalized_df = normalized_df.dropna().replace([np.inf, -np.inf], np.nan).dropna()

import seaborn as sns

fig, ax1 = plt.subplots(figsize=(20, 10))
sns.heatmap(normalized_df, cmap="Oranges", ax=ax1, cbar_kws={'label': 'Enrichment'},
            xticklabels=True, yticklabels=True, vmin = 0.2,vmax=np.percentile(normalized_df, 98))

ax1.tick_params(axis='x', which='both', bottom=False, top=False)
ax1.tick_params(axis='y', which='both', left=False, right=False, pad=20)
plt.show()

## Study the mediolateral pattern

In [None]:
# look at the most abundant lipizones in the anterior region to find (if any) mediolateral patterns

for testlev in dd2.loc[dd2['Section'].isin([4, 5]), 'cleanlipizone'].value_counts()[:10].index:

    fig, axes = plt.subplots(1,3, figsize=(10, 5))
    axes = axes.flatten()
    dot_size = 0.3

    sections_to_plot = [4, 5]

    global_min_z = datavignettes['zccf'].min()
    global_max_z = datavignettes['zccf'].max()
    global_min_y = -datavignettes['yccf'].max() 
    global_max_y = -datavignettes['yccf'].min()  

    for i, section_num in enumerate(sections_to_plot):
        ax = axes[i]
        xx = dd2[dd2["Section"] == section_num]
        sc1 = ax.scatter(xx['zccf'], -xx['yccf'],
                         c=np.array(xx['cleanlipizone'].astype("category").cat.codes), cmap="Greys", s=dot_size*2, alpha=1, rasterized=True)

        highlight = xx.loc[xx['cleanlipizone'] == testlev,:]
        
        ax.scatter(highlight['zccf'], -highlight['yccf'],
                         c=highlight['lipizone_color'], s=dot_size*4, alpha=1, rasterized=True)

        ax.axis('off')
        ax.set_aspect('equal')  
        ax.set_xlim(global_min_z, global_max_z)
        ax.set_ylim(global_min_y, global_max_y)

        
    axes[2].hist(highlight['zccf'], bins=10, density=True, color=highlight['lipizone_color'].iloc[0])
    axes[2].set_xlim((global_min_z, global_max_z))
    axes[2].set_ylim((0., 0.5))
    axes[2].axis('off')
    axes[2].set_aspect('equal')
    plt.suptitle(testlev)
    plt.tight_layout()
    plt.show()

In [None]:
data_withmeta = datavignettes.loc[callosalwm.index,:]
data_withmeta['cleanlipizone'] = callosalwm['cleanlipizone']
data_withmeta['lipizone_color'] = callosalwm['lipizone_color']

In [None]:
checkn = data_withmeta.loc[(data_withmeta['Section'].isin([4, 5])) & (data_withmeta['cleanlipizone'].isin(data_withmeta.loc[data_withmeta['Section'].isin([4, 5]), 'cleanlipizone'].value_counts()[:10].index)),:]

checkn['mediolateral'] = "nope"
checkn.loc[checkn['cleanlipizone'].isin(["Anterior medial callosum and internal capsule linings", "Ventricular lining WM", "Anterior ventral WM", "Internal capsule and optic tract"]), 'mediolateral'] = "medial"
checkn.loc[checkn['cleanlipizone'].isin(["Anterior lateral callosum", "Mixed WM 8", "Anterior lateral callosum 1"]), 'mediolateral'] = "lateral"
checkn = checkn.loc[checkn['mediolateral'] != "nope",:]
checkn

In [None]:
checkn['mediolateral'].value_counts()

In [None]:
# a function to check for differential lipids between two groups

from scipy.stats import mannwhitneyu, entropy
import matplotlib.pyplot as plt
from tqdm import tqdm
from statsmodels.stats.multitest import multipletests
from tqdm import tqdm

def differential_lipids(lipidata, kmeans_labels, min_fc=0.2, pthr=0.05):
    results = []

    a = lipidata.loc[kmeans_labels == 0,:]
    b = lipidata.loc[kmeans_labels == 1,:]
    
    for rrr in range(lipidata.shape[1]):
       
        groupA = a.iloc[:,rrr]
        groupB = b.iloc[:,rrr]
    
        # log2 fold change
        meanA = np.mean(groupA)
        meanB = np.mean(groupB)
        log2fold_change = np.log2(meanB / (meanA + 1e-7)) #if meanA > 0 and meanB > 0 else np.nan
    
        # Wilcoxon test
        try:
            _, p_value = mannwhitneyu(groupA, groupB, alternative='two-sided')
        except ValueError:
            p_value = np.nan
    
        results.append({'lipid': rrr, 'log2fold_change': log2fold_change, 'p_value': p_value})

    results_df = pd.DataFrame(results)

    # correct for multiple testing
    reject, pvals_corrected, _, _ = multipletests(results_df['p_value'].values, alpha=0.05, method='fdr_bh')
    results_df['p_value_corrected'] = pvals_corrected
    
    return results_df


In [None]:
kmeans_labels = np.array(checkn['mediolateral'])
kmeans_labels[kmeans_labels == "medial"] = 0
kmeans_labels[kmeans_labels == "lateral"] = 1

difflips = differential_lipids(checkn.iloc[:, :173], kmeans_labels)
difflips

In [None]:
difflips.index = checkn.columns[:173]

In [None]:
difflips.sort_values("log2fold_change")[::-1][:10] # otherwise low fc

In [None]:
difflips.sort_values("log2fold_change")[:10][:3] # otherwise low fc

In [None]:
hits = np.concatenate((difflips.sort_values("log2fold_change")[:10][:3].index.values, difflips.sort_values("log2fold_change")[::-1][:10].index.values))
hits

In [None]:
checkn = checkn.loc[checkn['zccf'] < checkn['zccf'].mean(),:]
checkn

In [None]:
corrs = [np.abs(np.corrcoef(checkn[x], checkn['zccf'])[0,1]) for x in checkn.columns[:173]]
corrs = pd.DataFrame(corrs, index = checkn.columns[:173], columns=['pearson'])
corrs = corrs.sort_values('pearson')
corrs[::-1][:20]

In [None]:
hits = corrs[::-1][:25].index

In [None]:
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

results = []

filtered_data = checkn

for currentPC in hits:

    """
    for section in filtered_data['Section'].unique():
        subset = filtered_data[filtered_data['Section'] == section]

        perc_2 = subset[currentPC].quantile(0.02)
        perc_98 = subset[currentPC].quantile(0.98)

        results.append([section, perc_2, perc_98])
    percentile_df = pd.DataFrame(results, columns=['Section', '2-perc', '98-perc'])
    med2p = percentile_df['2-perc'].median()
    med98p = percentile_df['98-perc'].median()
    """
    cmap = plt.cm.plasma

    fig, axes = plt.subplots(1,2, figsize=(10, 5))
    axes = axes.flatten()
    sections_to_plot = [4, 5]
    
    for section in sections_to_plot:
        ax = axes[section - 1-4]
        ddf = filtered_data[(filtered_data['Section'] == section)]
        ddf_sorted = ddf.sort_values(by=currentPC)
        ax.scatter(ddf_sorted['zccf'], -ddf_sorted['yccf'], c=ddf_sorted[currentPC], cmap="plasma", s=110,alpha=0.2,edgecolors='none',rasterized=True, vmin=np.percentile(ddf_sorted[currentPC], 5), vmax=np.percentile(ddf_sorted[currentPC], 85)) 
        ax.axis('off')
        ax.set_aspect('equal')

    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    #norm = Normalize(vmin=med2p, vmax=med98p)
    #sm = ScalarMappable(norm=norm, cmap="plasma")
    #fig.colorbar(sm, cax=cbar_ax)
    plt.title(currentPC)
    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.savefig("mediolateral_lipids_"+currentPC+".pdf")
    plt.show()


In [None]:
# ML axis - nice, it really seems there's a medial and there's a lateral cluster overall

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(10, 6))
callosalwm['lipizone_color'] = dd2['lipizone_color']
callosalwm2 = callosalwm.loc[callosalwm['zccf'] < np.mean(callosalwm['zccf']),:]
palettez = callosalwm2.drop_duplicates(subset='cleanlipizone').set_index('cleanlipizone')['lipizone_color'].to_dict()


sns.kdeplot(
    data=callosalwm2,
    x='zccf',
    hue='cleanlipizone',
    palette=palettez,
    multiple='layer', 
    fill=False,      
    linewidth=0.4     
)
plt.legend(title='Zone')
plt.xlabel('Value')
plt.ylabel('Density')
plt.title('Overlaid Histograms by Zone')
ax = plt.gca()
sns.despine(ax=ax, top=True, right=True, left=True, bottom=True)
ax.tick_params(left=False, bottom=False)
plt.show()

## Study the bundle in the middle of the anterior callosum

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

dv = datavignettes.loc[datavignettes['Section'] == 12,:]
dv = dv.loc[(dv['zccf'] > 4) & (dv['zccf'] < 7.5) & (-dv['yccf'] < -1) & (-dv['yccf'] > -2.2),:]
dv = dv.loc[dv['level_1'] == 1,:]
plt.scatter(dv['zccf'], -dv['yccf'],c=dv['subclass_color'],s=10)
plt.show()

In [None]:
wm_voxels_sub = dv

kmeans_labels = np.array(wm_voxels_sub['subclass_name'])
kmeans_labels[wm_voxels_sub['subclass_name'] == "Neuron-rich lateral white matter #2"] = 1
kmeans_labels[wm_voxels_sub['subclass_name'] != "Neuron-rich lateral white matter #2"] = 0

difflips = differential_lipids(wm_voxels_sub.loc[:, wm_voxels_sub.columns[:173]], kmeans_labels)
difflips.index = wm_voxels_sub.columns[:173]

print(difflips.sort_values('log2fold_change')[:9])
print(difflips.sort_values('log2fold_change')[-9:])

toplot = np.concatenate([np.array(difflips.sort_values('log2fold_change')[:9].index), np.array(difflips.sort_values('log2fold_change')[-9:].index)])
dot_size = 0.2  
num_plots = 18
rows, cols = 6, 3
fig, axes = plt.subplots(rows, cols, figsize=(5, 8))
axes = axes.flatten()

for iii in range(num_plots):
    if iii < len(toplot):
        ax = axes[iii]
        
        sc1 = ax.scatter(wm_voxels_sub['z_index'], -wm_voxels_sub['y_index'],
                         c=wm_voxels_sub[toplot[iii]], cmap='inferno', s=2, alpha=0.8, vmax=np.percentile(wm_voxels_sub[toplot[iii]], 90), rasterized=True)
        
        ax.set_title(toplot[iii], fontsize=8)
        ax.set_aspect('equal')
        ax.autoscale()  # This will automatically set the axis limits
        ax.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
        ax.axis('off')
    else:
        axes[iii].axis('off')

plt.tight_layout()
plt.savefig("axons.pdf")
plt.show()

## Is it oligodendrocyte or axonal heterogeneity? [XGBC]

In [None]:
atlas = datavignettes
ctl = pd.read_csv("/data/luca/lipidatlas/ManuscriptGithub/zenodo/csv/celltype_lipidomes.csv", index_col=0)
ctl = ctl.iloc[2:,:]

fitznerlips = ctl.index
oli = np.nanmean(ctl[['OligodendrocytesDIV11', 'OligodendrocytesDIV12',
       'OligodendrocytesDIV13', 'OligodendrocytesDIV251',
       'OligodendrocytesDIV252', 'OligodendrocytesDIV253',
       'OligodendrocytesDIV41', 'OligodendrocytesDIV42',
       'OligodendrocytesDIV43']].astype(float),axis=1)

neu = np.nanmean(ctl[['NeuronsDIV101', 'NeuronsDIV102', 'NeuronsDIV103', 'NeuronsDIV161',
       'NeuronsDIV162', 'NeuronsDIV163', 'NeuronsDIV51', 'NeuronsDIV52',
       'NeuronsDIV53']].astype(float),axis=1)

aveprof = pd.DataFrame([oli, neu], columns=ctl.index, index = ["oli", "neu"]).T
aveprof.fillna(0, inplace=True)
names = pd.read_csv("/data/luca/lipidatlas/ManuscriptGithub/zenodo/csv/goslinfitzner_celltyoes.tsv", sep="\t")
namesgoslin = names[['Original Name', 'Species Name']]
namesgoslin.index = namesgoslin['Original Name']
aveprof.index = namesgoslin.loc[:,'Species Name']

# extract oligo and neuro-specific lipids as having > 10x enrichment in quantitative LCMS
aveprof = aveprof.loc[aveprof.index.isin(atlas.columns[:173]),:]
ratio = (aveprof['neu'] / (aveprof['oli']+0.00001)).sort_values()
ratio = ratio.dropna()
neuronalmarkers = ratio[ratio > 10].index.unique()
ratio = (aveprof['oli'] / (aveprof['neu']+0.00001)).sort_values()
ratio = ratio.dropna()
oligomarkers = ratio[ratio > 10].index.unique()

# extract the callosal WM from our data
callosalwm = atlas.loc[(atlas['level_1'] == 1) & (atlas['level_2'] == 1) & (atlas['level_3'] == 1),:]
featurespace_oligo = callosalwm.loc[:, oligomarkers]
featurespace_neuro = callosalwm.loc[:, neuronalmarkers]
clusterlabels = callosalwm[['old_lipizone_names']]

In [None]:
# use the manually curated lipizone names
annotated = pd.read_csv("/data/luca/lipidatlas/ManuscriptGithub/zenodo/csv/callosalwm_annotated.csv", header=None)
annotated.columns = ['automatedanno', 'manualanno']
annotated.index = [annotated['automatedanno'][i][:-4] for i in range(annotated.shape[0])]
counts = annotated.groupby('manualanno').cumcount()
annotated['manualanno'] = annotated['manualanno'].where(counts == 0, 
                                    annotated['manualanno'] + ' ' + counts.astype(str))

len(annotated['manualanno'].unique())
callosalwm['cleanlipizone'] = callosalwm['old_lipizone_names'].map(annotated['manualanno'])
callosalwm['cleanlipizone']

In [None]:
# let's repeat, keeping only oligo and neuro lipids (again could be slightly unfair, should penalize heavily when scoring...)

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
import xgboost as xgb
import matplotlib.pyplot as plt
from threadpoolctl import threadpool_limits, threadpool_info
threadpool_limits(limits=8)
import os
os.environ['OMP_NUM_THREADS'] = '6'

DS = 1
basespace = callosalwm.iloc[:,:173]
target = callosalwm['cleanlipizone']
X = basespace[::DS]
y = target[::DS]
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(
    X, y_encoded, 
    test_size=0.2, 
    random_state=42, 
    stratify=y_encoded
)

smote = SMOTE(random_state=42)
X_train_res, y_train_res = smote.fit_resample(X_train, y_train)

print("Original class distribution:", np.bincount(y_train))
print("Resampled class distribution:", np.bincount(y_train_res))

xgb_clf = xgb.XGBClassifier(
    objective='multi:softprob',
    num_class=len(np.unique(y_train)),
    eval_metric='mlogloss',
    use_label_encoder=False,
    random_state=42
)

xgb_clf.fit(
    X_train_res, 
    y_train_res,
    eval_set=[(X_test, y_test)],
    early_stopping_rounds=10,
    verbose=True
)

y_pred = xgb_clf.predict(X_test)
y_test_decoded = label_encoder.inverse_transform(y_test)
y_pred_decoded = label_encoder.inverse_transform(y_pred)

print("Classification Report:")
print(classification_report(y_test_decoded, y_pred_decoded))

print("Confusion Matrix:")
print(confusion_matrix(y_test_decoded, y_pred_decoded))

xgb.plot_importance(xgb_clf)
plt.show()

plt.imshow(confusion_matrix(y_test_decoded, y_pred_decoded), cmap="Reds")
plt.show()

In [None]:
xgb_clf.save_model("xgb_model_WM.json")

In [None]:
# use shap to extract feature importances
import shap

explainer = shap.TreeExplainer(xgb_clf)
shap_values = explainer.shap_values(X_test)
feature_importance = np.abs(shap_values).mean(axis=0)

feature_importance = pd.DataFrame(feature_importance, columns = label_encoder.classes_, index = X_test.columns).T
feature_importance

In [None]:
# GSEA analysis

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import gseapy as gp
from gseapy.plot import gseaplot
import xgboost as xgb
import shap
from sklearn.preprocessing import LabelEncoder

DS = 1
X = callosalwm.iloc[:, :173][::DS]
y = callosalwm["cleanlipizone"][::DS]

label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
)

explainer    = shap.TreeExplainer(xgb_clf)
raw_shap     = explainer.shap_values(X_test)
if isinstance(raw_shap, list):
    shap_arr_3d = np.stack(raw_shap, axis=2)
else:
    shap_arr_3d = raw_shap

feature_importance = np.abs(shap_arr_3d).mean(axis=0)
feature_importance = pd.DataFrame(
    feature_importance,
    columns=label_encoder.classes_,
    index = X_test.columns
).T

ctl = pd.read_csv("/data/luca/lipidatlas/ManuscriptGithub/zenodo/csv/celltype_lipidomes.csv", index_col=0)
ctl = ctl.iloc[2:,:]

fitznerlips = ctl.index
oli = np.nanmean(ctl[['OligodendrocytesDIV11', 'OligodendrocytesDIV12',
       'OligodendrocytesDIV13', 'OligodendrocytesDIV251',
       'OligodendrocytesDIV252', 'OligodendrocytesDIV253',
       'OligodendrocytesDIV41', 'OligodendrocytesDIV42',
       'OligodendrocytesDIV43']].astype(float),axis=1)

neu = np.nanmean(ctl[['NeuronsDIV101', 'NeuronsDIV102', 'NeuronsDIV103', 'NeuronsDIV161',
       'NeuronsDIV162', 'NeuronsDIV163', 'NeuronsDIV51', 'NeuronsDIV52',
       'NeuronsDIV53']].astype(float),axis=1)

aveprof = pd.DataFrame([oli, neu], columns=ctl.index, index = ["oli", "neu"]).T
aveprof.fillna(0, inplace=True)
names = pd.read_csv("/data/luca/lipidatlas/ManuscriptGithub/zenodo/csv/goslinfitzner_celltyoes.tsv", sep="\t")
namesgoslin = names[['Original Name', 'Species Name']]
namesgoslin.index = namesgoslin['Original Name']
aveprof.index = namesgoslin.loc[:,'Species Name']

# extract oligo and neuro-specific lipids as having > 10x enrichment in quantitative LCMS
aveprof = aveprof.loc[aveprof.index.isin(callosalwm.columns[:173]),:]
ratio = (aveprof['neu'] / (aveprof['oli']+0.00001)).sort_values()
ratio = ratio.dropna()
neuronalmarkers = ratio[ratio > 10].index.unique()
ratio = (aveprof['oli'] / (aveprof['neu']+0.00001)).sort_values()
ratio = ratio.dropna()
oligomarkers = ratio[ratio > 10].index.unique()

# compute the per-feature mean importance (across classes):
toplot = feature_importance.quantile(q=0.95).sort_values()
toplot = toplot.loc[(toplot.index.isin(oligomarkers)) | (toplot.index.isin(neuronalmarkers))]
toplot = toplot.loc[~toplot.index.isin(np.intersect1d(oligomarkers, neuronalmarkers))]
# restrict to the two marker‐sets:
toplot2 = toplot.loc[
    toplot.index.isin(oligomarkers) |
    toplot.index.isin(neuronalmarkers)
]

ranked = toplot.sort_values(ascending=False)

rnk_df = ranked.reset_index()
rnk_df.columns = ['Term','Score']
rnk_df = rnk_df.sort_values('Score', ascending=False)

top_n = 20
bot_n = 20
all_terms = ranked.index.tolist()
pos_control = all_terms[:top_n]      
neg_control = all_terms[-bot_n:]     
gene_sets_ctrl = {
    'oligomarkers': list(oligomarkers),
    'neuronalmarkers': list(neuronalmarkers),
    'positive_control': pos_control,
    'negative_control': neg_control,
}

pre_res2 = gp.prerank(
    rnk=rnk_df,
    gene_sets=gene_sets_ctrl,
    processes=2,
    permutation_num=2000,
    seed=42,
    outdir=None
)

res2 = pre_res2.res2d.set_index('Term')[['ES','NES','NOM p-val','FDR q-val']]
res2 = res2.rename(columns={'NOM p-val':'pval','FDR q-val':'fdr'})

profiles = {}
for term in gene_sets_ctrl.keys():
    res = np.array(pre_res2.results[term]['RES'])
    es, nes = res2.loc[term, ['ES','NES']]
    mean_null = es / nes
    profiles[term] = {
        'nes_profile': res / mean_null,
        'NES': nes,
        'pval': res2.loc[term, 'pval'],
        'fdr': res2.loc[term, 'fdr']
    }

positions = np.arange(len(profiles['oligomarkers']['nes_profile']))
plt.figure(figsize=(5, 4))

colors = {
    'oligomarkers': 'tab:red',
    'neuronalmarkers': 'tab:blue',
}

for term, props in profiles.items():
    prof = props['nes_profile']
    nes = props['NES']
    p = props['pval']
    label = f"{term} (NES={nes:.2f}, p={p:.3f})"

    if term == 'positive_control':
        plt.plot(positions, prof,
                 color='black',
                 label=label,
                 alpha=1.0)
    elif term == 'negative_control':
        plt.plot(positions, prof,
                 color='gray',
                 label=label,
                 alpha=1.0)
    else:
        color = colors[term]
        plt.plot(positions, prof,
                 color=color,
                 label=label)
        plt.fill_between(positions, prof, 0,
                         where=(prof >= 0), interpolate=True,
                         alpha=0.2, color=color)

plt.axhline(0, color='gray', linestyle='--')
plt.xlabel("Ranked List Position")
plt.ylabel("NES profile (scaled to peak)")
plt.gca().set_ylim(bottom=-1)

plt.title("GSEApy prerank: NES profiles with positive/negative controls and p-value")
plt.legend(loc='best')
plt.tight_layout()
plt.savefig("gsea.pdf")
plt.show()

from scipy.stats import mannwhitneyu

colors = ['tab:red', 'tab:blue']

oligo_vals  = toplot2.loc[[f for f in toplot2.index if f in oligomarkers]].values
neuro_vals  = toplot2.loc[[f for f in toplot2.index if f in neuronalmarkers]].values

fig, ax = plt.subplots(figsize=(3, 6))

bplot = ax.boxplot(
    [oligo_vals, neuro_vals],
    labels=["Oligo markers", "Neuronal markers"],
    patch_artist=True,
    whis=1.5,
    widths=0.6,
    boxprops=dict(edgecolor="none", linewidth=0),
    whiskerprops=dict(linewidth=1.5),
    capprops=dict(linewidth=1.5),
    medianprops=dict(color="white", linewidth=2),
    flierprops=dict(marker='o', markersize=5, linestyle='none')
)

for i, box in enumerate(bplot['boxes']):
    box.set_facecolor(colors[i])

for i, whisker in enumerate(bplot['whiskers']):
    whisker.set_color(colors[i//2])

for i, cap in enumerate(bplot['caps']):
    cap.set_color(colors[i//2])

for i, median in enumerate(bplot['medians']):
    pass

for i, flier in enumerate(bplot['fliers']):
    flier.set_markerfacecolor(colors[i])
    flier.set_markeredgecolor(colors[i])

ax.set_ylabel("Mean |SHAP value|", fontsize=12)
ax.set_title("Feature-importance distributions by marker group", fontsize=14)
ax.tick_params(axis='both', which='major', labelsize=11)
ax.grid(True, axis='y', linestyle='--', alpha=0.4)

top5_oligo = toplot2.loc[[f for f in toplot2.index if f in oligomarkers]].nlargest(5)
top5_neuro = toplot2.loc[[f for f in toplot2.index if f in neuronalmarkers]].nlargest(5)
textstr = (
    "Top-5 Oligo:\n" + "\n".join(top5_oligo.index) + "\n\n"
    "Top-5 Neuro:\n" + "\n".join(top5_neuro.index)
)
bbox_props = dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.8)
ax.text(1.02, 0.95, textstr, transform=ax.transAxes,
        fontsize=10, va='top', ha='left', bbox=bbox_props)

stat, pval = mannwhitneyu(oligo_vals, neuro_vals, alternative='two-sided')
y_max = max(np.nanmax(oligo_vals), np.nanmax(neuro_vals))
h = 0.05 * y_max
x1, x2 = 1, 2
ax.plot([x1, x1, x2, x2], [y_max, y_max+h, y_max+h, y_max],
        lw=1.5, color='black')
ax.text((x1+x2)*0.5, y_max+h, f"p = {pval:.3g}",
        ha='center', va='bottom', fontsize=12)

plt.tight_layout()
plt.savefig("featimp_boxplots_colored.pdf", dpi=300)
plt.show()

## Is it oligodendrocyte or axonal heterogeneity? [XGBC on programs]

In [None]:
programs = ['HexCer + hexosylceramides', 'PC',
       'PA + diacylglycerophosphates [GP1001]', 'PE',
       'PS + diacylglycerophosphoserines [GP0301]',
       'PI + diacylglycerophosphoinositols [GP0601]',
       'LPC + monoacylglycerophosphocholines [GP0105]',
       'diacylglycerophosphocholines [GP0101]',
       '1-alkyl-2-acylglycerophosphocholines [GP0102]',
       'diacylglycerophosphoethanolamines [GP0201]',
       '1-alkyl-2-acylglycerophosphoethanolamines [GP0202]',
       'monoacylglycerophosphoethanolamines [GP0205] + LPE',
       'diacylglycerophosphoglycerols [GP0401] + PG',
       'headgroup with negative charge', 'headgroup with neutral charge',
       'headgroup with positive charge / zwitter-ion',
       'simple glc series [SP0501]', 'negative intrinsic curvature',
       'neutral intrinsic curvature', 'positive intrinsic curvature',
       'contains ether-bond', 'lysoglycerophospholipids',
       'very low transition temperature', 'low transition temperature',
       'average transition temperature', 'high transition temperature',
       'very high transition temperature', 'fatty acid with 18 carbons',
       'fatty acid with 20 carbons', 'fatty acid with 22 carbons',
       'saturated fatty acid', 'monounsaturated fatty acid',
       'fatty acid with 4 double bonds', 'lipid-mediated signalling',
       'endoplasmic reticulum (ER)', 'mitochondrion', 'plasma membrane',
       'endosome/lysosome + SM + ceramide phosphocholines (sphingomyelins) [SP0301] + golgi apparatus',
       'N-acylsphingosines (ceramides) [SP0201] + Cer',
       'very low bilayer thickness', 'low bilayer thickness',
       'average bilayer thickness', 'high bilayer thickness',
       'very high bilayer thickness', 'very low lateral diffusion',
       'low lateral diffusion', 'average lateral diffusion',
       'high lateral diffusion', 'very high lateral diffusion']

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
import xgboost as xgb
import matplotlib.pyplot as plt
from threadpoolctl import threadpool_limits, threadpool_info
threadpool_limits(limits=8)
import os
os.environ['OMP_NUM_THREADS'] = '6'

callosalwm = pd.read_parquet("callosalwm.parquet")

DS = 1
basespace = callosalwm.loc[:, programs]
target = callosalwm['cleanlipizone']
X = basespace[::DS]
y = target[::DS]
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(
    X, y_encoded, 
    test_size=0.2, 
    random_state=42, 
    stratify=y_encoded
)

smote = SMOTE(random_state=42)
X_train_res, y_train_res = smote.fit_resample(X_train, y_train)

print("Original class distribution:", np.bincount(y_train))
print("Resampled class distribution:", np.bincount(y_train_res))

xgb_clf = xgb.XGBClassifier(
    objective='multi:softprob',
    num_class=len(np.unique(y_train)),
    eval_metric='mlogloss',
    use_label_encoder=False,
    random_state=42
)

cleaned_feature_names = [name.replace('[', '_').replace(']', '_').replace('<', '_') for name in X_train_res.columns]
X_train_res.columns = cleaned_feature_names

if hasattr(X_test, 'columns'):
    X_test.columns = [name.replace('[', '_').replace(']', '_').replace('<', '_') for name in X_test.columns]

xgb_clf.fit(
    X_train_res, 
    y_train_res,
    eval_set=[(X_test, y_test)],
    early_stopping_rounds=10,
    verbose=True
)

y_pred = xgb_clf.predict(X_test)
y_test_decoded = label_encoder.inverse_transform(y_test)
y_pred_decoded = label_encoder.inverse_transform(y_pred)

print("Classification Report:")
print(classification_report(y_test_decoded, y_pred_decoded))

print("Confusion Matrix:")
print(confusion_matrix(y_test_decoded, y_pred_decoded))

xgb.plot_importance(xgb_clf)
plt.show()

plt.imshow(confusion_matrix(y_test_decoded, y_pred_decoded), cmap="Reds")
plt.show()

import shap
explainer = shap.TreeExplainer(xgb_clf)
shap_values = explainer.shap_values(X_test)
feature_importance = np.abs(shap_values).mean(axis=0)
feature_importance = pd.DataFrame(feature_importance, columns=label_encoder.classes_, index=X_test.columns).T
feature_importance_binarized = feature_importance > np.percentile(feature_importance.values.flatten(), 90)

def create_rank_dataframe(feature_importance_df):
    ranks_df = feature_importance_df.copy()
    for idx in ranks_df.index:
        row_data = ranks_df.loc[idx]
        ranks_df.loc[idx] = row_data.rank(ascending=False)
    return ranks_df

ranks_df = 65 - create_rank_dataframe(feature_importance)
ranks_df

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

medians = ranks_df.median().sort_values(ascending=False)
top_features = medians.head(5).index
sorted_df = ranks_df[top_features]
df_melted = sorted_df.melt(var_name="Feature", value_name="Importance")

sns.set_style("white")
sns.set_context("notebook", font_scale=1.2)

plt.figure(figsize=(10, 6))
sns.boxplot(
    data=df_melted,
    x="Feature",
    y="Importance",
    boxprops=dict(facecolor='none', edgecolor='black', linewidth=2),
    medianprops=dict(color='black', linewidth=2),
    whiskerprops=dict(color='black', linewidth=2),
    capprops=dict(color='black', linewidth=2),
    fliersize=1
)
plt.xticks(rotation=90)
plt.title("Top 5 Feature Importance Distribution (sorted by median)")
plt.tight_layout()
plt.show()

In [None]:
xgb_clf.save_model("xgb_model_PROGRAMS_WM.json")

In [None]:
import shap
explainer = shap.TreeExplainer(xgb_clf)
shap_values = explainer.shap_values(X_test)
feature_importance = np.abs(shap_values).mean(axis=0)
feature_importance = pd.DataFrame(feature_importance, columns=label_encoder.classes_, index=X_test.columns).T
feature_importance

In [None]:
toplot = feature_importance.mean().sort_values()

In [None]:
import matplotlib.pyplot as plt

ax = toplot.plot(kind='bar', figsize=(8, 5), color="black")
ax.set_xlabel('Index')
ax.set_ylabel('Value')
ax.set_title('Vertical Barplot of toplot Series')
plt.tight_layout()
plt.savefig("featimp_programs.pdf")
plt.show()

## Reclustering

In [None]:
atlas = datavignettes
callosalwm = atlas.loc[(atlas['level_1'] == 1) & (atlas['level_2'] == 1) & (atlas['level_3'] == 1),:]

featurespace_oligo = callosalwm.loc[:, oligomarkers]
featurespace_neuro = callosalwm.loc[:, neuronalmarkers]
clusterlabels = callosalwm[['lipizone_names']]

featurespace_neuro = featurespace_neuro - featurespace_neuro.min()
featurespace_oligo = featurespace_oligo - featurespace_oligo.min()

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, adjusted_rand_score, normalized_mutual_info_score
import scanpy as sc
import anndata
import numpy as np
import os
from threadpoolctl import threadpool_limits
threadpool_limits(limits=4)
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['OPENBLAS_NUM_THREADS'] = '4'
os.environ['MKL_NUM_THREADS'] = '4'
os.environ['NUMEXPR_NUM_THREADS'] = '4'
import joblib
joblib.Parallel(n_jobs=1)
le = LabelEncoder()
clusterlabels_numeric = le.fit_transform(clusterlabels)
import scanpy as sc
sc.settings.n_jobs = 1
sc.settings.verbosity = 1

def pca_leiden_analysis(features, labels, name, n_components=10):
    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(features)
    pca = PCA(n_components=n_components, random_state=42)
    W = pca.fit_transform(features_scaled)
    adata = sc.AnnData(W)
    sc.pp.neighbors(adata, n_neighbors=15, random_state=42)
    sc.tl.leiden(adata, key_added='leiden', random_state=42)
    leiden_labels = adata.obs['leiden'].astype(str).values
    leiden_labels_numeric = np.array(leiden_labels, dtype=int)
    cm = confusion_matrix(labels, leiden_labels_numeric)
    ari = adjusted_rand_score(labels, leiden_labels_numeric)
    nmi = normalized_mutual_info_score(labels, leiden_labels_numeric)
    print(f"\n{name} Feature Space - PCA + Leiden Clustering:")
    print(f"  ARI = {ari:.4f}, NMI = {nmi:.4f}")
    print("  Confusion Matrix:")
    print(cm)
    return leiden_labels_numeric, ari, nmi

print("\nNeuronal Feature Space:")
leiden_neuro, ari_neuro, nmi_neuro = pca_leiden_analysis(featurespace_neuro, clusterlabels_numeric, "Neuronal")
featurespace_oligo_backup = featurespace_oligo.copy()

print("\n=== PCA + Leiden Clustering Analysis ===")
print("Oligodendrocyte Feature Space:")
leiden_oligo, ari_oligo, nmi_oligo = pca_leiden_analysis(featurespace_oligo, clusterlabels_numeric, "Oligodendrocyte")

df = pd.crosstab(clusterlabels_numeric, leiden_oligo)
df = df / df.sum()
normalized_df = df

import scipy.cluster.hierarchy as sch
linkage = sch.linkage(sch.distance.pdist(normalized_df.T), method='weighted', optimal_ordering=True)
order = sch.leaves_list(linkage)
normalized_df = normalized_df.iloc[:, order]
order = np.argmax(normalized_df.values, axis=1)
order = np.argsort(order)
normalized_df = normalized_df.iloc[order, :]

In [None]:
import seaborn as sns

plt.figure(figsize=(10,10))
sns.heatmap(normalized_df, cmap="Grays", cbar_kws={'label': 'Enrichment'}, xticklabels=True, yticklabels=False, vmin=0.0, vmax=0.1)
plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
plt.tick_params(axis='y', which='both', left=False, right=False)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

df = pd.crosstab(clusterlabels_numeric, leiden_neuro)
df = df / df.sum()
normalized_df = df

linkage = sch.linkage(sch.distance.pdist(normalized_df.T), method='weighted', optimal_ordering=True)
order = sch.leaves_list(linkage)
normalized_df = normalized_df.iloc[:, order]
order = np.argmax(normalized_df.values, axis=1)
order = np.argsort(order)
normalized_df = normalized_df.iloc[order, :]

plt.figure(figsize=(10,10))
sns.heatmap(normalized_df, cmap="Grays", cbar_kws={'label': 'Enrichment'}, xticklabels=True, yticklabels=False, vmin=0.0, vmax=0.1)
plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
plt.tick_params(axis='y', which='both', left=False, right=False)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()