In [None]:
import pandas as pd

connectome = pd.read_parquet("./connectomic_datasets/connectome2992features.parquet")
connectome

In [None]:
import os
import gc
import numpy as np
import pandas as pd
import scanpy as sc
import anndata
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform
from scipy.sparse import csr_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import NMF
from openTSNE import TSNEEmbedding, affinity, initialization
from tqdm import tqdm
from collections import deque
import harmonypy as hm
import networkx as nx
from threadpoolctl import threadpool_limits, threadpool_info

# configure thread limits
threadpool_limits(limits=8)
os.environ['OMP_NUM_THREADS'] = '6'

In [None]:
data = pd.read_parquet("atlas.parquet")
coordinates = data[['Section', 'xccf', 'yccf', 'zccf']]
coordinates['Section'] = coordinates['Section'].astype(int)

data = connectome
data

In [None]:
# use NMF to decompose the data into factors

def compute_seeded_NMF(data):  # data is a dataframe pixels x lipids
    # 1. calculate the correlation matrix of this dataset
    corr = np.corrcoef(data.values.T)
    corr_matrix = np.abs(corr)  # anticorrelated lipids convey the same info
    np.fill_diagonal(corr_matrix, 0)
    
    adata = anndata.AnnData(X=np.zeros_like(corr_matrix))
    adata.obsp['connectivities'] = csr_matrix(corr_matrix)
    adata.uns['neighbors'] = {
        'connectivities_key': 'connectivities',
        'distances_key': 'distances',
        'params': {'n_neighbors': 10, 'method': 'custom'}
    }
    
    G = nx.from_numpy_array(corr_matrix)
    
    # span reasonable Leiden resolution parameters
    gamma_values = np.linspace(0.8, 1.5, num=6) ##########################
    num_communities = []
    modularity_scores = []
    objective_values = []
    
    for gamma in gamma_values:
        sc.tl.leiden(adata, resolution=gamma, key_added=f'leiden_{gamma}')
        clusters = adata.obs[f'leiden_{gamma}'].astype(int).values
        num_comms = len(np.unique(clusters))
        num_communities.append(num_comms)
        partition = {i: clusters[i] for i in range(len(clusters))}
        modularity = nx.community.modularity(G, [np.where(clusters == i)[0] for i in range(num_comms)])
        modularity_scores.append(modularity)
    
    # 3. pick a number of blocks that is relatively high while preserving good modularity
    epsilon = 1e-10
    alpha = 0.7  # controls the weight of modularity vs pushing higher the number of communities
    for Q, N_c in zip(modularity_scores, num_communities):
        f_gamma = Q**alpha * np.log(N_c + 1 + epsilon)
        objective_values.append(f_gamma)
        
    plt.plot(np.arange(len(objective_values)), objective_values)
    plt.title("obj")
    plt.show()
    
    plt.plot(np.arange(len(modularity_scores)), modularity_scores)
    plt.title("mod")
    plt.show()
        
    plt.plot(np.arange(len(num_communities)), num_communities)
    plt.title("ncomms")
    plt.show()
    
    max_index = np.argmax(objective_values)
    best_gamma = gamma_values[max_index]
    best_modularity = modularity_scores[max_index]
    best_num_comms = num_communities[max_index]
    print(f'Number of communities at best gamma: {best_num_comms}')
    
    sc.tl.leiden(adata, resolution=best_gamma, key_added='leiden_best') # run Leiden one final time with best parameters
    clusters = adata.obs['leiden_best'].astype(int).values
    print(clusters)
    
    N_factors = best_num_comms
    
    # 4. pick a representative lipid from each block, use to initialize W
    dist = 1 - corr_matrix
    np.fill_diagonal(dist, 0)
    dist = np.maximum(dist, dist.T)  # as numerical instability makes it unreasonably asymmetric
    dist_condensed = squareform(dist, checks=True)
    representatives = []
    
    for i in range(0, N_factors):
        cluster_members = np.where(clusters == i)[0]
        print(cluster_members)
        if len(cluster_members) > 0:  # find most central feature in cluster
            mean_dist = dist[cluster_members][:, cluster_members].mean(axis=1)
            central_idx = cluster_members[np.argmin(mean_dist)]
            representatives.append(central_idx)
    
    W_init = data.values[:, representatives]
    
    # 5. initialize H as a subset of the correlation matrix
    H_init = corr[representatives,:]
    H_init[H_init < 0.] = 0.  # only positive correlated can contribute by def in NMF
    
    # 6. compute the NMF with this initialization and rank N
    N_factors = W_init.shape[1]
    nmf = NMF(
        n_components=N_factors,
        init='custom',
        random_state=42
    )
    data_offset = data - np.min(data) + 1e-7
    
    data_offset = np.ascontiguousarray(data_offset)
    W_init = np.ascontiguousarray(W_init)
    H_init = np.ascontiguousarray(H_init)
    
    nmf_result = nmf.fit_transform(
        data_offset,
        W=W_init,
        H=H_init
    )
    nmfdf = pd.DataFrame(nmf_result, index=data.index)
    factor_to_lipid = nmf.components_
    
    return nmfdf, factor_to_lipid, N_factors, nmf

nmfdf, factor_to_lipid, N_factors, nmf = compute_seeded_NMF(data)

In [None]:
nmfdf

In [None]:
corr = np.corrcoef(data.values.T)
corr_matrix = np.abs(corr)
np.fill_diagonal(corr_matrix, 0)

In [None]:
nmf_result = nmfdf.values

In [None]:
for PC_I in range(0, N_factors):

    results = []
    filtered_data = pd.concat([coordinates.loc[data.index,:], pd.DataFrame(nmf_result[:,PC_I], index=coordinates.loc[data.index,:].index,columns=["test"])], axis=1)

    currentPC = "test"
    
    for section in filtered_data['Section'].unique():
        subset = filtered_data[filtered_data['Section'] == section]

        perc_2 = subset[currentPC].quantile(0.02) 
        perc_98 = subset[currentPC].quantile(0.98)

        results.append([section, perc_2, perc_98])
    percentile_df = pd.DataFrame(results, columns=['Section', '2-perc', '98-perc'])
    med2p = percentile_df['2-perc'].median()
    med98p = percentile_df['98-perc'].median()

    cmap = plt.cm.PuOr

    fig, axes = plt.subplots(4, 8, figsize=(20, 10))
    axes = axes.flatten()

    for section in range(1, 33):
        ax = axes[section - 1]
        ddf = filtered_data[(filtered_data['Section'] == section)]

        ax.scatter(ddf['zccf'], -ddf['yccf'], c=ddf[currentPC], cmap="PuOr", s=0.5,rasterized=True, vmin=med2p, vmax=med98p) 
        ax.axis('off')
        ax.set_aspect('equal')

    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    norm = Normalize(vmin=med2p, vmax=med98p)
    sm = ScalarMappable(norm=norm, cmap=cmap)
    fig.colorbar(sm, cax=cbar_ax)

    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.show()