In [None]:
import pandas as pd
import anndata

imputed_h5ad_path ="./C57BL6J-638850-imputed-log2.h5ad" ####### download from the Allen institute the imputed MERFISH dataset
adata = anndata.read_h5ad(imputed_h5ad_path, backed=None)

In [None]:
adata.obs_names

In [None]:
adata.X.shape

In [None]:
adata.var['gene_symbol'].to_csv("genes_implist.csv")

In [None]:
adata.var['gene_symbol'][adata.var['gene_symbol'] == "Mbp"]

In [None]:
import pandas as pd

merfish = pd.read_parquet("./zenodo/multimodal/cell_filtered_w500genes.parquet") # just basic preprocessing of the Allen MERFISH coronal atlas
datavignettes = pd.read_parquet("./zenodo/maindata_2.parquet")
lipidsinallen = datavignettes[['xccf','yccf','zccf']].dropna()
merfishinallen = merfish[['x_ccf', 'y_ccf', 'z_ccf']]
merfishinallen.columns = ['xccf','yccf','zccf']
merfishinallen

In [None]:
xx = 'xccf'
print(merfishinallen[xx].max())
print(lipidsinallen[xx].max())

xx = 'yccf'
print(merfishinallen[xx].max())
print(lipidsinallen[xx].max())

xx = 'zccf'
print(merfishinallen[xx].max())
print(lipidsinallen[xx].max()) # perfect

## Match the two datasets by constrained neighbor search

In [None]:
from allensdk.core.mouse_connectivity_cache import MouseConnectivityCache
mcc = MouseConnectivityCache(manifest_file='mouse_connectivity_manifest.json')
annotation, _ = mcc.get_annotation_volume()
merfish['x_index'] = (merfish['x_ccf']*40).astype(int)
merfish['y_index'] = (merfish['y_ccf']*40).astype(int)
merfish['z_index'] = (merfish['z_ccf']*40).astype(int)
merfish['id'] = annotation[merfish['x_index'], merfish['y_index'], merfish['z_index']]
merfish['id']

In [None]:
datavignettes = datavignettes.dropna(subset=['id'])
datavignettes['id'] = datavignettes['id'].astype(int).astype(str)
merfish['id'] = merfish['id'].astype(str)
merfishinallen['id'] = merfish['id'].values
#drop vascular and immune cells first...
merfishinallen['division'] = merfish['division'].values
merfishinallen = merfishinallen.loc[~merfishinallen['division'].isin(['6 Vascular', '7 Immune']),:]
datavignettes =datavignettes.dropna(subset=['xccf'])
datavignettess = datavignettes.copy().loc[datavignettes['SectionID'].isin([76.,  82., 106.,   2., 131.,  88.,  63., 112.,  60.,  62., 118.,
     21.,  45., 123.,  58., 100.,  83.,  61.,  59.,  98.,  28.,  19.,
     43.,  18., 107.,  29., 104., 124.,  52., 129.,  14.,  78.,  15.,
     65.,  89.,  41., 117., 111.,  68.,  70., 125.,  92.,  16., 122.,
    114.,  91.,  11.,  24.,  71.,  46.,  57., 120.,  75.]),:]# focus on preselected good sections...

datavignettess

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.spatial import cKDTree
from threadpoolctl import threadpool_limits, threadpool_info
threadpool_limits(limits=8)
import os
os.environ['OMP_NUM_THREADS'] = '6'

# 1) pre-group merfish and build trees - just a fast precomputation here!

trees = {}
feats = {}

for id_, sub in tqdm(merfish.groupby('id')): 
    coords = sub[['x_ccf','y_ccf','z_ccf']].to_numpy()
    trees[id_] = cKDTree(coords)
    valid_idx = sub.index[sub.index.isin(adata.obs_names)]
    feats[id_] = np.asarray(adata[valid_idx, :].X)

In [None]:
thr = 0.075
idxs = []
means = []

for iiii in tqdm(datavignettess['SectionID'].unique()): ### reduced to 53 sensible sections
    datavignettes = datavignettess.loc[datavignettess['SectionID'] == iiii,:]

    for id_, dsub in datavignettes.groupby('id'):
        tree = trees.get(id_)
        if tree is None:
            continue
        qpts = dsub[['xccf','yccf','zccf']].to_numpy()
        nbrs_list = tree.query_ball_point(qpts, r=thr)
        arr = feats[id_]
        for i, nbrs in enumerate(nbrs_list):
            if nbrs:
                idxs.append(dsub.index[i])
                means.append(arr[nbrs].mean(axis=0))       

In [None]:
import pickle
with open('means.pkl', 'wb') as f:
    pickle.dump(means, f)

In [None]:
result = pd.DataFrame(np.array(means), index=idxs, columns=adata.var['gene_symbol'].index)

In [None]:
from threadpoolctl import threadpool_limits, threadpool_info
threadpool_limits(limits=8)
import os
os.environ['OMP_NUM_THREADS'] = '6'

result.to_parquet("spatialgoodgexpr_WHOLETRANSCRIPTOME.parquet")
result

In [None]:
datavignettes = datavignettess
datavignettes.shape

In [None]:
result['SectionID'] = datavignettes.loc[result.index,'SectionID']
result['xccf'] = datavignettes.loc[result.index,'xccf']
result['yccf'] = datavignettes.loc[result.index,'yccf']
result['zccf'] = datavignettes.loc[result.index,'zccf']
result['boundary'] = datavignettes.loc[result.index,'boundary']

## Check imputation quality visually

In [None]:
r = result
sections_top10_fast = r['xccf'].groupby(r['SectionID']).mean().sort_values()[::5].index # equispace rostrocaudally manually good sections...
sections_top10_fast

import matplotlib.pyplot as plt

for xxx in sections_top10_fast:
    mer = result.loc[result['SectionID'] == xxx,:]

    cont = mer.loc[mer['boundary'] == 1,:]

    plt.scatter(mer['zccf'], -mer['yccf'], c=mer['ENSMUSG00000041607'], cmap="Reds", s=0.1, rasterized=True)
    plt.scatter(cont['zccf'], -cont['yccf'],
                     c='black', s=0.01, alpha=1.0, rasterized=True)

    plt.show()

## Train genes to lipids XGBoost models

In [None]:
import pandas as pd
import anndata

import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.spatial import cKDTree
from threadpoolctl import threadpool_limits, threadpool_info
threadpool_limits(limits=8)
import os
os.environ['OMP_NUM_THREADS'] = '6'

datavignettes = pd.read_parquet("./zenodo/maindata_2.parquet")
result = pd.read_parquet("spatialgoodgexpr_WHOLETRANSCRIPTOME.parquet") # computed just above

In [None]:
# ---- DATA SETUP ----
genes        = result.iloc[:,:8460]                    
lipids       = datavignettes.loc[genes.index, :].iloc[:, :173]
lipids2learn = datavignettes.columns[:173]
sids         = datavignettes.loc[genes.index, 'SectionID']

# ---- HYPERPARAM GRID ----
param_dist = {
    "n_estimators":  [300],
    "learning_rate": [0.05],
    "max_depth":     [6],
    "subsample":     [0.6]
}

In [None]:
# downsample for feasibility
sections = sids.unique()

In [None]:
meanccf = datavignettes.loc[genes.index, 'xccf'].groupby(datavignettes.loc[genes.index, 'SectionID']).mean().sort_values()
meanccf

In [None]:
sections_to_train_on = meanccf.loc[meanccf.index < 110][::4].index.values # remove pregnancy as well, it's a condition
len(sections_to_train_on) # i'll start small from only 11 sections...

In [None]:
sids         = sids.loc[sids.isin(sections_to_train_on)]
genes        = genes.loc[sids.index,:]                     
lipids       = lipids.loc[sids.index,:] 
lipids.shape

In [None]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split, ParameterSampler
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
import xgboost as xgb
import shap
from joblib import dump
from threadpoolctl import threadpool_limits

# limit threads for reproducibility
threadpool_limits(limits=8)
os.environ['OMP_NUM_THREADS'] = '6'

n_iter     = 1
param_list = list(ParameterSampler(param_dist, n_iter=n_iter, random_state=42))

def pearson_scorer(y_true, y_pred):
    if np.std(y_true) == 0 or np.std(y_pred) == 0:
        return 0.0
    return pearsonr(y_true, y_pred)[0]

def process_lipid(lipid):
    X = genes
    y = lipids[lipid]

    X_tmp, X_test,  y_tmp,  y_test  = train_test_split(X, y, test_size=0.2, random_state=42)
    X_train, X_val,  y_train, y_val  = train_test_split(X_tmp, y_tmp, test_size=0.25, random_state=42)

    scaler = StandardScaler().fit(X_train)
    X_tr_s = scaler.transform(X_train).astype(np.float32)
    X_val_s = scaler.transform(X_val).astype(np.float32)
    X_te_s  = scaler.transform(X_test).astype(np.float32)

    print("alive")
    
    records = []
    best_model = None
    best_score = -np.inf

    for params in param_list:
        model = xgb.XGBRegressor(
            tree_method='hist',
            grow_policy='lossguide',
            n_jobs=2,
            random_state=42,
            **params
        )

        model.fit(
            X_tr_s, y_train,
            eval_set=[(X_val_s, y_val)],
            early_stopping_rounds=10,
            verbose=True
        )

        y_pred = model.predict(X_te_s)
        mse = mean_squared_error(y_test, y_pred)
        r   = pearson_scorer(y_test, y_pred)
        records.append({**params, 'Test_MSE': mse, 'Test_R': r})

        if r > best_score:
            best_score = r
            best_model = model

    results_df = pd.DataFrame(records)
    res_fname = f"{lipid.replace(' ','_').replace('/','_')}_xgb_results.csv"
    results_df.to_csv(res_fname, index=False)

    model_fname = f"{lipid.replace(' ','_').replace('/','_')}_xgb_model.joblib"
    dump(best_model, model_fname)

    explainer = shap.TreeExplainer(best_model)
    shap_vals = explainer.shap_values(X_te_s)
    df_shap = pd.DataFrame(shap_vals, columns=X.columns)
    shap_fname = f"{lipid.replace(' ','_').replace('/','_')}_xgb_shap_values.parquet" 
    df_shap.to_parquet(shap_fname)

    return lipid, results_df


lipids_subset = lipids2learn[1:] 
results_by_lipid = {}
for lip in tqdm(lipids_subset):
    lipid_name, df = process_lipid(lip)
    results_by_lipid[lipid_name] = df

## Extract the best cell type markers from the imputed data

In [None]:
cell_types = merfish['labels_supertype']
import numpy as np
import pandas as pd
from scipy import sparse

def compute_enrichment_scores(adata, cell_types: pd.Series,
                              alpha: float = 1.0, beta: float = 1.0) -> pd.DataFrame:
    """
    Compute E(i,g) = (mean_i_g / mean_g)^alpha * (freq_i_g / freq_g)^beta
    """
    ct = cell_types.reindex(adata.obs_names).dropna()
    ad = adata[ct.index]  
    
    X = ad.X  
    n_cells, n_genes = X.shape
    types = ct.unique()
    
    if sparse.issparse(X):
        mean_g = np.ravel(X.mean(axis=0))
        freq_g = np.ravel(X.astype(bool).sum(axis=0)) / n_cells
    else:
        mean_g = X.mean(axis=0)
        freq_g = (X > 0).sum(axis=0) / n_cells
    
    enrich = np.zeros((types.size, n_genes), dtype=float)
    for i, t in enumerate(types):
        mask = (ct == t).values
        Xi = X[mask]
        ni = mask.sum()
        if sparse.issparse(Xi):
            mean_i = np.ravel(Xi.mean(axis=0))
            freq_i = np.ravel(Xi.astype(bool).sum(axis=0)) / ni
        else:
            mean_i = Xi.mean(axis=0)
            freq_i = (Xi > 0).sum(axis=0) / ni
        
        enrich[i, :] = (mean_i/mean_g)**alpha * (freq_i/freq_g)**beta
    
    return pd.DataFrame(enrich, index=types, columns=adata.var_names)

def rank_genes_by_max_enrichment(enrich_df: pd.DataFrame, top_n: int = 20) -> pd.Series:
    """Take each gene’s max across types and return the top_n."""
    return enrich_df.max(axis=0).nlargest(top_n)

def assign_genes_to_cell_types(enrich_df: pd.DataFrame,
                               threshold: float, top_x: int = 20) -> pd.DataFrame:
    """
    For each gene, pick the type with highest E; if E>=threshold, assign it.
    Then report the top_x per cell type.
    """
    best_score = enrich_df.max(axis=0)
    best_type  = enrich_df.idxmax(axis=0)
    
    df = pd.DataFrame({
        'gene':        enrich_df.columns,
        'cell_type':   best_type.values,
        'enrichment':  best_score.values
    })
    df = df[df.enrichment >= threshold]
    
    return (df
            .sort_values(['cell_type','enrichment'], ascending=[True,False])
            .groupby('cell_type')
            .head(top_x)
            .reset_index(drop=True))

enrichment_scores = compute_enrichment_scores(adata, cell_types)
detect = merfish[['labels_supertype', 'labels_division']].drop_duplicates()
badtypes = detect.loc[detect['labels_division'].isin(['6 Vascular', '7 Immune']), 'labels_supertype'] 
enrichment_scores = enrichment_scores.drop(badtypes)
enrichment_scores.to_parquet("enrichment_scores_full_transcriptome.parquet")

In [None]:
import numpy as np
import pandas as pd
from statsmodels.stats.multitest import multipletests
from statsmodels.stats.multitest import multipletests

B = 2000          
threshold = 1e-6   
N = 20            
enrichment_scores =enrichment_scores.loc[enrichment_scores.index.isin(cell_types.value_counts().sort_values().index[cell_types.value_counts().sort_values() > 150]),:]

common_genes = region_matrix.columns.intersection(genemarkerish.index)
if len(common_genes) < len(region_matrix.columns):
    print(f"Warning: Dropping {len(region_matrix.columns) - len(common_genes)} genes "
          f"from region_matrix that are not in genemarkerish.")
if len(common_genes) < len(genemarkerish.index):
    print(f"Warning: Dropping {len(genemarkerish.index) - len(common_genes)} genes "
          f"from genemarkerish that are not in region_matrix.")

region_matrix = region_matrix.loc[:, common_genes]
genemarkerish = genemarkerish.loc[common_genes]


marker_ranks = genemarkerish.rank(method="average", ascending=True)
marker_ranks = 8460 - marker_ranks 
all_genes = list(common_genes)
gene2idx = { g: i for i, g in enumerate(all_genes) }

import numpy as np

G = 8460       # total number of genes
B = 2000       # bootstraps
N = 20         # draw size per bootstrap

rng = np.random.default_rng(seed=0)
samples = np.empty((B, N), dtype=int)

for b in range(B):
    samples[b, :] = rng.choice(G, size=N, replace=False)

null_means_N20 = marker_ranks_array[samples].mean(axis=1)
null_means_N20.sort()

import pickle

filename = "all_shap_norms_controlling_outliers.pkl"

with open(filename, "rb") as f:
    all_shap_norms = pickle.load(f)

region_matrices = all_shap_norms

lipid_names = np.load("lipids4xgbimpo.npy")
lipid_names

records = []
for lipid_name, region_matrix in tqdm(zip(lipid_names, region_matrices)):
    
    for region in region_matrix.index:  
        row = region_matrix.loc[region] 
        
        filtered = row[row >= threshold]
        if len(filtered) < N:
            continue

        top20 = filtered.nlargest(N).index.tolist() 
        top20_idx = [gene2idx[g] for g in top20]
        
        T_obs = marker_ranks_array[top20_idx].mean()

        # empirical p: fraction of null_means_N20 <= T_obs
        pos = np.searchsorted(null_means_N20, T_obs, side="right")
        p_emp = (pos + 1) / (B + 1)

        records.append({
            "lipid":       lipid_name,
            "region":      region,
            "T_obs":       T_obs,
            "p_value_raw": p_emp,
        })

results_df = pd.DataFrame.from_records(records)

reject, p_adj, _, _ = multipletests(results_df["p_value_raw"].values,
                                    alpha=0.05, method="fdr_bh")
results_df["p_value_fdr"] = p_adj
results_df["significant_FDR05"] = reject

num_tests = len(results_df)
num_sig   = results_df["significant_FDR05"].sum()
frac_sig  = 100.0 * num_sig / num_tests
print(f"{num_sig} / {num_tests} = {frac_sig:.1f}% of region‐lipid pairs are marker‐enriched at FDR<0.05.")

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

plt.hist(results_df['p_value_fdr'], bins=100, color="gray")

plt.axvline(x=0.05, color='red', linestyle='--', linewidth=2, label='FDR = 0.05')

plt.xlim(0, 0.2)

plt.legend()
plt.savefig("pvalues_xgb_shap_celltypemarkers.pdf")
plt.show()