In [None]:
import numpy as np
import pandas as pd
import anndata as ad
import torch
import lightning as L
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from modlyn.models import SimpleLogReg, SimpleLogRegDataModule
import lamindb as ln

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

In [None]:
# Create simple synthetic data with known structure
n_samples = 1000
n_features = 20
n_classes = 3

# Create linearly separable data
np.random.seed(42)
X_synthetic = np.random.randn(n_samples, n_features)

# Create clear linear decision boundaries
true_weights = np.random.randn(n_classes, n_features)
true_bias = np.random.randn(n_classes)

# Generate labels based on linear model + noise
scores = X_synthetic @ true_weights.T + true_bias
y_synthetic = np.argmax(scores, axis=1)

print(f"Synthetic data shape: {X_synthetic.shape}")
print(f"Class distribution: {np.bincount(y_synthetic)}")
print(f"True weights shape: {true_weights.shape}")

# Create AnnData object
adata_synthetic = ad.AnnData(X=X_synthetic)
adata_synthetic.obs['y'] = y_synthetic
adata_synthetic.obs['cell_line'] = [f'class_{i}' for i in y_synthetic]
adata_synthetic.var_names = [f'feature_{i}' for i in range(n_features)]


In [None]:
# Split data identically for both methods
X_train, X_val, y_train, y_val = train_test_split(
    X_synthetic, y_synthetic, test_size=0.2, random_state=42, stratify=y_synthetic
)

print("=== SKLEARN LOGISTIC REGRESSION ANALYSIS ===")

# Default sklearn (with L2 regularization)
sklearn_default = LogisticRegression(random_state=42, max_iter=1000)
sklearn_default.fit(X_train, y_train)
acc_default = sklearn_default.score(X_train, y_train)
print(f"Sklearn (default L2): Train accuracy = {acc_default:.4f}")

# No regularization
sklearn_no_reg = LogisticRegression(C=1e10, random_state=42, max_iter=1000)  # Very high C = low regularization
sklearn_no_reg.fit(X_train, y_train)
acc_no_reg = sklearn_no_reg.score(X_train, y_train)
print(f"Sklearn (no regularization): Train accuracy = {acc_no_reg:.4f}")

# High regularization
sklearn_high_reg = LogisticRegression(C=0.01, random_state=42, max_iter=1000)  # Low C = high regularization
sklearn_high_reg.fit(X_train, y_train)
acc_high_reg = sklearn_high_reg.score(X_train, y_train)
print(f"Sklearn (high regularization): Train accuracy = {acc_high_reg:.4f}")

print(f"\nSklearn default parameters:")
print(f"C (inverse regularization): {sklearn_default.C}")
print(f"Penalty: {sklearn_default.penalty}")
print(f"Solver: {sklearn_default.solver}")


In [None]:
print("=== MODLYN ANALYSIS ===")

# Create identical training data for modlyn
adata_train = ad.AnnData(X=X_train)
adata_train.obs['y'] = y_train
adata_train.obs['cell_line'] = [f'class_{i}' for i in y_train]
adata_train.var_names = [f'feature_{i}' for i in range(X_train.shape[1])]

adata_val = ad.AnnData(X=X_val)
adata_val.obs['y'] = y_val
adata_val.obs['cell_line'] = [f'class_{i}' for i in y_val]
adata_val.var_names = [f'feature_{i}' for i in range(X_val.shape[1])]

# Test different weight_decay values to match sklearn's regularization
# sklearn C=1.0 (default) roughly corresponds to weight_decay = 1/C = 1.0
weight_decay_values = [0.0, 0.01, 1.0, 100.0]  # Test range from no reg to high reg

modlyn_results = {}

for wd in weight_decay_values:
    print(f"\nTesting weight_decay = {wd}")
    
    # Create datamodule
    datamodule = SimpleLogRegDataModule(
        adata_train=adata_train,
        adata_val=adata_val,
        label_column="y",
        train_dataloader_kwargs={"batch_size": len(adata_train), "num_workers": 0},  # Full batch
        val_dataloader_kwargs={"batch_size": len(adata_val), "num_workers": 0}
    )
    
    # Create model with specific weight_decay
    model = SimpleLogReg(
        adata=adata_train,
        label_column="y",
        learning_rate=1e-2,  # Start with a reasonable learning rate
        weight_decay=wd
    )
    
    # Train with more epochs to ensure convergence
    trainer = L.Trainer(
        max_epochs=100,  # More epochs for convergence
        enable_progress_bar=False,
        logger=False,
        enable_checkpointing=False
    )
    
    trainer.fit(model=model, datamodule=datamodule)
    
    # Get predictions and accuracy
    with torch.no_grad():
        X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
        predictions = model(X_train_tensor)
        predicted_classes = predictions.argmax(dim=1).numpy()
        accuracy = (predicted_classes == y_train).mean()
    
    print(f"  Train accuracy: {accuracy:.4f}")
    
    # Store results
    weights = model.linear.weight.detach().cpu().numpy()
    modlyn_results[wd] = {
        'accuracy': accuracy,
        'weights': weights,
        'model': model
    }


In [None]:
print("=== WEIGHT CORRELATION ANALYSIS ===")

# Compare modlyn weights with sklearn weights
sklearn_weights_default = sklearn_default.coef_  # Shape: (n_classes, n_features)
sklearn_weights_no_reg = sklearn_no_reg.coef_

correlations = {}

for wd, results in modlyn_results.items():
    modlyn_weights = results['weights']  # Shape: (n_classes, n_features)
    
    # Calculate correlation between flattened weight matrices
    corr_default = np.corrcoef(modlyn_weights.flatten(), sklearn_weights_default.flatten())[0, 1]
    corr_no_reg = np.corrcoef(modlyn_weights.flatten(), sklearn_weights_no_reg.flatten())[0, 1]
    
    correlations[wd] = {
        'vs_sklearn_default': corr_default,
        'vs_sklearn_no_reg': corr_no_reg,
        'modlyn_accuracy': results['accuracy']
    }
    
    print(f"Weight_decay {wd}:")
    print(f"  vs sklearn default: {corr_default:.4f}")
    print(f"  vs sklearn no reg: {corr_no_reg:.4f}")
    print(f"  modlyn accuracy: {results['accuracy']:.4f}")

print(f"\nSklearn accuracies for reference:")
print(f"  Default: {acc_default:.4f}")
print(f"  No reg: {acc_no_reg:.4f}")
print(f"  High reg: {acc_high_reg:.4f}")

# Find the best matching configuration
best_wd = max(correlations.keys(), key=lambda x: correlations[x]['vs_sklearn_default'])
best_corr = correlations[best_wd]['vs_sklearn_default']

print(f"\n🎯 BEST MATCH FOUND:")
print(f"   weight_decay = {best_wd}")
print(f"   correlation = {best_corr:.4f}")

# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# 1. Accuracy comparison
wd_values = list(correlations.keys())
modlyn_accs = [correlations[wd]['modlyn_accuracy'] for wd in wd_values]

axes[0].plot(wd_values, modlyn_accs, 'o-', label='Modlyn', linewidth=2)
axes[0].axhline(acc_default, color='red', linestyle='--', label='Sklearn default')
axes[0].axhline(acc_no_reg, color='orange', linestyle='--', label='Sklearn no reg')
axes[0].set_xlabel('Weight Decay')
axes[0].set_ylabel('Training Accuracy')
axes[0].set_title('Accuracy vs Regularization')
axes[0].set_xscale('symlog')
axes[0].legend()
axes[0].grid(True)

# 2. Correlation vs weight decay
corr_values = [correlations[wd]['vs_sklearn_default'] for wd in wd_values]

axes[1].plot(wd_values, corr_values, 'o-', color='green', linewidth=2)
axes[1].axhline(0.95, color='red', linestyle='--', alpha=0.7, label='Good match (>0.95)')
axes[1].set_xlabel('Weight Decay')
axes[1].set_ylabel('Weight Correlation with Sklearn')
axes[1].set_title('Correlation vs Regularization')
axes[1].set_xscale('symlog')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

print("\n=== RECOMMENDATIONS FOR YOUR VALIDATION NOTEBOOK ===")
print(f"Replace in your SimpleLogReg creation:")
print(f"  weight_decay={best_wd}  # (current issue: likely too low or zero)")
print(f"  learning_rate=1e-2")
print(f"  max_epochs=100  # (current issue: likely too few epochs)")
print(f"  batch_size=len(adata_train)  # (use full batch for small datasets)")

if best_corr > 0.95:
    print(f"\n✅ Expected improvement: EXCELLENT match (correlation = {best_corr:.3f})")
elif best_corr > 0.8:
    print(f"\n⚠️  Expected improvement: GOOD match (correlation = {best_corr:.3f})")
    print(f"   May need additional tuning")
else:
    print(f"\n❌ Expected improvement: LIMITED (correlation = {best_corr:.3f})")
    print(f"   May need to investigate other factors")
