In [None]:
import os
import warnings
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import anndata as ad
import lightning as L
import lamindb as ln
import matplotlib.pyplot as plt
import scanpy as sc
from sklearn.metrics import average_precision_score
import torch

from modlyn.io.loading import read_lazy
from modlyn.io.datamodules import ClassificationDataModule
from modlyn.models.linear import Linear ## should move to modlyn not to arrayloader - cp to folder - maintain API structure - name the folders and sub-modules

warnings.filterwarnings('ignore')

project = ln.Project(name="Modlyn")
project.save()

ln.track(project="Modlyn")

run = ln.track()

In [None]:
import os
os.environ["LAMIN_CACHE_DIR"] = "/data/.lamindb-cache"


In [None]:
from pathlib import Path
from modlyn.io import read_lazy

store_path = Path("/data/.lamindb-cache/lamin-us-west-2/wXDsTYYd/tahoe100M_shuffled_zarr_store_2025-05-07/chunk_30.zarr")
adata = read_lazy(store_path)
var = pd.read_parquet("var_subset_tahoe100M.parquet")
adata.var = var
adata.obs["y"] = adata.obs["cell_line"].astype("category").cat.codes.astype("i8")
# adata.var

In [None]:
sc.pp.log1p(adata)

In [None]:
# Subset
n = adata.n_obs

n_train = int(n * 0.8)
n_val = n - n_train

adata_train = adata[:n_train]
adata_val = adata[n_train:]
adata_train

In [None]:
class LossTracker(L.Callback):
    def __init__(self):
        super().__init__()
        self.train_losses = []
        self.val_losses = []

    def on_train_epoch_end(self, trainer, pl_module):
        loss = trainer.callback_metrics["train_loss"]
        self.train_losses.append(loss.item())

    def on_validation_epoch_end(self, trainer, pl_module):
        loss = trainer.callback_metrics["val_loss"]
        self.val_losses.append(loss.item())

datamodule = ClassificationDataModule(
    adata_train=adata_train,
    adata_val=adata_val,
    label_column="y",
    train_dataloader_kwargs={"batch_size": 2048, "drop_last": True},
    val_dataloader_kwargs={"batch_size": 2048, "drop_last": False},
)

linear = Linear(
    n_genes=adata.n_vars,
    n_covariates=adata.obs["y"].nunique(),
    learning_rate=1e-2,
)

loss_tracker = LossTracker()
trainer = L.Trainer(
    max_epochs=3,
    max_steps=3000,
    log_every_n_steps=100,
    callbacks=[loss_tracker]
)
trainer.fit(linear, datamodule)



In [None]:
plt.plot(loss_tracker.train_losses, marker='o', label="train_loss")
plt.plot(loss_tracker.val_losses,   marker='x', label="val_loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
weights = linear.linear.weight.detach().cpu().numpy()
top_cell_lines = adata.obs["cell_line"].value_counts().index[:weights.shape[0]].tolist()

weights_df = pd.DataFrame(
    weights, 
    columns=adata.var_names,
    index=top_cell_lines
)


In [None]:
import seaborn as sns
from matplotlib import cm

def compute_fisher_info_and_se(model, dataloader):
    model.eval()
    fisher_diag = torch.zeros_like(model.linear.weight)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Computing Fisher Information"):
            x, y = batch
            logits = model.linear(x)
            probs = torch.softmax(logits, dim=1)

            for i in range(probs.shape[1]):
                p = probs[:, i].unsqueeze(1)
                fisher_i = p * (1 - p) * x**2
                fisher_diag[i] += fisher_i.sum(dim=0)

    se = torch.sqrt(1.0 / (fisher_diag + 1e-8))  # Add epsilon for numerical stability
    confidence = 1.0 / se**2
    return se.cpu().numpy(), confidence.cpu().numpy()

se, confidence = compute_fisher_info_and_se(linear, datamodule.val_dataloader())
confidence_df = pd.DataFrame(
    confidence,
    columns=adata.var_names,
    index=weights_df.index  # real cell line names
)

In [None]:
from sklearn.preprocessing import minmax_scale

# ─── 1) Load your precomputed AnnData ──────────────────────────────────────────
# adata_pre = sc.read_h5ad('adata_chunk30_processed.h5ad')
print(adata_pre)
# inspect what keys were stored
print(adata_pre.uns.keys())  

In [None]:
sc.pl.rank_genes_groups_dotplot(
    adata_pre,
    groupby='cell_line',
    key='logreg', 
    n_genes=10,
    title='Precomputed Scanpy LogReg (Top 10)'
)
sc.pl.rank_genes_groups_dotplot(
    adata_pre,
    groupby='cell_line',
    key='wilcoxon',
    n_genes=10,
    title='Precomputed Scanpy Wilcoxon (Top 10)'
)

In [None]:
# ─── 3) Build your model’s dotplot as before ───────────────────────────────────

# assume `weights_df` and `confidence_df` are already in memory from your trained Linear
# pick top‐10 genes per cell_line from your model
n_top = 10
top_model = {
    cl: weights_df.loc[cl].nlargest(n_top).index.tolist()
    for cl in weights_df.index
}
genes_model = list({g for genes in top_model.values() for g in genes})

# scale your confidence for dot size
conf_sub = confidence_df.loc[weights_df.index, genes_model]
conf_clipped = np.clip(conf_sub, 0, np.percentile(conf_sub.values, 99))
conf_scaled = pd.DataFrame(
    minmax_scale(conf_clipped, axis=1),
    index=conf_clipped.index,
    columns=conf_clipped.columns
)

# build AnnData for your model
adata_model = sc.AnnData(
    X=conf_scaled.values,
    obs=pd.DataFrame(index=conf_scaled.index),
    var=pd.DataFrame(index=conf_scaled.columns)
)
adata_model.obs['cell_line'] = adata_model.obs.index
adata_model.layers['weights'] = weights_df.loc[
    adata_model.obs_names, adata_model.var_names
].values

In [None]:
adata_model

In [None]:
sc.pl.dotplot(
    adata_model,
    var_names=genes_model,
    groupby='cell_line',
    use_raw=False,
    layer='weights',
    dot_min=0.05,
    dot_max=0.6,
    title='modlyn linear: Weight & Confidence'
)

In [None]:
# adata_model.write("adata_modlyn_chunk30_trained_with_confidence.h5ad")

In [None]:
# lr_key = 'logreg'  
# lr_names = adata_pre.uns['logreg']['names']   # shape: (n_groups, n_genes)
# flat = [g for row in lr_names for g in row]   # flatten list of lists
# genes = []
# for g in flat:
#     if g not in genes:
#         genes.append(g)
#     if len(genes) == 50:
#         break

# print(genes)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
from sklearn.preprocessing import minmax_scale

# ─── Assumptions ────────────────────────────────────────────────────────────────
# • adata_pre: AnnData with `.uns['logreg']['names']`, `.uns['logreg']['scores']`, 
#   and similarly for 'wilcoxon'
# • adata_model: AnnData with .layers['weights'] and .X = scaled certainty
# • genes: list of genes (e.g. top 50 from logreg) you want to compare

# ─── 0) Filter genes to those present in the model AnnData ───────────────────────
common_genes = [g for g in genes if g in adata_model.var_names]
if len(common_genes) < len(genes):
    missing = set(genes) - set(common_genes)
    print(f"Warning: dropping {len(missing)} genes not in model: {missing}")
genes = common_genes

# ─── 1) Build Scanpy logreg_scores DataFrame ────────────────────────────────────
scores_lr = adata_pre.uns['logreg']['scores']
names_lr  = adata_pre.uns['logreg']['names']
groups    = scores_lr.dtype.names

lr_dict = {
    cl: pd.Series(scores_lr[cl], index=names_lr[cl])
    for cl in groups
}
logreg_scores = pd.DataFrame(lr_dict).T[genes]

# ─── 2) Build Scanpy wilcoxon_scores DataFrame ─────────────────────────────────
scores_wl = adata_pre.uns['wilcoxon']['scores']
names_wl  = adata_pre.uns['wilcoxon']['names']

wl_dict = {
    cl: pd.Series(scores_wl[cl], index=names_wl[cl])
    for cl in scores_wl.dtype.names
}
wilcoxon_scores = pd.DataFrame(wl_dict).T[genes]

# ─── 3) Extract your model’s weights & certainty ───────────────────────────────
modlyn_weights   = pd.DataFrame(
    adata_model.layers['weights'],
    index=adata_model.obs_names,
    columns=adata_model.var_names
)[genes]

modlyn_certainty = pd.DataFrame(
    adata_model.X,
    index=adata_model.obs_names,
    columns=adata_model.var_names
)[genes]

# ─── 4) Scale values for dot color ([-1,1]) ─────────────────────────────────────
def scale_weights(df):
    vmax = np.percentile(np.abs(df.values), 99)
    return df.clip(-vmax, vmax) / vmax

logreg_scaled   = scale_weights(logreg_scores)
wilcoxon_scaled = scale_weights(wilcoxon_scores)
modlyn_scaled   = scale_weights(modlyn_weights)

# ─── 5) Scale values for dot size ([0,1]) ──────────────────────────────────────
logreg_size   = pd.DataFrame(
    minmax_scale(logreg_scores.abs(),   axis=1),
    index=logreg_scores.index,   columns=genes
)
wilcoxon_size = pd.DataFrame(
    minmax_scale(wilcoxon_scores.abs(), axis=1),
    index=wilcoxon_scores.index, columns=genes
)
modlyn_size   = pd.DataFrame(
    minmax_scale(modlyn_certainty,      axis=1),
    index=modlyn_certainty.index,      columns=genes
)

# ─── 6) Plot side-by-side dotplots ─────────────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

sc.pl.dotplot(
    adata_pre,
    var_names=genes,
    groupby='cell_line',
    dot_color_df=logreg_scaled,
    dot_size_df=logreg_size,
    ax=axes[0],
    cmap='RdBu_r',
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=False
)
axes[0].set_title('Scanpy LogReg (scaled)')

sc.pl.dotplot(
    adata_pre,
    var_names=genes,
    groupby='cell_line',
    dot_color_df=wilcoxon_scaled,
    dot_size_df=wilcoxon_size,
    ax=axes[1],
    cmap='RdBu_r',
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=False
)
axes[1].set_title('Scanpy Wilcoxon (scaled)')

sc.pl.dotplot(
    adata_model,
    var_names=genes,
    groupby='cell_line',
    dot_color_df=modlyn_scaled,
    dot_size_df=modlyn_size,
    ax=axes[2],
    cmap='RdBu_r',
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=False
)
axes[2].set_title('Modlyn (scaled)')

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import pandas as pd
from scipy.stats import spearmanr

# ─── 1) Assemble the three DataFrames on the same genes & cell_lines ───────────

# A) Scanpy logistic regression scores
lr_scores = {
    cl: pd.Series(adata_pre.uns['logreg']['scores'][cl],
                  index=adata_pre.uns['logreg']['names'][cl])
    for cl in adata_pre.uns['logreg']['scores'].dtype.names
}
df_lr = pd.DataFrame(lr_scores).T  # shape: (cell_line × gene)

# B) Scanpy Wilcoxon scores
wl_scores = {
    cl: pd.Series(adata_pre.uns['wilcoxon']['scores'][cl],
                  index=adata_pre.uns['wilcoxon']['names'][cl])
    for cl in adata_pre.uns['wilcoxon']['scores'].dtype.names
}
df_wl = pd.DataFrame(wl_scores).T

# C) Your model’s raw weights (no Fisher info)
df_ml = pd.DataFrame(
    adata_model.layers['weights'],
    index=adata_model.obs_names,
    columns=adata_model.var_names
)

# D) Restrict to shared genes & sorted cell_lines
common_genes = sorted(set(df_lr.columns) & set(df_wl.columns) & set(df_ml.columns))
common_cells = sorted(set(df_lr.index) & set(df_wl.index) & set(df_ml.index))

df_lr = df_lr.loc[common_cells, common_genes]
df_wl = df_wl.loc[common_cells, common_genes]
df_ml = df_ml.loc[common_cells, common_genes]

# ─── 2) For each cell_line, get top-N gene lists by absolute score/weight ───────
N = 50
top_lr  = {cl: df_lr.loc[cl].abs().nlargest(N).index.tolist() for cl in common_cells}
top_wl  = {cl: df_wl.loc[cl].abs().nlargest(N).index.tolist() for cl in common_cells}
top_ml  = {cl: df_ml.loc[cl].abs().nlargest(N).index.tolist() for cl in common_cells}

# ─── 3) Compute overlaps and Spearman correlations ──────────────────────────────
records = []
for cl in common_cells:
    set_lr, set_wl, set_ml = set(top_lr[cl]), set(top_wl[cl]), set(top_ml[cl])
    
    # top-N overlaps
    overlap_lr_ml = len(set_lr & set_ml)
    overlap_lr_wl = len(set_lr & set_wl)
    
    # rank correlations on all shared genes
    rho_lr_ml = spearmanr(df_lr.loc[cl, common_genes], df_ml.loc[cl, common_genes]).correlation
    rho_lr_wl = spearmanr(df_lr.loc[cl, common_genes], df_wl.loc[cl, common_genes]).correlation
    
    records.append({
        'cell_line': cl,
        'overlap_logreg_modlyn': overlap_lr_ml,
        'overlap_logreg_wilcox': overlap_lr_wl,
        'spearman_logreg_modlyn': rho_lr_ml,
        'spearman_logreg_wilcox': rho_lr_wl
    })

comparison_df = pd.DataFrame(records).set_index('cell_line')
# print(comparison_df)

plt.figure(figsize=(12, 10))
sns.heatmap(
    comparison_df,
    annot=True,
    fmt=".2f",
    cmap="vlag",
    cbar_kws={"label": "Value"},
    linewidths=0.5
)
plt.title("Comparison of Model vs. Scanpy Metrics per Cell Line")
plt.ylabel("Cell Line")
plt.xlabel("Metric")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import pandas as pd
from scipy.stats import spearmanr

# ─── 1) Assemble DataFrames for each method ────────────────────────────────────
# A) Scanpy logistic regression scores
lr_scores = {
    cl: pd.Series(adata_pre.uns['logreg']['scores'][cl],
                  index=adata_pre.uns['logreg']['names'][cl])
    for cl in adata_pre.uns['logreg']['scores'].dtype.names
}
df_lr = pd.DataFrame(lr_scores).T

# B) Scanpy Wilcoxon scores
wl_scores = {
    cl: pd.Series(adata_pre.uns['wilcoxon']['scores'][cl],
                  index=adata_pre.uns['wilcoxon']['names'][cl])
    for cl in adata_pre.uns['wilcoxon']['scores'].dtype.names
}
df_wl = pd.DataFrame(wl_scores).T

# C) Your model’s raw weights
df_ml = pd.DataFrame(
    adata_model.layers['weights'],
    index=adata_model.obs_names,
    columns=adata_model.var_names
)

# ─── 2) Restrict to shared genes and cell lines ────────────────────────────────
common_genes = sorted(set(df_lr.columns) & set(df_wl.columns) & set(df_ml.columns))
common_cells = sorted(set(df_lr.index)   & set(df_wl.index)   & set(df_ml.index))

df_lr = df_lr.loc[common_cells, common_genes]
df_wl = df_wl.loc[common_cells, common_genes]
df_ml = df_ml.loc[common_cells, common_genes]

# ─── 3) Identify top-N genes by absolute score per cell line ──────────────────
N = 50
top_lr = {cl: df_lr.loc[cl].abs().nlargest(N).index.tolist() for cl in common_cells}
top_wl = {cl: df_wl.loc[cl].abs().nlargest(N).index.tolist() for cl in common_cells}
top_ml = {cl: df_ml.loc[cl].abs().nlargest(N).index.tolist() for cl in common_cells}

# ─── 4) Compute overlap and Spearman correlations ─────────────────────────────
records = []
for cl in common_cells:
    set_lr, set_wl, set_ml = set(top_lr[cl]), set(top_wl[cl]), set(top_ml[cl])
    overlap_lr_ml = len(set_lr & set_ml)
    overlap_lr_wl = len(set_lr & set_wl)
    rho_lr_ml     = spearmanr(df_lr.loc[cl, common_genes],
                              df_ml.loc[cl, common_genes]).correlation
    rho_lr_wl     = spearmanr(df_lr.loc[cl, common_genes],
                              df_wl.loc[cl, common_genes]).correlation
    records.append({
        'cell_line': cl,
        'overlap_logreg_modlyn':     overlap_lr_ml,
        'overlap_logreg_wilcox':     overlap_lr_wl,
        'spearman_logreg_modlyn':    rho_lr_ml,
        'spearman_logreg_wilcox':    rho_lr_wl
    })

comparison_df = pd.DataFrame(records).set_index('cell_line')
# print(comparison_df)

# ─── 5) (Optional) Save or plot comparison_df as needed ───────────────────────
# comparison_df.to_csv('method_comparison_summary.csv')

# comparison_df has columns:
#   ['overlap_logreg_modlyn', 'overlap_logreg_wilcox',
#    'spearman_logreg_modlyn', 'spearman_logreg_wilcox']

# 1) Split metrics
overlap_df = comparison_df[['overlap_logreg_modlyn', 'overlap_logreg_wilcox']]
rho_df     = comparison_df[['spearman_logreg_modlyn', 'spearman_logreg_wilcox']]

# 2) Set up subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 8))

# 3) Heatmap of overlaps
sns.heatmap(
    overlap_df,
    annot=True, fmt="d",
    cmap="Blues",
    cbar_kws={"label": "Number of shared top-50 genes"},
    ax=axes[0]
)
axes[0].set_title("Overlap of Top-50 Gene Sets")
axes[0].set_xlabel("Comparison")
axes[0].set_ylabel("Cell Line")
axes[0].tick_params(axis='x', rotation=45)

# 4) Heatmap of Spearman correlations
sns.heatmap(
    rho_df,
    annot=True, fmt=".2f",
    cmap="vlag",
    center=0,
    cbar_kws={"label": "Spearman ρ"},
    ax=axes[1]
)
axes[1].set_title("Spearman Correlation of Full Rankings")
axes[1].set_xlabel("Comparison")
axes[1].set_ylabel("")

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc
from scipy.stats import spearmanr

# ─── 0) Prepare effect & confidence matrices for each method ───────────────────
# df_lr: Scanpy logreg scores    (adata_pre.uns['logreg']['scores'])
# df_wl: Scanpy Wilcoxon scores  (adata_pre.uns['wilcoxon']['scores'])
# df_ml: Modlyn weights          (adata_model.layers['weights'])

# Align on shared genes/cell_lines
common_genes = sorted(set(df_lr.columns) & set(df_wl.columns) & set(df_ml.columns))
common_cells = sorted(set(df_lr.index)   & set(df_wl.index)   & set(df_ml.index))
df_lr = df_lr.loc[common_cells, common_genes]
df_wl = df_wl.loc[common_cells, common_genes]
df_ml = df_ml.loc[common_cells, common_genes]

# Define confidence proxies (absolute scores)  
conf_lr = df_lr.abs()   # logreg z-scores ∝ −log₁₀(p) :contentReference[oaicite:0]{index=0}  
conf_wl = df_wl.abs()   # Wilcoxon U-statistics :contentReference[oaicite:1]{index=1}  
conf_ml = df_ml.abs()   # log-odds (weights) ∝ effect strength :contentReference[oaicite:2]{index=2}  

In [None]:
# ─── 1) Build significance sets over confidence thresholds ─────────────────────
def build_sigsets(df, conf, effect_thresh):
    """
    For a method:
      • df: (cells×genes) effect sizes  
      • conf: same shape, confidence measure  
      • effect_thresh: scalar threshold on abs(effect)  
    Returns: sorted list of (threshold, set of (cell,gene))
    """
    # 1a) mask low-effect entries  
    mask = df.abs() >= effect_thresh  # effect filtering :contentReference[oaicite:3]{index=3}  
    df_eff = df.where(mask).stack()    # stack to Series of ((cell,gene)->value) :contentReference[oaicite:4]{index=4}  
    conf_eff = conf.stack()            # same structure  

    # 1b) unique confidence cutoffs (percentiles)  
    cuts = np.unique(conf_eff.values)
    cuts.sort()

    sigsets = []
    for c in cuts:
        sel = df_eff[conf_eff>=c]
        pairs = set(zip(sel.index.get_level_values(0),
                        sel.index.get_level_values(1)))
        sigsets.append((c, pairs))
    return sigsets

# thresholds for effect-size filtering
# effect_thresh = {
#     'logreg':   1.96,        # only genes with |z| ≥1.96 (~p<0.05)  
#     'wilcox':   1.96,        # same for Wilcoxon U test z-score  
#     'modlyn':   np.log(4)    # weights ≥log(4) ≈ 2× fold-change in odds  
# }


p = 60  # top 10%
effect_thresh = {
    'logreg': np.percentile(df_lr.abs().values.flatten(),     p),
    'wilcox': np.percentile(df_wl.abs().values.flatten(),     p),
    'modlyn': np.percentile(df_ml.abs().values.flatten(),     p),
}


sig_lr = build_sigsets(df_lr, conf_lr, effect_thresh['logreg'])
sig_wl = build_sigsets(df_wl, conf_wl, effect_thresh['wilcox'])
sig_ml = build_sigsets(df_ml, conf_ml, effect_thresh['modlyn'])

# ─── 2) Compute PR curves & AUPR ────────────────────────────────────────────────
def pr_auc(truth_sets, test_sets):
    """
    truth_sets: list of (thr, set) from most liberal to conservative  
    test_sets: same  
    """
    # ground truth = liberalest set  
    gt = truth_sets[0][1]  
    precisions, recalls = [], []
    for _, test in test_sets:
        tp = len(gt & test)
        fp = len(test) - tp
        fn = len(gt) - tp
        p = tp/(tp+fp) if tp+fp>0 else 1.0  # precision definition :contentReference[oaicite:6]{index=6}  
        r = tp/(tp+fn) if tp+fn>0 else 0.0  # recall definition :contentReference[oaicite:7]{index=7}  
        precisions.append(p)
        recalls.append(r)
    return np.array(recalls), np.array(precisions), auc(recalls, precisions)

r_wl, p_wl, auc_wl = pr_auc(sig_lr, sig_wl)  # Wilcoxon vs logreg  
r_ml, p_ml, auc_ml = pr_auc(sig_lr, sig_ml)  # Modlyn vs logreg  

# ─── 3) Plot Precision–Recall curves ───────────────────────────────────────────
plt.figure(figsize=(6,6))
plt.plot(r_wl, p_wl, label=f'Wilcoxon vs LogReg  (AUPR={auc_wl:.3f})')
plt.plot(r_ml, p_ml, label=f'Modlyn vs LogReg    (AUPR={auc_ml:.3f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision–Recall Comparison') 
plt.legend()
plt.grid(True)
plt.show()

In [None]:
ln.finish()