In [None]:
import os
import warnings
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import anndata as ad
import lightning as L
import lamindb as ln
import matplotlib.pyplot as plt
import scanpy as sc
from sklearn.metrics import average_precision_score
import torch

from modlyn.io.loading import read_lazy
from modlyn.io.datamodules import ClassificationDataModule
from modlyn.models.linear import Linear ## should move to modlyn not to arrayloader - cp to folder - maintain API structure - name the folders and sub-modules

warnings.filterwarnings('ignore')

# Start tracking
project = ln.Project(name="Modlyn")
project.save()

ln.track(project="Modlyn")

run = ln.track()


In [None]:
# artifact = ln.Artifact.filter(key="tahoe100M_shuffled_zarr_store_2025-05-07").one()
# path = artifact.cache()

In [None]:
import os
os.environ["LAMIN_CACHE_DIR"] = "/data/.lamindb-cache"


In [None]:
# !export LAMIN_CACHE_DIR=/data/.lamindb-cache
# from pathlib import Path

# os.environ["LAMIN_CACHE_DIR"] = "/data/.lamindb-cache"

# artifact = ln.Artifact.filter(key="tahoe100M_shuffled_zarr_store_2025-05-07").one()

# # Cache it to /data
# store_path = artifact.cache()


In [None]:
from pathlib import Path
from modlyn.io import read_lazy

store_path = Path("/data/.lamindb-cache/lamin-us-west-2/wXDsTYYd/tahoe100M_shuffled_zarr_store_2025-05-07/chunk_30.zarr")
adata = read_lazy(store_path)
var = pd.read_parquet("var_subset_tahoe100M.parquet")
adata.var = var
adata.obs["y"] = adata.obs["cell_line"].astype("category").cat.codes.astype("i8")
adata.var

## Multiple chunks

from anndata import concat

base_path = Path("/data/.lamindb-cache/lamin-us-west-2/wXDsTYYd/tahoe100M_shuffled_zarr_store_2025-05-07")
chunk_paths = sorted(base_path.glob("chunk_*.zarr"))

adatas = [read_lazy(p) for p in chunk_paths[:1]]
adata = concat(adatas, axis=0, join="outer", merge="same")

adata.var = pd.read_parquet("var_subset_tahoe100M.parquet")
adata.obs["y"] = adata.obs["cell_line"].astype("category").cat.codes.astype("i8")
adata

In [None]:
# adata

In [None]:
# Subset
n = adata.n_obs

n_train = int(n * 0.8)
n_val = n - n_train

adata_train = adata[:n_train]
adata_val = adata[n_train:]
adata_train

In [None]:
# Old load data
# store_path = Path("/home/ubuntu/tahoe100M_chunk_1")
# adata = read_lazy(store_path)
# var = pd.read_parquet("var_subset_tahoe100M.parquet")
# adata.var = var
# adata.obs["y"] = adata.obs["cell_line"].astype("category").cat.codes.astype("i8")

# # Subset
# adata_train = adata[:80000]
# adata_val = adata[80000:100000]

Move the code block to modlyn
now at arrayloader


In [None]:
class LossTracker(L.Callback):
    def __init__(self):
        super().__init__()
        self.train_losses = []
        self.val_losses = []

    def on_train_epoch_end(self, trainer, pl_module):
        loss = trainer.callback_metrics["train_loss"]
        self.train_losses.append(loss.item())

    def on_validation_epoch_end(self, trainer, pl_module):
        loss = trainer.callback_metrics["val_loss"]
        self.val_losses.append(loss.item())

datamodule = ClassificationDataModule(
    adata_train=adata_train,
    adata_val=adata_val,
    label_column="y",
    train_dataloader_kwargs={"batch_size": 2048, "drop_last": True},
    val_dataloader_kwargs={"batch_size": 2048, "drop_last": False},
)

linear = Linear(
    n_genes=adata.n_vars,
    n_covariates=adata.obs["y"].nunique(),
    learning_rate=1e-2,
)

loss_tracker = LossTracker()
trainer = L.Trainer(
    max_epochs=3,
    max_steps=3000,
    log_every_n_steps=100,
    callbacks=[loss_tracker]
)
trainer.fit(linear, datamodule)



Plot the elbow curve for classification loss

autostopping (sklearn)

scalable DL approach and sklearn implementation - weights should be the same when we converge 
Monitor the loss (classif accuracy & loss on the test)

Make sure that the weights make sense: cell lines should be an easy task for a proof of concept - predictive on the validation test

Move the code in the modlyn package

In [None]:
plt.plot(loss_tracker.train_losses, marker='o', label="train_loss")
plt.plot(loss_tracker.val_losses,   marker='x', label="val_loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
weights = linear.linear.weight.detach().cpu().numpy()
top_cell_lines = adata.obs["cell_line"].value_counts().index[:weights.shape[0]].tolist()

weights_df = pd.DataFrame(
    weights, 
    columns=adata.var_names,
    index=top_cell_lines
)


In [None]:
# Predict on validation set
preds = []
linear.eval()
for batch in datamodule.val_dataloader():
    x, y = batch
    with torch.no_grad():
        preds.append(linear(x).detach().cpu().numpy())

y_score = np.vstack(preds)
y_true = adata_val.obs["y"].values

# Compute AUPR per class
aupr_scores = []
for i in range(y_score.shape[1]):
    y_bin = (y_true == i).astype(int)
    aupr = average_precision_score(y_bin, y_score[:, i])
    aupr_scores.append(aupr)

# Plot AUPR
plt.figure()
plt.bar(range(len(aupr_scores)), aupr_scores)
plt.xlabel("Cell Line Index")
plt.ylabel("AUPR")
plt.title("AUPR per Cell Line – Modlyn")
plt.tight_layout()
# plt.savefig("modlyn_aupr.pdf")


1. Uncertainty-Aware Dotplots
A plot where:

Dot size = uncertainty (e.g. inverse of standard error),

Dot color = effect size (e.g. logistic regression weight),
similar to Scanpy/Seurat dotplots or volcano plots

Comparison Metrics Across Methods

2. Correlation metrics like Kendall's Tau or Spearman's rho to quantify agreement between:

Logistic regression weights (Modlyn)

Wilcoxon test statistics

Scanpy logistic regression results

3. Differential Expression Testing Benchmark

Use sc.tl.rank_genes_groups() (Wilcoxon) as a reference.

Compare results with:

LogReg-derived weights using t-test-style criteria

Associated uncertainty estimates

4. Scale Testing on Small Dataset

Run this script: https://lamin.ai/laminlabs/arrayloader-benchmarks/transform/mVi9vDOMcgir on a small dataset to validate methodology before full scale7d994388-c1d5-41d7-ba15….

5. Visual Comparison Figure (Figure 1 Style)
A 3-panel figure:

Left: Bulk average expression dotplot

Middle: Dotplot from Wilcoxon

Right: Dotplot from logistic regression with uncertainty (Modlyn)


TODO: Felix: 1M, 10M, 100M datasets

TODO: Subsampled version: Make one 1M cells test dataset

scanpy reproduction and multinomial logistic regression for now they can get some interpretations out of these data better scvi and limma


sklearn for ligistic regression or 

In [None]:
top_genes_per_cellline = {}
for i, cell_line in enumerate(weights_df.index):
    top_genes = weights_df.loc[cell_line].nlargest(20)
    top_genes_per_cellline[cell_line] = top_genes

top_cell_lines = adata.obs["cell_line"].value_counts().index[:3].tolist()
adata_subset = adata[adata.obs["cell_line"].isin(top_cell_lines)].copy()

cell_lines = list(top_genes_per_cellline.keys())[:3]

top_sets = [set(top_genes_per_cellline[cl].index[:10]) for cl in cell_lines]
shared_genes = list(set.intersection(*top_sets))
if not shared_genes:
    shared_genes = list(set.union(*[set(top_genes_per_cellline[cl].index[:5]) for cl in cell_lines]))

print(f"Analyzing {len(shared_genes)} shared genes")

In [None]:
cell_lines_unique = adata_subset.obs["cell_line"].unique().tolist()
print(cell_lines)

In [None]:
import seaborn as sns
from matplotlib import cm

def compute_fisher_info_and_se(model, dataloader):
    model.eval()
    fisher_diag = torch.zeros_like(model.linear.weight)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Computing Fisher Information"):
            x, y = batch
            logits = model.linear(x)
            probs = torch.softmax(logits, dim=1)

            for i in range(probs.shape[1]):
                p = probs[:, i].unsqueeze(1)
                fisher_i = p * (1 - p) * x**2
                fisher_diag[i] += fisher_i.sum(dim=0)

    se = torch.sqrt(1.0 / (fisher_diag + 1e-8))  # Add epsilon for numerical stability
    confidence = 1.0 / se**2
    return se.cpu().numpy(), confidence.cpu().numpy()

se, confidence = compute_fisher_info_and_se(linear, datamodule.val_dataloader())
confidence_df = pd.DataFrame(
    confidence,
    columns=adata.var_names,
    index=weights_df.index  # real cell line names
)

In [None]:
top_genes = list({gene for genes in top_genes_per_cellline.values() for gene in genes.index})[1:300]


### expression_cutoff=-2

In [None]:
from anndata import AnnData
from sklearn.preprocessing import minmax_scale
import numpy as np
import pandas as pd
import scanpy as sc

# Subset weights and confidence matrices
weights_sub = weights_df[top_genes]
confidence_sub = confidence_df[top_genes]
# print(weights_sub.head())

# Clip high confidence outliers (optional but helps visibility)
conf_clipped = np.clip(confidence_sub, 0, np.percentile(confidence_sub, 99))

# Min-max scale confidence row-wise to simulate expression fractions
conf_scaled = pd.DataFrame(
    minmax_scale(conf_clipped, axis=1),
    index=confidence_sub.index,
    columns=confidence_sub.columns
)

# Build AnnData: .X = confidence → controls dot size
adata_dot = AnnData(
    X=conf_scaled.values,
    obs=pd.DataFrame(index=conf_scaled.index),
    var=pd.DataFrame(index=conf_scaled.columns)
)

# Add group labels for Scanpy to use
adata_dot.obs["cell_line"] = conf_scaled.index
adata_dot.obs_names = conf_scaled.index
adata_dot.var_names = conf_scaled.columns

# Add weights (effect sizes) as color layer
adata_dot.layers["weights"] = weights_sub.loc[adata_dot.obs_names, adata_dot.var_names].values

# Plot: size from .X (confidence), color from weights layer
sc.pl.dotplot(
    adata=adata_dot,
    var_names=adata_dot.var_names.tolist(), 
    groupby="cell_line",
    use_raw=False,
    layer="weights",  # color: logistic regression weights
    # title="Uncertainty-Aware DotPlot (Color = Effect Size, Size = Confidence)",
    # colorbar_title="Effect Size (Weight)",
    cmap="RdBu_r",
    vcenter=0.0, #vmax = , vmin = ,
    expression_cutoff=-2,  # show all values
    # dot_min=0.01,
    # dot_max=0.3,
    # smallest_dot=1.0,  # ensures small dots are visible
    # standard_scale=None,
    # mean_only_expressed=False,
    show=True
)


In [None]:
stderr_df = 1 / np.sqrt(confidence_df)

weights_sub = weights_df[top_genes]
stderr_sub = stderr_df[top_genes]

certainty = 1 / stderr_sub.replace(0, np.nan)  # avoid division by zero
certainty = certainty.fillna(0)

certainty_clipped = np.clip(certainty, 0, np.percentile(certainty, 99))
certainty_scaled = pd.DataFrame(
    minmax_scale(certainty_clipped, axis=1),
    index=certainty.index,
    columns=certainty.columns
)

adata_dot = AnnData(
    X=certainty_scaled.values,
    obs=pd.DataFrame(index=certainty_scaled.index),
    var=pd.DataFrame(index=certainty_scaled.columns)
)

adata_dot.obs["cell_line"] = certainty_scaled.index
adata_dot.obs_names = certainty_scaled.index
adata_dot.var_names = certainty_scaled.columns
adata_dot.layers["weights"] = weights_sub.loc[adata_dot.obs_names, adata_dot.var_names].values


adata_dot.X = certainty_scaled.values
# print(adata_dot.X)

sc.pl.dotplot(
    adata=adata_dot,
    var_names=adata_dot.var_names.tolist(), 
    groupby="cell_line",
    use_raw=False,
    layer="weights",
    # cmap="RdBu_r",
    vcenter=0.0,
    expression_cutoff=0,  # show all dots
    dot_min=0,           # controls minimal dot size
    dot_max=1,           # controls maximal dot size
    smallest_dot=0.1,      # ensures visibility of small values
    show=True
)

### expression_cutoff=0 

In [None]:
adata_dot.X = weights_sub.loc[adata_dot.obs_names, adata_dot.var_names].values

# Put certainty (scaled inverse stderr) into .raw → will control size
adata_dot.raw = AnnData(
    X=certainty_scaled.values,
    obs=adata_dot.obs.copy(),
    var=adata_dot.var.copy()
)

# Plot
sc.pl.dotplot(
    adata=adata_dot,
    var_names=adata_dot.var_names.tolist(),
    groupby="cell_line",
    use_raw=True,             # size from certainty (.raw.X)
    cmap="RdBu_r",            # color from weights (.X)
    vcenter=0.0,
    expression_cutoff=0,      # show all
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.5,
    show=True
)


### z scores

In [None]:
z_scores = weights_sub / (1 / certainty)  # = weight * certainty
adata_dot.X = certainty.values
adata_dot.layers["z"] = z_scores.values

sc.pl.dotplot(
    adata=adata_dot,
    var_names=adata_dot.var_names.tolist(),
    groupby="cell_line",
    layer="z",
    vcenter=0,
    expression_cutoff=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.5,
    show=True
)


### Normalize weights between -1,1

In [None]:
w_abs_max = np.percentile(np.abs(weights_sub.values), 99)

# Normalize
weights_scaled = weights_sub / w_abs_max
weights_scaled = weights_scaled.clip(-1, 1)  # keep in [-1, 1]

# Update dotplot dataweights_sub
adata_dot.X = certainty_scaled.values  
adata_dot.layers["weights_scaled"] = weights_scaled.loc[adata_dot.obs_names, adata_dot.var_names].values

sc.pl.dotplot(
    adata=adata_dot,
    var_names=adata_dot.var_names.tolist(),
    groupby="cell_line",
    layer="weights_scaled",
    cmap="RdBu_r",
    vcenter=0.0,
    expression_cutoff=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.5,
    show=True
)



### Correlation between weights and uncertainty

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Flatten
w = weights_sub.values.flatten()
se = (1 / certainty.values).flatten()  # standard error
valid = ~np.isnan(w) & ~np.isnan(se) & np.isfinite(w) & np.isfinite(se)

# Plot
sns.scatterplot(x=w[valid], y=se[valid], alpha=0.3)
plt.xlabel("Weight")
plt.ylabel("Standard Error")
plt.title("Weight vs. Uncertainty")
plt.grid(True)
plt.show()

from scipy.stats import spearmanr
rho, _ = spearmanr(w[valid], se[valid])
print(f"Spearman correlation: {rho:.2f}")


In [None]:
sns.histplot(w[valid] / se[valid], bins=100)
plt.xlabel("Weight / SE")


# Scanpy

## logreg

In [None]:
adata_sc = adata_train.copy()
sc.pp.log1p(adata_sc)
adata_sc.X = adata_sc.X.compute()
adata_sc.X = np.array(adata_sc.X) 


In [None]:
sc.tl.rank_genes_groups(
    adata_sc,
    groupby="cell_line", 
    method="logreg", 
    key_added="logreg"
)


In [None]:
# sc.pl.rank_genes_groups_dotplot(
#     adata_sc,
#     key="logreg",
#     n_genes=10,
#     groupby="cell_line",
#     cmap="RdBu_r",
#     vcenter=0,
#     show=True
# )

In [None]:
# ### Same genes as modlyn
# sc.pl.dotplot(
#     adata_sc,
#     var_names=top_genes,
#     groupby="cell_line",
#     standard_scale="var",
#     cmap="RdBu_r",
#     vcenter=0.0,
#     dot_min=0.2,
#     dot_max=1.0,
#     smallest_dot=0.5,
#     show=True
# )


## wilcoxon

In [None]:
warnings.filterwarnings('ignore')
sc.tl.rank_genes_groups(
    adata_sc,
    groupby="cell_line",     
    method="wilcoxon",           
    key_added="wilcoxon"
)


In [None]:
sc.pl.rank_genes_groups_dotplot(
    adata_sc,
    key="wilcoxon",
    n_genes=10,
    groupby="cell_line",
    cmap="RdBu_r",
    vcenter=0,
    show=True
)

In [None]:
adata_sc

In [None]:
# adata_sc.write("adata_chunk30_processed.h5ad")
# adata_dot.write("adata_dask_chunk30_processed.h5ad")

In [None]:
# sc.pl.dotplot(
#     adata_sc,
#     var_names=top_genes,
#     groupby="cell_line",
#     standard_scale="var",      
#     cmap="RdBu_r",
#     vcenter=0.0,
#     dot_min=0.2,
#     dot_max=1.0,
#     smallest_dot=0.5,
#     show=True
# )


### Correlation (Spearman's rho) to quantify agreement between:
#### Logistic regression weights (Modlyn) vs Wilcoxon test statistics vs Scanpy logistic regression results

In [None]:
print(adata_sc)
adata_dot

In [None]:
genes = top_genes
groups = adata_sc.obs["cell_line"].unique()
cell_lines = adata_sc.uns["logreg"]["names"].dtype.names
groups

In [None]:
wilcoxon_scores = pd.DataFrame(
    {cl: adata_sc.uns["wilcoxon"]["scores"][cl] for cl in cell_lines},
    index=adata_sc.uns["wilcoxon"]["names"][cell_lines[0]]
).T[genes]  # T = transpose → cell_line x gene

# Scanpy logreg scores
scanpy_logreg_scores = pd.DataFrame(
    {cl: adata_sc.uns["logreg"]["scores"][cl] for cl in cell_lines},
    index=adata_sc.uns["logreg"]["names"][cell_lines[0]]
).T[genes]

# print(scanpy_logreg_scores)

In [None]:
modlyn_df = pd.DataFrame(
    adata_dot.layers["weights"],
    index=adata_dot.obs_names,
    columns=adata_dot.var_names
)
modlyn_df.head()


In [None]:
# Common cell lines
common_cls = list(set(cell_lines).intersection(modlyn_df.index))

print(f"Shared cell lines: {len(common_cls)}")
common_cls

In [None]:
results = []

for cl in common_cls:
    for name, scores in [
        ("scanpy_logreg", scanpy_logreg_scores),
        ("wilcoxon", wilcoxon_scores)
    ]:
        if cl in scores.index:
            rho, _ = spearmanr(modlyn_df.loc[cl], scores.loc[cl])
            results.append({"cell_line": cl, "vs": name, "rho": rho})


In [None]:
# import seaborn as sns
# import matplotlib.pyplot as plt

# sns.boxplot(data=results, x="vs", y="rho")
# plt.axhline(0, color="gray", linestyle="--")
# plt.title("Spearman Correlation with Modlyn Weights")
# plt.show()


### Differential Expression Testing Benchmark

Use sc.tl.rank_genes_groups() (Wilcoxon) as a reference.

Compare results with:

LogReg-derived weights using t-test-style criteria

Associated uncertainty estimates


In [None]:
# sc.tl.rank_genes_groups(
#     adata_sc,
#     groupby="cell_line",
#     method="wilcoxon",
#     key_added="wilcoxon"
# )


In [None]:
top_wilcoxon = {
    cl: list(adata_sc.uns["wilcoxon"]["names"][cl][:100])
    for cl in adata_sc.uns["wilcoxon"]["names"].dtype.names
}


In [None]:
weights = pd.DataFrame(adata_dot.layers["weights"], index=adata_dot.obs_names, columns=adata_dot.var_names)
stderr = 1 / certainty  
z_scores = weights / stderr
z_scores = z_scores.replace([np.inf, -np.inf], np.nan).fillna(0).clip(-20, 20)


In [None]:
top_logreg = {
    cl: z_scores.loc[cl].abs().sort_values(ascending=False).head(100).index.tolist()
    for cl in z_scores.index
}


In [None]:
def overlap(set1, set2):
    return len(set(set1).intersection(set2)) / len(set2)

benchmark = pd.DataFrame({
    cl: {
        "overlap_wilcoxon_vs_logregZ": overlap(top_wilcoxon[cl], top_logreg[cl])
    }
    for cl in z_scores.index
}).T


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

benchmark = benchmark.sort_values("overlap_wilcoxon_vs_logregZ")
sns.barplot(data=benchmark, x=benchmark.index, y="overlap_wilcoxon_vs_logregZ")
plt.xticks(rotation=90)
plt.ylabel("DE overlap (Wilcoxon vs. LogReg Z)")
plt.title("DE Benchmark: Wilcoxon vs. LogReg + Uncertainty")
plt.tight_layout()
plt.show()


In [None]:
print(adata_sc)
adata_dot

In [None]:
# ln.finish()


In [None]:
### NEW
genes = top_genes[:30]
genes

In [None]:
# fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# sc.pl.dotplot(adata_sc, var_names=genes, groupby="cell_line", ax=axes[0],
#               cmap="RdBu_r", vcenter=0.0, dot_min=0.2, dot_max=1.0,
#               standard_scale="var", show=False)
# axes[0].set_title("Scanpy LogReg")

# sc.pl.dotplot(adata_sc, var_names=genes, groupby="cell_line", ax=axes[1],
#               cmap="RdBu_r", vcenter=0.0, dot_min=0.2, dot_max=1.0, standard_scale="var", show=False)
# axes[1].set_title("Wilcoxon")

# sc.pl.dotplot(adata_dot, var_names=genes, groupby="cell_line", ax=axes[2],
#                cmap="RdBu_r", vcenter=0.0,
#               use_raw=True, dot_min=0.2, dot_max=1.0, show=False)
# axes[2].set_title("Modlyn (weights + uncertainty)")

# plt.tight_layout()
# plt.show()


In [None]:
logreg_scores = pd.DataFrame(
    {cl: adata_sc.uns["logreg"]["scores"][cl] for cl in adata_sc.uns["logreg"]["scores"].dtype.names},
    index=adata_sc.uns["logreg"]["names"][adata_sc.uns["logreg"]["scores"].dtype.names[0]]
).T[genes]  # shape: cell_line x gene

wilcoxon_scores = pd.DataFrame(
    {cl: adata_sc.uns["wilcoxon"]["scores"][cl] for cl in adata_sc.uns["wilcoxon"]["scores"].dtype.names},
    index=adata_sc.uns["wilcoxon"]["names"][adata_sc.uns["wilcoxon"]["scores"].dtype.names[0]]
).T[genes]

# dot_size_df = logreg_scores.abs()

modlyn_weights = pd.DataFrame(
    adata_dot.layers["weights_scaled"],
    index=adata_dot.obs_names,
    columns=adata_dot.var_names
)[genes]

modlyn_certainty = pd.DataFrame(
    adata_dot.X,  # this holds certainty_scaled
    index=adata_dot.obs_names,
    columns=adata_dot.var_names
)[genes]

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# LogReg
sc.pl.dotplot(
    adata_sc,
    var_names=genes,
    groupby="cell_line",
    dot_color_df=logreg_scores,
    # dot_size_df=dot_size_df,
    ax=axes[0],
    cmap="RdBu_r",
    vcenter=0,
    # dot_min=0.2,
    # dot_max=1.0,
    # smallest_dot=0.1,
    show=False
)
axes[0].set_title("Scanpy LogReg")

# Wilcoxon
sc.pl.dotplot(
    adata_sc,
    var_names=genes,
    groupby="cell_line",
    dot_color_df=wilcoxon_scores,
    # dot_size_df=dot_size_df,#wilcoxon_scores.abs(),
    ax=axes[1],
    cmap="RdBu_r",
    vcenter=0,
    # dot_min=0.2,
    # dot_max=1.0,
    # smallest_dot=0.1,
    show=False
)
axes[1].set_title("Wilcoxon")

# Modlyn
sc.pl.dotplot(
    adata_dot,
    var_names=genes,
    groupby="cell_line",
    dot_color_df=modlyn_weights,
    dot_size_df=modlyn_certainty,
    ax=axes[2],
    cmap="RdBu_r",
    vcenter=0,
    # dot_min=0.2,
    # dot_max=1.0,
    # smallest_dot=0.1,
    show=False
)
axes[2].set_title("Modlyn (weights + uncertainty)")

plt.tight_layout()
plt.show()

## Scale scores for comparison

In [None]:
logreg_scores = pd.DataFrame(
    {cl: adata_sc.uns["logreg"]["scores"][cl] for cl in adata_sc.uns["logreg"]["scores"].dtype.names},
    index=adata_sc.uns["logreg"]["names"][adata_sc.uns["logreg"]["scores"].dtype.names[0]]
).T[genes]  # shape: cell_line x gene

wilcoxon_scores = pd.DataFrame(
    {cl: adata_sc.uns["wilcoxon"]["scores"][cl] for cl in adata_sc.uns["wilcoxon"]["scores"].dtype.names},
    index=adata_sc.uns["wilcoxon"]["names"][adata_sc.uns["wilcoxon"]["scores"].dtype.names[0]]
).T[genes]

# dot_size_df = logreg_scores.abs()

modlyn_weights = pd.DataFrame(
    adata_dot.layers["weights_scaled"],
    index=adata_dot.obs_names,
    columns=adata_dot.var_names
)[genes]

modlyn_certainty = pd.DataFrame(
    adata_dot.X,  # this holds certainty_scaled
    index=adata_dot.obs_names,
    columns=adata_dot.var_names
)[genes]


from sklearn.preprocessing import minmax_scale

# Color: normalize all scores to [-1, 1]
def scale_weights(df):
    vmax = np.percentile(np.abs(df.values), 99)
    return df.clip(-vmax, vmax) / vmax

logreg_scaled = scale_weights(logreg_scores[genes])
wilcoxon_scaled = scale_weights(wilcoxon_scores[genes])
modlyn_scaled = scale_weights(modlyn_weights)

# Size: normalize certainty/abs(score) to [0, 1]
logreg_size = pd.DataFrame(minmax_scale(logreg_scores[genes].abs(), axis=1),
                           index=logreg_scores.index, columns=genes)
wilcoxon_size = pd.DataFrame(minmax_scale(wilcoxon_scores[genes].abs(), axis=1),
                             index=wilcoxon_scores.index, columns=genes)
modlyn_size = pd.DataFrame(minmax_scale(modlyn_certainty, axis=1),
                           index=modlyn_certainty.index, columns=genes)


fig, axes = plt.subplots(1, 3, figsize=(18, 6))

sc.pl.dotplot(
    adata_sc,
    var_names=genes,
    groupby="cell_line",
    dot_color_df=logreg_scaled,
    dot_size_df=logreg_size,
    ax=axes[0],
    cmap="RdBu_r",
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=False
)
axes[0].set_title("Scanpy LogReg (scaled)")

sc.pl.dotplot(
    adata_sc,
    var_names=genes,
    groupby="cell_line",
    dot_color_df=wilcoxon_scaled,
    dot_size_df=wilcoxon_size,
    ax=axes[1],
    cmap="RdBu_r",
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=False
)
axes[1].set_title("Wilcoxon (scaled)")

sc.pl.dotplot(
    adata_dot,
    var_names=genes,
    groupby="cell_line",
    dot_color_df=modlyn_scaled,
    dot_size_df=modlyn_size,
    ax=axes[2],
    cmap="RdBu_r",
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=False
)
axes[2].set_title("Modlyn (scaled)")

plt.tight_layout()
plt.show()


In [None]:
corr_df = pd.DataFrame(results)
plt.figure(figsize=(6, 4))
sns.boxplot(data=corr_df, x="vs", y="rho")
plt.axhline(0, linestyle="--", color="gray")
plt.ylabel("Spearman ρ")
plt.title("Rank Correlation: Modlyn vs. Other Methods")
plt.tight_layout()
plt.show()


In [None]:
## If top genes overlap across 2+ methods → higher confidence

In [None]:
consensus_genes = []
for cl in common_cls:
    top_m = set(modlyn_df.loc[cl].abs().nlargest(100).index)
    top_w = set(wilcoxon_scores.loc[cl].abs().nlargest(100).index)
    top_s = set(scanpy_logreg_scores.loc[cl].abs().nlargest(100).index)
    overlap = top_m & top_w & top_s
    consensus_genes.extend(overlap)

from collections import Counter
shared_counts = Counter(consensus_genes)
top_consensus = [g for g, c in shared_counts.items() if c >= 2]
print(top_consensus[:10])


## Linscvi

In [None]:
import scvi


In [None]:
# Log-transform
adata_scvi = adata_train.copy()

sc.pp.log1p(adata_scvi)
adata_scvi.X = adata_train.X.compute()
adata_scvi.X = np.array(adata_scvi.X)

adata_sub = adata_scvi[np.random.choice(adata_scvi.n_obs, 2000, replace=False)].copy()


In [None]:
scvi.model.LinearSCVI.setup_anndata(adata_sub, labels_key="cell_line")


In [None]:
model = scvi.model.LinearSCVI(adata_sub, gene_likelihood="gaussian")
model.view_anndata_setup()

In [None]:
model.train()


In [None]:
print(model.get_loadings())
print(model.summary_string)

In [None]:
labels = adata_sub.obs["cell_line"].values

import time
start = time.time()
Z = model.get_latent_representation(batch_size=64)
print(f"Elapsed: {time.time() - start:.2f} seconds")


In [None]:
labels_unique = np.unique(labels)

Z_mean = np.stack([Z[labels == k].mean(axis=0) for k in labels_unique])

# Project into gene space
W = model.get_loadings().values  # shape: genes × latent
weights = Z_mean @ W.T  # shape: cell_lines × genes

# Wrap up as DataFrame
weights_df = pd.DataFrame(
    weights,
    index=labels_unique,
    columns=model.adata.var_names
)
weights_df

In [None]:
from sklearn.preprocessing import minmax_scale

# Normalize weights (for plotting)
w_scaled = weights_df.clip(-np.percentile(np.abs(weights_df), 99), 
                           np.percentile(np.abs(weights_df), 99))
w_scaled = w_scaled / np.percentile(np.abs(w_scaled.values), 99)

# Certainty estimate → use abs(weight) as proxy (LinearSCVI doesn't output SE directly)
certainty = weights_df.abs()
certainty_scaled = pd.DataFrame(minmax_scale(certainty, axis=1),
                                index=certainty.index,
                                columns=certainty.columns)


In [None]:
adata_dot_lscvi = ad.AnnData(
    X=certainty_scaled.values,
    obs=pd.DataFrame(index=certainty_scaled.index),
    var=pd.DataFrame(index=certainty_scaled.columns)
)
adata_dot_lscvi.obs["cell_line"] = adata_dot_lscvi.obs.index
adata_dot_lscvi.obs_names = adata_dot_lscvi.obs.index
adata_dot_lscvi.var_names = adata_dot_lscvi.var.index
adata_dot_lscvi.layers["weights_scaled"] = w_scaled.loc[adata_dot_lscvi.obs_names, adata_dot_lscvi.var_names].values


In [None]:
lscvi_size = pd.DataFrame(minmax_scale(certainty, axis=1),
                           index=certainty.index, columns=certainty.columns)

dot_color = pd.DataFrame(
    adata_dot_lscvi.layers["weights_scaled"],
    index=adata_dot_lscvi.obs_names,
    columns=adata_dot_lscvi.var_names
)

sc.pl.dotplot(
    adata_dot_lscvi,
    var_names=top_genes,
    groupby="cell_line",
    dot_color_df=dot_color[top_genes],
    dot_size_df=lscvi_size[top_genes],
    cmap="RdBu_r",
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=True
)

## Comparisons

In [None]:
# certainty and scaled weights from LinearSCVI
lscvi_weights = pd.DataFrame(
    adata_dot_lscvi.layers["weights_scaled"],
    index=adata_dot_lscvi.obs_names,
    columns=adata_dot_lscvi.var_names
)[genes]

lscvi_certainty = pd.DataFrame(
    adata_dot_lscvi.X,
    index=adata_dot_lscvi.obs_names,
    columns=adata_dot_lscvi.var_names
)[genes]

# Normalize weights and certainty
lscvi_scaled = scale_weights(lscvi_weights)
lscvi_size = pd.DataFrame(minmax_scale(lscvi_certainty, axis=1),
                          index=lscvi_certainty.index,
                          columns=lscvi_certainty.columns)


In [None]:
fig, axes = plt.subplots(1, 4, figsize=(24, 6))  # changed to 4 plots

# Panel 1: LogReg
sc.pl.dotplot(
    adata_sc,
    var_names=genes,
    groupby="cell_line",
    dot_color_df=logreg_scaled,
    dot_size_df=logreg_size,
    ax=axes[0],
    cmap="RdBu_r",
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=False
)
axes[0].set_title("Scanpy LogReg (scaled)")

# Panel 2: Wilcoxon
sc.pl.dotplot(
    adata_sc,
    var_names=genes,
    groupby="cell_line",
    dot_color_df=wilcoxon_scaled,
    dot_size_df=wilcoxon_size,
    ax=axes[1],
    cmap="RdBu_r",
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=False
)
axes[1].set_title("Wilcoxon (scaled)")

# Panel 3: Modlyn
sc.pl.dotplot(
    adata_dot,
    var_names=genes,
    groupby="cell_line",
    dot_color_df=modlyn_scaled,
    dot_size_df=modlyn_size,
    ax=axes[2],
    cmap="RdBu_r",
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=False
)
axes[2].set_title("Modlyn (scaled)")

# Panel 4: LSCVI
sc.pl.dotplot(
    adata_dot_lscvi,
    var_names=genes,
    groupby="cell_line",
    dot_color_df=lscvi_scaled,
    dot_size_df=lscvi_size,
    ax=axes[3],
    cmap="RdBu_r",
    vcenter=0,
    dot_min=0.2,
    dot_max=1.0,
    smallest_dot=0.1,
    show=False
)
axes[3].set_title("LinearSCVI (scaled)")

plt.tight_layout()
plt.show()


In [None]:
methods = {
    "LogReg": logreg_scaled,
    "Wilcoxon": wilcoxon_scaled,
    "LinearSCVI": lscvi_scaled
}

results = []

for method_name, df in methods.items():
    for cl in modlyn_scaled.index.intersection(df.index):
        rho, _ = spearmanr(modlyn_scaled.loc[cl], df.loc[cl])
        results.append({
            "cell_line": cl,
            "method": method_name,
            "spearman_rho": rho
        })

corr_df = pd.DataFrame(results)

plt.figure(figsize=(6, 4))
sns.boxplot(data=corr_df, x="method", y="spearman_rho")
sns.stripplot(data=corr_df, x="method", y="spearman_rho", color="black", alpha=0.4, jitter=0.15)

plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.ylabel("Spearman ρ with Modlyn")
plt.title("Correlation of Modlyn weights vs. other methods")
plt.tight_layout()
plt.show()



In [None]:
methods = {
    "LogReg": logreg_scaled,
    "Wilcoxon": wilcoxon_scaled,
    "LinearSCVI": lscvi_scaled
}

# rows: cell lines, columns: methods
heatmap_data = {}

for cl in modlyn_scaled.index:
    row = {}
    for method_name, df in methods.items():
        if cl in df.index:
            rho, _ = spearmanr(modlyn_scaled.loc[cl], df.loc[cl])
            row[method_name] = rho
    heatmap_data[cl] = row

heatmap_df = pd.DataFrame.from_dict(heatmap_data, orient="index").sort_index()
heatmap_sorted = heatmap_df.loc[heatmap_df.mean(axis=1).sort_values(ascending=False).index]

plt.figure(figsize=(8, 0.4 * len(heatmap_sorted)))
sns.heatmap(
    heatmap_sorted,
    annot=True,
    fmt=".2f",
    cmap="coolwarm",
    center=0,
    norm=TwoSlopeNorm(vcenter=0),
    linewidths=0.4,
    linecolor="white",
    cbar_kws={"label": "Spearman ρ (vs Modlyn)", "shrink": 0.8}
)

plt.title("Per-Cell Line Correlation with Modlyn", fontsize=14, weight="bold")
plt.ylabel("Cell Line")
plt.xlabel("Comparison Method")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

In [None]:
from upsetplot import from_memberships, UpSet
import matplotlib.pyplot as plt

cell_line = "CVCL_0459"
top_k = 20

# Extract top genes per method for this cell line
top_genes_per_method = {
    "LogReg": set(logreg_scaled.loc[cell_line].abs().nlargest(top_k).index),
    "Wilcoxon": set(wilcoxon_scaled.loc[cell_line].abs().nlargest(top_k).index),
    "Modlyn": set(modlyn_scaled.loc[cell_line].abs().nlargest(top_k).index),
    "LinearSCVI": set(lscvi_scaled.loc[cell_line].abs().nlargest(top_k).index),
}


In [None]:
memberships = []
for gene in set.union(*top_genes_per_method.values()):
    methods = [method for method, genes in top_genes_per_method.items() if gene in genes]
    memberships.append(methods)

# Convert to UpSet data
upset_data = from_memberships(memberships)

In [None]:
plt.figure(figsize=(10, 5))
UpSet(
    upset_data,
    sort_by="degree",
    show_counts=True,
    min_subset_size=1,
    subset_size="count"
).plot()

plt.suptitle(f"Top {top_k} Gene Overlap for Cell Line {cell_line}", fontsize=14, weight="bold")
plt.tight_layout()
plt.show()

## Questions:

#### Does the Scanpy LogReg method (inspired by Ntranos et al., Nature Methods, 2018) quantify uncertainty? How does that compare to what Modlyn provides?

They do use logistic regression, but primarily for ranking features, not for quantifying uncertainty.

Confirm: Scanpy’s sc.tl.rank_genes_groups(..., method="logreg") does not return confidence intervals or standard errors.

In contrast, Modlyn explicitly computes uncertainty via inverse Fisher Information.

Conclusion:
Only Modlyn returns an interpretable, model-derived estimate of uncertainty per weight. That’s a functional difference worth highlighting in your benchmark.



#### What explains the difference between Scanpy LogReg and Modlyn LogReg?
Check Scanpy’s LogReg: it performs one-vs-rest logistic regression using sklearn.linear_model.LogisticRegression with default settings.

That includes L2 regularization, no uncertainty, no confidence filtering.

It's done per group independently.

Modlyn:

Uses a joint model trained with mini-batches, likely using multinomial logistic regression.

Returns dense weights, filtered or visualized by certainty.

Can learn from the full dataset jointly rather than slicing the problem into binary tasks.

Conclusion:
Modlyn’s implementation is fundamentally different: multivariate, batch-trained, and uncertainty-aware. Scanpy’s is simpler, independent per label, and trained on small subsets of data at once.

#### Why do Scanpy LogReg and LinearSCVI not discriminate between conditions?
For Scanpy LogReg:

Examine if the method is underpowered due to label imbalance or small training sets per group.

Possibly the logistic classifier hits a ceiling with L2 penalty.

For LinearSCVI:

It's a generative model, optimized for reconstruction, not discrimination.

The weights you extract (via decoder projection) are not tuned for group separation.

Suggested analysis:

Compute classification accuracy or AUPR per method using held-out labels.

Visualize group separability using UMAP of latent space (especially for LinearSCVI).

Report dotplot sparsity per method: average number of genes above a certainty/weight threshold per group.




#### Are we comparing all conditions fairly?
Could the signal in Wilcoxon be relative to groups not shown (e.g., 800 cell lines), affecting upregulation interpretation?

Check adata.obs["cell_line"].value_counts() to get full distribution.

Determine which groups are included in each method (e.g., top 50 most frequent? All?).

For Wilcoxon in Scanpy: it's always group vs. rest, so “rest” changes depending on what’s in the dataset.

Suggested analysis:

Run Wilcoxon on full dataset and on top-50 groups and compare results.

Report how many groups are included in each method’s comparison (len(adata.uns["wilcoxon"]["names"]) etc.).

Conclusion:
If the comparison set is unbalanced or truncated, DE signal may be inflated or suppressed, depending on how the reference is defined.

#### Should we use classification-style metrics instead of correlation?
Do correlation metrics (like Spearman ρ) capture the biological or functional similarity across methods? Or would AUPR / accuracy be better?

Run AUPR per cell line:

Treat the top 100 genes from one method as positives

Use ranked weights from another method as predictions

Use pairwise classification metrics to compare ranked gene sets between methods


precision recall instead of accuracy: modlyn logerg vs scanpy logreg agree XXX% and YYY AUPR with wilcoxon, etc. Evrything is comparable and we can trust them.

#### Biology
Drugs vs Genes: Which drugs activate pathways XYZ 

Comparison with T cells (pert vs unpert)

Early activation?

Makrer genes for cell lines - proof of concept of the tool that allows us to ask "how 50 cell lines respond to 50k conditions?" 250k hypothesis and you need to pick the interesting ones.

#### API like scanpy

Make the same API as scanpy 

As similar as possible so it's easy to use

In [None]:
# zarr v3 is faster 
# shuffle chunks and preshuffle everything
# way faster than merlyn
# arrayloaders 

In [None]:
ln.finish()