In [None]:
import os
import warnings
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import average_precision_score, auc, classification_report, accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder
import torch

import anndata as ad
import lightning as L
import modlyn as mn
import lamindb as ln
import scanpy as sc
from scipy.stats import spearmanr

import seaborn as sns
sns.set_theme()
%config InlineBackend.figure_formats = ['svg']

warnings.filterwarnings('ignore')

project = ln.Project(name="Modlyn")
project.save()

ln.track(project="Modlyn")

run = ln.track()

In [None]:
artifact = ln.Artifact.using("laminlabs/arrayloader-benchmarks").get("D21D2K8697CY8tHE0001")
adata = artifact.load()
adata

In [None]:
# import os
# os.environ["LAMIN_CACHE_DIR"] = "/data/.lamindb-cache"

# artifact = ln.Artifact.using("laminlabs/arrayloader-benchmarks").get("bzX5jvxDmcqoJVJg0000")
# adata.load()

In [None]:
# store_path = Path(
#     "/data/.lamindb-cache/lamin-us-west-2/"
#     "wXDsTYYd/tahoe100M_shuffled_zarr_store_2025-05-07/chunk_30.zarr"
# )
# adata = ad.read_zarr(str(store_path))

# var = pd.read_parquet("var_subset_tahoe100M.parquet")
# adata.var = var
# adata.obs["y"] = (
#     adata.obs["cell_line"]
#     .astype("category")
#     .cat.codes
#     .astype("i8")
# )

# adata


In [None]:
keep = adata.obs["cell_line"].value_counts().loc[lambda x: x>1].index
adata = adata[adata.obs["cell_line"].isin(keep)].copy()
sc.pp.log1p(adata)

In [None]:
# Subset
# n = adata.n_obs

# # n_train = int(n * 0.8)
# n_train = 5000
# # n_val = n - n_train
# n_val = 2000

# adata_train = adata[:n_train]
# adata_val = adata[n_train:]
# adata_train

In [None]:
adata_train = adata.copy()

In [None]:
logreg = mn.models.SimpleLogReg(
    adata=adata_train,
    label_column="cell_line",    
    learning_rate=1e-1,
)
logreg.fit(
    adata_train=adata_train,
    adata_val=adata_train[:20],
    train_dataloader_kwargs={"batch_size": 8},
    max_epochs=4,
)

In [None]:
logreg.plot_losses()


In [None]:
# Run scanpy methods
sc.tl.rank_genes_groups(adata_train, 'cell_line', method='logreg')
sc.tl.rank_genes_groups(adata_train, 'cell_line', method='wilcoxon', key_added='wilcoxon')

# Extract scores and build DataFrames
lr_scores = {cl: pd.Series(adata_train.uns['rank_genes_groups']['scores'][cl], 
                          index=adata_train.uns['rank_genes_groups']['names'][cl]) 
             for cl in adata_train.uns['rank_genes_groups']['scores'].dtype.names}
df_lr = pd.DataFrame(lr_scores).T

wl_scores = {cl: pd.Series(adata_train.uns['wilcoxon']['scores'][cl],
                          index=adata_train.uns['wilcoxon']['names'][cl]) 
             for cl in adata_train.uns['wilcoxon']['scores'].dtype.names}
df_wl = pd.DataFrame(wl_scores).T

# Get modlyn weights from the linear layer
weights = logreg.linear.weight.detach().numpy()  # shape: (n_classes, n_genes)
df_ml = pd.DataFrame(weights.T,  # transpose to (n_genes, n_classes)
                     index=adata_train.var_names,
                     columns=logreg.datamodule.label_encoder.classes_).T

In [None]:
# Restrict to shared genes and cell lines
common_genes = sorted(set(df_lr.columns) & set(df_wl.columns) & set(df_ml.columns))
common_cells = sorted(set(df_lr.index) & set(df_wl.index) & set(df_ml.index))

df_lr = df_lr.loc[common_cells, common_genes]
df_wl = df_wl.loc[common_cells, common_genes]
df_ml = df_ml.loc[common_cells, common_genes]

# Top-N overlap analysis
N = 50
records = []
for cl in common_cells:
    top_lr = df_lr.loc[cl].abs().nlargest(N).index
    top_wl = df_wl.loc[cl].abs().nlargest(N).index
    top_ml = df_ml.loc[cl].abs().nlargest(N).index
    
    overlap_lr_ml = len(set(top_lr) & set(top_ml))
    overlap_lr_wl = len(set(top_lr) & set(top_wl))
    rho_lr_ml = spearmanr(df_lr.loc[cl], df_ml.loc[cl]).correlation
    rho_lr_wl = spearmanr(df_lr.loc[cl], df_wl.loc[cl]).correlation
    
    records.append({
        'cell_line': cl,
        'overlap_logreg_modlyn': overlap_lr_ml,
        'overlap_logreg_wilcox': overlap_lr_wl,
        'spearman_logreg_modlyn': rho_lr_ml,
        'spearman_logreg_wilcox': rho_lr_wl
    })

comparison_df = pd.DataFrame(records).set_index('cell_line')

# Plot overlap and correlation heatmaps
fig, axes = plt.subplots(1, 2, figsize=(14, 8))

overlap_df = comparison_df[['overlap_logreg_modlyn', 'overlap_logreg_wilcox']]
sns.heatmap(overlap_df, annot=True, fmt="d", cmap="Blues", 
            cbar_kws={"label": "Shared top-50 genes"}, ax=axes[0])
axes[0].set_title("Top-50 Gene Overlap")

rho_df = comparison_df[['spearman_logreg_modlyn', 'spearman_logreg_wilcox']]
sns.heatmap(rho_df, annot=True, fmt=".2f", cmap="vlag", center=0,
            cbar_kws={"label": "Spearman ρ"}, ax=axes[1])
axes[1].set_title("Spearman Correlation")

plt.tight_layout()
plt.show()

In [None]:
# AUPR analysis
p = 60  # top 40% threshold
effect_thresh = {
    'logreg': np.percentile(df_lr.abs().values.flatten(), p),
    'wilcox': np.percentile(df_wl.abs().values.flatten(), p),
    'modlyn': np.percentile(df_ml.abs().values.flatten(), p),
}

def build_sigsets(df, effect_thresh):
    mask = df.abs() >= effect_thresh
    df_eff = df.where(mask).stack().dropna()
    # Use absolute values as confidence
    conf_eff = df.abs().stack()
    
    cuts = np.unique(conf_eff.values)
    cuts.sort()
    
    sigsets = []
    for c in cuts:
        sel = df_eff[conf_eff >= c]
        pairs = set(zip(sel.index.get_level_values(0), sel.index.get_level_values(1)))
        sigsets.append((c, pairs))
    return sigsets

def pr_auc(truth_sets, test_sets):
    gt = truth_sets[0][1]  # most liberal set
    precisions, recalls = [], []
    
    for _, test in test_sets:
        tp = len(gt & test)
        fp = len(test) - tp
        fn = len(gt) - tp
        p = tp/(tp+fp) if tp+fp > 0 else 1.0
        r = tp/(tp+fn) if tp+fn > 0 else 0.0
        precisions.append(p)
        recalls.append(r)
    
    return np.array(recalls), np.array(precisions), auc(recalls, precisions)

# Build significance sets
sig_lr = build_sigsets(df_lr, effect_thresh['logreg'])
sig_wl = build_sigsets(df_wl, effect_thresh['wilcox'])
sig_ml = build_sigsets(df_ml, effect_thresh['modlyn'])

# Compute PR curves
r_wl, p_wl, auc_wl = pr_auc(sig_lr, sig_wl)
r_ml, p_ml, auc_ml = pr_auc(sig_lr, sig_ml)

# Plot PR curves
plt.figure(figsize=(8, 6))
plt.plot(r_wl, p_wl, label=f'Wilcoxon vs LogReg (AUPR={auc_wl:.3f})')
plt.plot(r_ml, p_ml, label=f'Modlyn vs LogReg (AUPR={auc_ml:.3f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Comparison')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
ln.finish()