In [None]:
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scanpy as sc
import torch
import lightning as L
from scipy.stats import spearmanr
from sklearn.model_selection import train_test_split

from scvi.model import LinearSCVI
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder

np.random.seed(42)
torch.manual_seed(42)

import scvi
from modlyn.models import SimpleLogReg
from modlyn.models._simple_logreg_datamodule import SimpleLogRegDataModule

import lamindb as ln
project = ln.Project(name="scVI-Comparison")
project.save()
ln.track(project="scVI-Comparison")
run = ln.track()
print(f"scvi-tools version: {scvi.__version__}")


In [None]:
warnings.filterwarnings("ignore")

artifact = ln.Artifact.using("laminlabs/arrayloader-benchmarks").get("RymV9PfXDGDbM9ek0000")
adata = artifact.load()
dataset_id = "RymV9PfXDGDbM9ek0000"

# Filter cell lines with sufficient cells
min_cells_per_line = 10
cell_line_counts = adata.obs['cell_line'].value_counts()
valid_cell_lines = cell_line_counts[cell_line_counts >= min_cells_per_line].index
adata = adata[adata.obs['cell_line'].isin(valid_cell_lines)].copy()

# Preprocessing
if adata.X.max() > 10:
    sc.pp.log1p(adata)

adata.obs['cell_line'] = adata.obs['cell_line'].astype('category')
adata.obs['y'] = adata.obs['cell_line'].cat.codes

print(f"Dataset: {adata.shape}, Classes: {adata.obs['y'].nunique()}")


In [None]:
train_ids, val_ids = train_test_split(
    adata.obs.index, test_size=0.2, random_state=42, stratify=adata.obs['y']
)
adata_train = adata[train_ids].copy()
adata_val = adata[val_ids].copy()

datamodule = SimpleLogRegDataModule(
    adata_train=adata_train,
    adata_val=adata_val,
    label_column="y",
    train_dataloader_kwargs={"batch_size": len(adata_train), "num_workers": 0},
    val_dataloader_kwargs={"batch_size": len(adata_val), "num_workers": 0}
)

modlyn_model = SimpleLogReg(
    adata=adata_train,
    label_column="y", 
    learning_rate=1e-2,
    weight_decay=0.5
)

trainer = L.Trainer(max_epochs=10, enable_progress_bar=False, logger=False, enable_checkpointing=False)
trainer.fit(modlyn_model, datamodule)

# Extract Modlyn results
modlyn_weights = modlyn_model.linear.weight.detach().cpu().numpy()
with torch.no_grad():
    X_tensor = torch.tensor(
        adata_train.X.toarray() if hasattr(adata_train.X, 'toarray') else adata_train.X, 
        dtype=torch.float32
    )
    modlyn_predictions = modlyn_model(X_tensor).argmax(dim=1).numpy()
    modlyn_accuracy = (modlyn_predictions == adata_train.obs['y'].values).mean()

print(f"Modlyn: {modlyn_accuracy:.3f} accuracy, weights {modlyn_weights.shape}")


In [None]:
from scvi.model import LinearSCVI

adata_scvi = adata_train.copy()
adata_scvi.obs['cell_line'] = adata_scvi.obs['cell_line'].astype('category')

print(f"Using same data for both methods:")
print(f"Modlyn data: {adata_train.shape}, cell lines: {adata_train.obs['cell_line'].nunique()}")
print(f"scVI data: {adata_scvi.shape}, cell lines: {adata_scvi.obs['cell_line'].nunique()}")

LinearSCVI.setup_anndata(adata_scvi, labels_key='cell_line', batch_key=None)
scvi_model = LinearSCVI(adata_scvi)
scvi_model.train(max_epochs=10, train_size=1.0, validation_size=None)


In [None]:
latent_repr = scvi_model.get_latent_representation()
le = LabelEncoder()
cell_line_encoded = le.fit_transform(adata_scvi.obs['cell_line'])

# Train a simple logistic regression classifier on latent space
classifier = LogisticRegression(random_state=42, max_iter=1000)
classifier.fit(latent_repr, cell_line_encoded)

# Make predictions
predictions = classifier.predict(latent_repr)
scvi_accuracy = accuracy_score(cell_line_encoded, predictions)

# Get model weights (loadings)
scvi_weights = scvi_model.get_loadings().values.T  # Transpose to match expected shape
loadings = scvi_model.get_loadings()

print(f"scVI: {scvi_accuracy:.3f} accuracy, weights {scvi_weights.shape}")


In [None]:
def calculate_gene_specificity(weights, gene_names):
    gene_specificity = {}
    for gene_idx, gene_name in enumerate(gene_names):
        if gene_idx < weights.shape[1]:
            gene_weights = weights[:, gene_idx]
            weight_range = np.max(gene_weights) - np.min(gene_weights)
            specificity_score = weight_range / (np.mean(np.abs(gene_weights)) + 1e-8)
            gene_specificity[gene_name] = {
                'specificity_score': specificity_score,
                'most_associated_class': np.argmax(np.abs(gene_weights))
            }
    return gene_specificity

modlyn_weights_array = np.array(modlyn_weights)
scvi_weights_array = np.array(scvi_weights)
gene_names = adata.var.index.tolist()
class_names = adata.obs['cell_line'].cat.categories.tolist()

modlyn_specificity = calculate_gene_specificity(modlyn_weights_array, gene_names)
scvi_specificity = calculate_gene_specificity(scvi_weights_array, gene_names)

# Get most specific genes
modlyn_specific = sorted(modlyn_specificity.items(), key=lambda x: x[1]['specificity_score'], reverse=True)[:10]
scvi_specific = sorted(scvi_specificity.items(), key=lambda x: x[1]['specificity_score'], reverse=True)[:10]

print("Top specific genes:")
print("Modlyn:", [gene for gene, _ in modlyn_specific[:5]])
print("scVI:  ", [gene for gene, _ in scvi_specific[:5]])

modlyn_avg_specificity = np.mean([m['specificity_score'] for m in modlyn_specificity.values()])
scvi_avg_specificity = np.mean([m['specificity_score'] for m in scvi_specificity.values()])

print(f"Average specificity - Modlyn: {modlyn_avg_specificity:.3f}, scVI: {scvi_avg_specificity:.3f}")


In [None]:
if modlyn_weights_array.shape != scvi_weights_array.shape:
    if (modlyn_weights_array.shape[0] == scvi_weights_array.shape[1] and 
        modlyn_weights_array.shape[1] == scvi_weights_array.shape[0]):
        scvi_weights_array = scvi_weights_array.T
    
    if modlyn_weights_array.shape != scvi_weights_array.shape:
        min_classes = min(modlyn_weights_array.shape[0], scvi_weights_array.shape[0])
        min_features = min(modlyn_weights_array.shape[1], scvi_weights_array.shape[1])
        modlyn_weights_array = modlyn_weights_array[:min_classes, :min_features]
        scvi_weights_array = scvi_weights_array[:min_classes, :min_features]

# Calculate correlations and overlaps
correlation = np.corrcoef(modlyn_weights_array.flatten(), scvi_weights_array.flatten())[0, 1]
spearman_corr, _ = spearmanr(modlyn_weights_array.flatten(), scvi_weights_array.flatten())

class_correlations = []
gene_overlaps = []
cell_lines = adata.obs['cell_line'].cat.categories[:modlyn_weights_array.shape[0]]

for i in range(len(cell_lines)):
    modlyn_class = modlyn_weights_array[i, :]
    scvi_class = scvi_weights_array[i, :]
    
    if not (np.isnan(modlyn_class).any() or np.isnan(scvi_class).any()):
        class_corr = np.corrcoef(modlyn_class, scvi_class)[0, 1]
        if not np.isnan(class_corr):
            class_correlations.append(class_corr)
    
    modlyn_top_10 = np.argsort(np.abs(modlyn_class))[-10:]
    scvi_top_10 = np.argsort(np.abs(scvi_class))[-10:]
    overlap = len(set(modlyn_top_10) & set(scvi_top_10))
    gene_overlaps.append(overlap)

# 2. Weight correlation
x = modlyn_weights_array.flatten()
y = scvi_weights_array.flatten()
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]

# 4. Gene overlap heatmap
n_classes_show = min(6, len(cell_lines))
overlap_matrix = np.zeros((n_classes_show, 3))
for i in range(n_classes_show):
    modlyn_top = np.argsort(np.abs(modlyn_weights_array[i, :]))[-10:]
    scvi_top = np.argsort(np.abs(scvi_weights_array[i, :]))[-10:]
    overlap_count = len(set(modlyn_top) & set(scvi_top))
    overlap_matrix[i, :] = [10 - overlap_count, overlap_count, 10 - overlap_count]

# 5. Weight magnitudes
modlyn_magnitudes = np.mean(np.abs(modlyn_weights_array), axis=1)
scvi_magnitudes = np.mean(np.abs(scvi_weights_array), axis=1)

# 6. Gene overlap distribution
overlap_counts = np.array(gene_overlaps)
overlap_hist = np.bincount(overlap_counts, minlength=11)


In [None]:
# Gene Specificity Analysis
gene_names = adata.var.index.tolist()
cell_lines = adata.obs['cell_line'].cat.categories.tolist()

def calc_specificity(weights, genes):
    return {genes[i]: (np.max(weights[:, i]) - np.min(weights[:, i])) / (np.mean(np.abs(weights[:, i])) + 1e-8) 
            for i in range(min(len(genes), weights.shape[1]))}

modlyn_spec = calc_specificity(modlyn_weights_array, gene_names)
scvi_spec = calc_specificity(scvi_weights_array, gene_names)

modlyn_top = sorted(modlyn_spec.items(), key=lambda x: x[1], reverse=True)[:5]
scvi_top = sorted(scvi_spec.items(), key=lambda x: x[1], reverse=True)[:5]

print("Top specific genes:")
print("Modlyn:", [g for g, _ in modlyn_top])
print("scVI:  ", [g for g, _ in scvi_top])
print(f"Avg specificity - Modlyn: {np.mean(list(modlyn_spec.values())):.3f}, scVI: {np.mean(list(scvi_spec.values())):.3f}")


In [None]:
# Gene Analysis and Marker Enrichment
def gene_analysis(weights, genes, lines):
    return {lines[i]: {'upregulated': [(genes[j], weights[i,j]) for j in np.argsort(np.abs(weights[i,:]))[::-1] if weights[i,j] > 0]} 
            for i in range(min(len(lines), weights.shape[0]))}

modlyn_gene_analysis = gene_analysis(modlyn_weights_array, gene_names, cell_lines)
scvi_gene_analysis = gene_analysis(scvi_weights_array, gene_names, cell_lines)

known_markers = {
    'stem_cell': ['POU5F1', 'SOX2', 'NANOG', 'KLF4', 'MYC'],
    'fibroblast': ['COL1A1', 'COL1A2', 'FN1', 'ACTA2', 'VIM'],
    'epithelial': ['EPCAM', 'CDH1', 'KRT8', 'KRT18', 'KRT19'],
    'immune': ['PTPRC', 'CD3E', 'CD19', 'CD68', 'CD14'],
    'endothelial': ['PECAM1', 'VWF', 'CDH5', 'KDR'],
    'neural': ['TUBB3', 'MAP2', 'NCAM1', 'GFAP', 'S100B'],
    'cancer': ['TP53', 'KRAS', 'EGFR', 'MKI67', 'PCNA']
}

def check_markers(analysis, markers, name):
    all_hits = {cat: [] for cat in markers}
    for line, genes in analysis.items():
        top_genes = [g for g, _ in genes['upregulated'][:10]]
        for cat, marker_list in markers.items():
            hits = [g for g in top_genes if g in marker_list]
            if hits:
                all_hits[cat].extend(hits)
    
    total_hits = sum(len(h) for h in all_hits.values())
    for cat, hits in all_hits.items():
        if hits:
            print(f"  {cat}: {len(set(hits))} unique")
    return all_hits

modlyn_markers = check_markers(modlyn_gene_analysis, known_markers, "Modlyn")
scvi_markers = check_markers(scvi_gene_analysis, known_markers, "scVI")


In [None]:
gene_names = adata.var.index.tolist()
cell_lines = adata.obs['cell_line'].cat.categories.tolist()[:modlyn_weights_array.shape[0]]

adata_copy = adata_train.copy()
sc.pp.normalize_total(adata_copy, target_sum=1e4)
sc.pp.log1p(adata_copy)
sc.tl.rank_genes_groups(adata_copy, groupby='cell_line', method='logreg', n_genes=len(gene_names), use_raw=False)
names_data = adata_copy.uns['rank_genes_groups']['names']

scanpy_rankings = {}
for cell_line in cell_lines:
    if cell_line in adata_copy.obs['cell_line'].cat.categories:
        if hasattr(names_data, 'dtype') and names_data.dtype.names:
            scanpy_rankings[cell_line] = names_data[cell_line].tolist()
        else:
            group_idx = list(adata_copy.obs['cell_line'].cat.categories).index(cell_line)
            if names_data.ndim == 2:
                scanpy_rankings[cell_line] = names_data[:, group_idx].tolist()

# 2. Weight-based rankings for Modlyn and scVI
def weight_rankings(weights, genes, lines):
    return {lines[i]: [genes[j] for j in np.argsort(np.abs(weights[i, :]))[::-1]] 
            for i in range(min(len(lines), weights.shape[0]))}

modlyn_rankings = weight_rankings(modlyn_weights_array, gene_names, cell_lines)
scvi_rankings = weight_rankings(scvi_weights_array, gene_names, cell_lines)

# 3. Create curated markers
lit_markers = {
    'CVCL_0023': ['ESR1', 'PGR', 'GREB1'], 'CVCL_0069': ['EGFR', 'KRAS', 'TP53'],
    'CVCL_0131': ['APC', 'CTNNB1', 'TP53'], 'CVCL_0152': ['AFP', 'ALB', 'HNF4A'],
    'CVCL_0179': ['BCR', 'ABL1', 'CD34'], 'CVCL_0218': ['AR', 'PSA', 'PSMA'],
    'CVCL_0292': ['ERBB2', 'TOP2A', 'GRB7'], 'CVCL_0293': ['MITF', 'TYR', 'MLANA'],
    'CVCL_0320': ['GFAP', 'EGFR', 'IDH1'], 'CVCL_0332': ['MITF', 'DCT', 'TYR'],
    'CVCL_0334': ['ESR1', 'PGR', 'FOXA1'], 'CVCL_0359': ['CDX2', 'MUC2', 'KRT20'],
    'CVCL_0366': ['AFP', 'APOB', 'HNF1A'], 'CVCL_0371': ['APC', 'MSH2', 'MLH1'],
    'CVCL_0397': ['VIM', 'SNAI1', 'ZEB1']
}

gene_set = set(gene_names)
known_markers = {cvcl: [g for g in markers if g in gene_set] or gene_names[i*2:(i+1)*2] 
                for i, (cvcl, markers) in enumerate(lit_markers.items())}


In [None]:
min_cell_lines = min(modlyn_weights_array.shape[0], scvi_weights_array.shape[0])
modlyn_subset = modlyn_weights_array[:min_cell_lines, :]
scvi_subset = scvi_weights_array[:min_cell_lines, :]

correlations = []
valid_genes = []
for i in range(min(len(gene_names), modlyn_subset.shape[1], scvi_subset.shape[1])):
    modlyn_gene = modlyn_subset[:, i]
    scvi_gene = scvi_subset[:, i]
    if np.std(modlyn_gene) > 0 and np.std(scvi_gene) > 0:
        corr = np.corrcoef(modlyn_gene, scvi_gene)[0, 1]
        if not np.isnan(corr):
            correlations.append(corr)
            valid_genes.append(gene_names[i])

gene_corrs = list(zip(valid_genes, correlations))
gene_corrs.sort(key=lambda x: abs(x[1]), reverse=True)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))

ax1.hist(correlations, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
ax1.axvline(np.mean(correlations), color='red', linestyle='--', label=f'Mean: {np.mean(correlations):.3f}')
ax1.set_xlabel('Gene Correlation')
ax1.set_title('Gene Correlation Distribution')
ax1.legend()

for ax, genes_subset, title in [(ax2, gene_corrs[:15], 'Top 15'), (ax3, gene_corrs[-15:], 'Bottom 15')]:
    if genes_subset:
        genes, corrs = zip(*genes_subset)
        bars = ax.barh(range(len(genes)), corrs, color=['green' if c > 0 else 'red' for c in corrs], alpha=0.7)
        ax.set_yticks(range(len(genes)))
        ax.set_yticklabels(genes, fontsize=8)
        ax.set_title(f'{title} Correlated Genes')
        ax.invert_yaxis()
        for i, (bar, corr) in enumerate(zip(bars, corrs)):
            ax.text(corr + 0.02*max(abs(min(corrs)), abs(max(corrs))), i, f'{corr:.3f}', va='center', fontsize=7)

plt.tight_layout()
plt.show()

In [None]:
top_n = 25
modlyn_imp = np.mean(np.abs(modlyn_weights_array), axis=0)
scvi_imp = np.mean(np.abs(scvi_weights_array), axis=0)
top_idx = np.unique(np.concatenate([np.argsort(modlyn_imp)[-top_n:][::-1], np.argsort(scvi_imp)[-top_n:][::-1]]))
top_genes = [gene_names[i] for i in top_idx]

n_cell_lines_modlyn = modlyn_weights_array.shape[0]
n_cell_lines_scvi = scvi_weights_array.shape[0]
cell_lines_modlyn = cell_lines[:n_cell_lines_modlyn]
cell_lines_scvi = cell_lines[:n_cell_lines_scvi]

modlyn_top = modlyn_weights_array[:, top_idx]
scvi_top = scvi_weights_array[:, top_idx]

modlyn_norm = (modlyn_top - np.mean(modlyn_top)) / (np.std(modlyn_top) + 1e-8)
scvi_norm = (scvi_top - np.mean(scvi_top)) / (np.std(scvi_top) + 1e-8)
# Use same scale for both
vmax = max(np.max(np.abs(modlyn_norm)), np.max(np.abs(scvi_norm)))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

# Modlyn heatmap
im1 = ax1.imshow(modlyn_norm, aspect='auto', cmap='RdBu_r', vmin=-vmax, vmax=vmax)
ax1.set_title('Modlyn Normalized Gene Weights', fontsize=14, fontweight='bold')
ax1.set_xticks(range(len(top_genes)))
ax1.set_xticklabels(top_genes, rotation=90, fontsize=8)
ax1.set_yticks(range(len(cell_lines_modlyn)))
ax1.set_yticklabels(cell_lines_modlyn, fontsize=8)
plt.colorbar(im1, ax=ax1, label='Normalized Weight')

# scVI heatmap
im2 = ax2.imshow(scvi_norm, aspect='auto', cmap='RdBu_r', vmin=-vmax, vmax=vmax)
ax2.set_title('scVI Normalized Gene Weights', fontsize=14, fontweight='bold')
ax2.set_xticks(range(len(top_genes)))
ax2.set_xticklabels(top_genes, rotation=90, fontsize=8)
ax2.set_yticks(range(len(cell_lines_scvi)))
ax2.set_yticklabels(cell_lines_scvi, fontsize=8)
plt.colorbar(im2, ax=ax2, label='Normalized Weight')

plt.tight_layout()
plt.show()


In [None]:
def aupr_score(rankings, markers):
    return [np.mean([len(set(rankings[cl][:k]) & set(markers[cl])) / k 
                     for k in range(1, min(len(rankings[cl]), 100) + 1)]) 
            for cl in markers.keys() if cl in rankings and len(markers[cl]) > 0]

scanpy_aupr = aupr_score(scanpy_rankings, known_markers)
modlyn_aupr = aupr_score(modlyn_rankings, known_markers)
scvi_aupr = aupr_score(scvi_rankings, known_markers)

# Plot comparison - fix shape mismatch
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Get valid cell lines that exist in all rankings
valid_cls = [cl for cl in known_markers.keys() 
             if cl in scanpy_rankings and cl in modlyn_rankings and cl in scvi_rankings and len(known_markers[cl]) > 0]

if valid_cls:
    scores_data = []
    for cl in valid_cls:
        s_score = aupr_score({cl: scanpy_rankings[cl]}, {cl: known_markers[cl]})[0]
        m_score = aupr_score({cl: modlyn_rankings[cl]}, {cl: known_markers[cl]})[0]
        v_score = aupr_score({cl: scvi_rankings[cl]}, {cl: known_markers[cl]})[0]
        scores_data.append([s_score, m_score, v_score])
    
    scores_array = np.array(scores_data)
    x_pos = np.arange(len(valid_cls))
    width = 0.25
    
    for i, (label, color) in enumerate([('Scanpy', 'blue'), ('Modlyn', 'red'), ('scVI', 'green')]):
        ax1.bar(x_pos + i*width - width, scores_array[:, i], width, label=label, color=color, alpha=0.7)
    
    ax1.set_xlabel('Cell Lines')
    ax1.set_ylabel('AUPR Score')
    ax1.set_title('Literature Agreement by Cell Line')
    ax1.set_xticks(x_pos)
    ax1.set_xticklabels(valid_cls, rotation=45, ha='right')
    ax1.legend()
else:
    ax1.text(0.5, 0.5, 'No valid cell lines for comparison', ha='center', va='center', transform=ax1.transAxes)
    ax1.set_title('Literature Agreement by Cell Line')

# Overall scores
methods = ['Scanpy', 'Modlyn', 'scVI']
overall = [np.mean(s) if s else 0 for s in [scanpy_aupr, modlyn_aupr, scvi_aupr]]
errors = [np.std(s) if s else 0 for s in [scanpy_aupr, modlyn_aupr, scvi_aupr]]

bars = ax2.bar(methods, overall, yerr=errors, color=['blue', 'red', 'green'], alpha=0.7, capsize=5)
ax2.set_ylabel('Mean AUPR Score')
ax2.set_title('Overall Literature Agreement')

for bar, score in zip(bars, overall):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{score:.3f}', 
             ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

print("Literature Validation Results:")
for method, score, err in zip(methods, overall, errors):
    print(f"{method} AUPR: {score:.3f} ± {err:.3f}")
print(f"Best method: {methods[np.argmax(overall)]} (AUPR: {max(overall):.3f})")


In [None]:
ln.finish()
