In [None]:
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scanpy as sc
import torch
import lightning as L
from scipy.stats import spearmanr
from sklearn.model_selection import train_test_split

from scvi.model import LinearSCVI
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder

np.random.seed(42)
torch.manual_seed(42)

import scvi
from modlyn.models import SimpleLogReg
from modlyn.models._simple_logreg_datamodule import SimpleLogRegDataModule

import lamindb as ln
project = ln.Project(name="scVI-Comparison")
project.save()
ln.track(project="scVI-Comparison")
run = ln.track()
print(f"scvi-tools version: {scvi.__version__}")
warnings.filterwarnings("ignore")


In [None]:
# Data Loading and Model Training
artifact = ln.Artifact.using("laminlabs/arrayloader-benchmarks").get("RymV9PfXDGDbM9ek0000")
adata = artifact.load()

# Filter and preprocess
cell_line_counts = adata.obs['cell_line'].value_counts()
valid_cell_lines = cell_line_counts[cell_line_counts >= 10].index
adata = adata[adata.obs['cell_line'].isin(valid_cell_lines)].copy()
if adata.X.max() > 10: sc.pp.log1p(adata)
adata.obs['cell_line'] = adata.obs['cell_line'].astype('category')
adata.obs['y'] = adata.obs['cell_line'].cat.codes

# Split data
train_ids, val_ids = train_test_split(adata.obs.index, test_size=0.2, random_state=42, stratify=adata.obs['y'])
adata_train = adata[train_ids].copy(); adata_val = adata[val_ids].copy()

# Train Modlyn
datamodule = SimpleLogRegDataModule(adata_train, adata_val, "y", 
                                  {"batch_size": len(adata_train), "num_workers": 0},
                                  {"batch_size": len(adata_val), "num_workers": 0})
modlyn_model = SimpleLogReg(adata_train, "y", 1e-2, 0.5)
trainer = L.Trainer(max_epochs=10, enable_progress_bar=False, logger=False, enable_checkpointing=False)
trainer.fit(modlyn_model, datamodule)

# Train scVI with consistent filtering (no additional cell filtering)
adata_scvi = adata_train.copy()
adata_scvi.obs['cell_line'] = adata_scvi.obs['cell_line'].astype('category')
LinearSCVI.setup_anndata(adata_scvi, labels_key='cell_line')
scvi_model = LinearSCVI(adata_scvi)
scvi_model.train(max_epochs=10, train_size=1.0, validation_size=None)

print(f"Training complete - Data: {adata.shape}, Cell lines: {adata.obs['y'].nunique()}")


In [None]:
# Extract Weights, Align, and Create Rankings
# Extract and align weights
modlyn_weights = modlyn_model.linear.weight.detach().cpu().numpy()
scvi_weights = scvi_model.get_loadings().values.T
modlyn_weights_array, scvi_weights_array = np.array(modlyn_weights), np.array(scvi_weights)

if modlyn_weights_array.shape != scvi_weights_array.shape:
    if (modlyn_weights_array.shape[0] == scvi_weights_array.shape[1] and modlyn_weights_array.shape[1] == scvi_weights_array.shape[0]):
        scvi_weights_array = scvi_weights_array.T
    if modlyn_weights_array.shape != scvi_weights_array.shape:
        min_classes = min(modlyn_weights_array.shape[0], scvi_weights_array.shape[0])
        min_features = min(modlyn_weights_array.shape[1], scvi_weights_array.shape[1])
        modlyn_weights_array = modlyn_weights_array[:min_classes, :min_features]
        scvi_weights_array = scvi_weights_array[:min_classes, :min_features]

# Setup variables and create rankings
gene_names = adata.var.index.tolist()
cell_lines = adata.obs['cell_line'].cat.categories.tolist()[:modlyn_weights_array.shape[0]]

# Scanpy rankings
adata_copy = adata_train.copy(); sc.pp.normalize_total(adata_copy, target_sum=1e4); sc.pp.log1p(adata_copy)
sc.tl.rank_genes_groups(adata_copy, groupby='cell_line', method='logreg', n_genes=len(gene_names), use_raw=False)
names_data = adata_copy.uns['rank_genes_groups']['names']
scanpy_rankings = {}
for cell_line in cell_lines:
    if cell_line in adata_copy.obs['cell_line'].cat.categories:
        if hasattr(names_data, 'dtype') and names_data.dtype.names:
            scanpy_rankings[cell_line] = names_data[cell_line].tolist()
        else:
            group_idx = list(adata_copy.obs['cell_line'].cat.categories).index(cell_line)
            if names_data.ndim == 2: scanpy_rankings[cell_line] = names_data[:, group_idx].tolist()

# Weight-based rankings
modlyn_rankings = {cell_lines[i]: [gene_names[j] for j in np.argsort(np.abs(modlyn_weights_array[i, :]))[::-1]] for i in range(min(len(cell_lines), modlyn_weights_array.shape[0]))}
scvi_rankings = {cell_lines[i]: [gene_names[j] for j in np.argsort(np.abs(scvi_weights_array[i, :]))[::-1]] for i in range(min(len(cell_lines), scvi_weights_array.shape[0]))}

# Known markers
lit_markers = {'CVCL_0023': ['ESR1', 'PGR', 'GREB1'], 'CVCL_0069': ['EGFR', 'KRAS', 'TP53'], 'CVCL_0131': ['APC', 'CTNNB1', 'TP53'], 'CVCL_0152': ['AFP', 'ALB', 'HNF4A'], 'CVCL_0179': ['BCR', 'ABL1', 'CD34'], 'CVCL_0218': ['AR', 'PSA', 'PSMA'], 'CVCL_0292': ['ERBB2', 'TOP2A', 'GRB7'], 'CVCL_0293': ['MITF', 'TYR', 'MLANA'], 'CVCL_0320': ['GFAP', 'EGFR', 'IDH1'], 'CVCL_0332': ['MITF', 'DCT', 'TYR'], 'CVCL_0334': ['ESR1', 'PGR', 'FOXA1'], 'CVCL_0359': ['CDX2', 'MUC2', 'KRT20'], 'CVCL_0366': ['AFP', 'APOB', 'HNF1A'], 'CVCL_0371': ['APC', 'MSH2', 'MLH1'], 'CVCL_0397': ['VIM', 'SNAI1', 'ZEB1']}
gene_set = set(gene_names)
known_markers = {cvcl: [g for g in markers if g in gene_set] or gene_names[i*2:(i+1)*2] for i, (cvcl, markers) in enumerate(lit_markers.items())}

print(f"Aligned shapes - Modlyn: {modlyn_weights_array.shape}, scVI: {scvi_weights_array.shape}")


In [None]:
# Gene Specificity + Correlation Analysis with Plots
# Gene specificity
def calc_specificity(weights, genes):
    return {genes[i]: (np.max(weights[:, i]) - np.min(weights[:, i])) / (np.mean(np.abs(weights[:, i])) + 1e-8) for i in range(min(len(genes), weights.shape[1]))}

modlyn_spec = calc_specificity(modlyn_weights_array, gene_names)
scvi_spec = calc_specificity(scvi_weights_array, gene_names)
modlyn_top = sorted(modlyn_spec.items(), key=lambda x: x[1], reverse=True)[:5]
scvi_top = sorted(scvi_spec.items(), key=lambda x: x[1], reverse=True)[:5]

print("Top specific genes:")
print("Modlyn:", [g for g, _ in modlyn_top])
print("scVI:  ", [g for g, _ in scvi_top])

# Gene correlation analysis
min_cell_lines = min(modlyn_weights_array.shape[0], scvi_weights_array.shape[0])
modlyn_subset, scvi_subset = modlyn_weights_array[:min_cell_lines, :], scvi_weights_array[:min_cell_lines, :]
correlations, valid_genes = [], []
for i in range(min(len(gene_names), modlyn_subset.shape[1], scvi_subset.shape[1])):
    modlyn_gene, scvi_gene = modlyn_subset[:, i], scvi_subset[:, i]
    if np.std(modlyn_gene) > 0 and np.std(scvi_gene) > 0:
        corr = np.corrcoef(modlyn_gene, scvi_gene)[0, 1]
        if not np.isnan(corr): correlations.append(corr); valid_genes.append(gene_names[i])

gene_corrs = list(zip(valid_genes, correlations))
gene_corrs.sort(key=lambda x: abs(x[1]), reverse=True)

# Plot correlations
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
ax1.hist(correlations, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
ax1.axvline(np.mean(correlations), color='red', linestyle='--', label=f'Mean: {np.mean(correlations):.3f}')
ax1.set_xlabel('Gene Correlation'); ax1.set_title('Gene Correlation Distribution'); ax1.legend()

for ax, genes_subset, title in [(ax2, gene_corrs[:15], 'Top 15'), (ax3, gene_corrs[-15:], 'Bottom 15')]:
    if genes_subset:
        genes, corrs = zip(*genes_subset)
        bars = ax.barh(range(len(genes)), corrs, color=['green' if c > 0 else 'red' for c in corrs], alpha=0.7)
        ax.set_yticks(range(len(genes))); ax.set_yticklabels(genes, fontsize=8)
        ax.set_title(f'{title} Correlated Genes'); ax.invert_yaxis()
        for i, (bar, corr) in enumerate(zip(bars, corrs)):
            ax.text(corr + 0.02*max(abs(min(corrs)), abs(max(corrs))), i, f'{corr:.3f}', va='center', fontsize=7)

plt.tight_layout(); plt.show()
print(f"Mean correlation: {np.mean(correlations):.3f}, High corr genes (r>0.5): {sum(1 for c in correlations if c > 0.5)}")


In [None]:
# Normalized Gene Weight Heatmaps
top_n = 25
modlyn_imp = np.mean(np.abs(modlyn_weights_array), axis=0)
scvi_imp = np.mean(np.abs(scvi_weights_array), axis=0)
top_idx = np.unique(np.concatenate([np.argsort(modlyn_imp)[-top_n:][::-1], np.argsort(scvi_imp)[-top_n:][::-1]]))
top_genes = [gene_names[i] for i in top_idx]

modlyn_top, scvi_top = modlyn_weights_array[:, top_idx], scvi_weights_array[:, top_idx]
modlyn_norm = (modlyn_top - np.mean(modlyn_top)) / (np.std(modlyn_top) + 1e-8)
scvi_norm = (scvi_top - np.mean(scvi_top)) / (np.std(scvi_top) + 1e-8)
vmax = max(np.max(np.abs(modlyn_norm)), np.max(np.abs(scvi_norm)))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
for ax, weights, title, cell_lines_subset in [(ax1, modlyn_norm, 'Modlyn', cell_lines[:modlyn_norm.shape[0]]), (ax2, scvi_norm, 'scVI', cell_lines[:scvi_norm.shape[0]])]:
    im = ax.imshow(weights, aspect='auto', cmap='RdBu_r', vmin=-vmax, vmax=vmax)
    ax.set_title(f'{title} Normalized Gene Weights', fontsize=14, fontweight='bold')
    ax.set_xticks(range(len(top_genes))); ax.set_xticklabels(top_genes, rotation=90, fontsize=8)
    ax.set_yticks(range(len(cell_lines_subset))); ax.set_yticklabels(cell_lines_subset, fontsize=8)
    plt.colorbar(im, ax=ax, label='Normalized Weight')

plt.tight_layout(); plt.show()


In [None]:
# AUPR Literature Validation + Final Summary
def aupr_score(rankings, markers):
    return [np.mean([len(set(rankings[cl][:k]) & set(markers[cl])) / k for k in range(1, min(len(rankings[cl]), 100) + 1)]) for cl in markers.keys() if cl in rankings and len(markers[cl]) > 0]

scanpy_aupr = aupr_score(scanpy_rankings, known_markers)
modlyn_aupr = aupr_score(modlyn_rankings, known_markers)
scvi_aupr = aupr_score(scvi_rankings, known_markers)

# Plot AUPR results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
valid_cls = [cl for cl in known_markers.keys() if cl in scanpy_rankings and cl in modlyn_rankings and cl in scvi_rankings and len(known_markers[cl]) > 0]

if valid_cls:
    scores_data = [[aupr_score({cl: scanpy_rankings[cl]}, {cl: known_markers[cl]})[0], aupr_score({cl: modlyn_rankings[cl]}, {cl: known_markers[cl]})[0], aupr_score({cl: scvi_rankings[cl]}, {cl: known_markers[cl]})[0]] for cl in valid_cls]
    scores_array = np.array(scores_data)
    x_pos = np.arange(len(valid_cls))
    width = 0.25
    for i, (label, color) in enumerate([('Scanpy', 'blue'), ('Modlyn', 'red'), ('scVI', 'green')]):
        ax1.bar(x_pos + i*width - width, scores_array[:, i], width, label=label, color=color, alpha=0.7)
    ax1.set_xlabel('Cell Lines'); ax1.set_ylabel('AUPR Score'); ax1.set_title('Literature Agreement by Cell Line')
    ax1.set_xticks(x_pos); ax1.set_xticklabels(valid_cls, rotation=45, ha='right'); ax1.legend()
else:
    ax1.text(0.5, 0.5, 'No valid cell lines for comparison', ha='center', va='center', transform=ax1.transAxes)

methods = ['Scanpy', 'Modlyn', 'scVI']
overall = [np.mean(s) if s else 0 for s in [scanpy_aupr, modlyn_aupr, scvi_aupr]]
errors = [np.std(s) if s else 0 for s in [scanpy_aupr, modlyn_aupr, scvi_aupr]]
bars = ax2.bar(methods, overall, yerr=errors, color=['blue', 'red', 'green'], alpha=0.7, capsize=5)
ax2.set_ylabel('Mean AUPR Score'); ax2.set_title('Overall Literature Agreement')
for bar, score in zip(bars, overall): ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{score:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout(); plt.show()

# Final summary
mean_corr = np.mean(correlations)
modlyn_avg_spec = np.mean(list(modlyn_spec.values()))
scvi_avg_spec = np.mean(list(scvi_spec.values()))
modlyn_aupr_mean = np.mean(modlyn_aupr) if modlyn_aupr else 0
scvi_aupr_mean = np.mean(scvi_aupr) if scvi_aupr else 0

ln.finish()
