# Notebook for uncertainty estimation, volcano plots, dot plots, heatmaps, and the biological interpretation framework you outlined.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.utils import resample
from scipy import stats
from scipy.stats import norm
import warnings
warnings.filterwarnings('ignore')

In [None]:
class LinearModelAnalyzer:
    """Comprehensive analysis for linear models with uncertainty estimation"""
    
    def __init__(self, model, adata, datamodule=None):
        self.model = model
        self.adata = adata
        self.datamodule = datamodule
        self.weights = model.linear.weight.detach().cpu().numpy()
        self.bias = model.linear.bias.detach().cpu().numpy() if model.linear.bias is not None else None
        
        # Get class and gene names
        if 'y' in adata.obs.columns:
            if hasattr(adata.obs['y'], 'cat'):
                self.class_names = adata.obs['y'].cat.categories.tolist()
            else:
                self.class_names = sorted(adata.obs['y'].unique())
        else:
            self.class_names = [f"Class_{i}" for i in range(model.linear.out_features)]
        
        self.gene_names = [f"Gene_{i:05d}" for i in range(adata.n_vars)]
        self.n_classes, self.n_genes = self.weights.shape
        
    def bootstrap_uncertainty(self, n_bootstrap=100, sample_size=0.8):
        """
        Estimate weight uncertainty using bootstrap sampling
        Returns mean weights and standard errors
        """
        print(f"Computing uncertainty via bootstrap (n={n_bootstrap})...")
        
        if self.datamodule is None:
            print("Warning: No datamodule provided, using simple weight-based uncertainty")
            return self._simple_weight_uncertainty()
        
        bootstrap_weights = []
        self.model.eval()
        
        # Get validation data
        val_loader = self.datamodule.val_dataloader()
        all_x, all_y = [], []
        
        for batch in val_loader:
            x, y = batch
            all_x.append(x.cpu())
            all_y.append(y.cpu())
        
        all_x = torch.cat(all_x)
        all_y = torch.cat(all_y)
        
        n_samples = len(all_x)
        bootstrap_size = int(sample_size * n_samples)
        
        for i in range(n_bootstrap):
            if i % 20 == 0:
                print(f"  Bootstrap {i+1}/{n_bootstrap}")
            
            # Bootstrap sample
            indices = torch.randint(0, n_samples, (bootstrap_size,))
            x_boot = all_x[indices]
            y_boot = all_y[indices]
            
            # Fit simple logistic regression on bootstrap sample
            try:
                # Simple gradient descent for speed
                weights_boot = self._fit_bootstrap_weights(x_boot, y_boot)
                bootstrap_weights.append(weights_boot)
            except:
                continue
        
        if len(bootstrap_weights) > 10:
            bootstrap_weights = np.array(bootstrap_weights)
            weight_means = np.mean(bootstrap_weights, axis=0)
            weight_stds = np.std(bootstrap_weights, axis=0)
            
            print(f"✅ Bootstrap completed with {len(bootstrap_weights)} successful fits")
            return weight_means, weight_stds
        else:
            print("⚠️  Bootstrap failed, using simple uncertainty estimation")
            return self._simple_weight_uncertainty()
    
    def _fit_bootstrap_weights(self, x, y, lr=0.01, n_steps=50):
        """Quick weight fitting for bootstrap"""
        device = x.device
        weights = torch.randn(self.n_classes, self.n_genes, device=device) * 0.01
        weights.requires_grad_(True)
        
        optimizer = torch.optim.Adam([weights], lr=lr)
        
        for _ in range(n_steps):
            logits = torch.mm(x, weights.t())
            loss = F.cross_entropy(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        return weights.detach().cpu().numpy()
    
    def _simple_weight_uncertainty(self):
        """Simple uncertainty based on weight magnitude and class separation"""
        # Estimate uncertainty based on weight statistics
        weight_stds = np.abs(self.weights) * 0.1  # Simple heuristic
        
        # Higher uncertainty for smaller weights
        weight_stds += 0.05 / (np.abs(self.weights) + 0.01)
        
        return self.weights, weight_stds
    
    def create_volcano_plot(self, class1_idx=0, class2_idx=1, uncertainty=None):
        """
        Create volcano plot comparing two classes
        X-axis: log fold change (weight difference)
        Y-axis: -log10(p-value) or significance metric
        """
        class1_name = self.class_names[class1_idx]
        class2_name = self.class_names[class2_idx]
        
        # Calculate log fold change (weight difference)
        log_fc = self.weights[class1_idx] - self.weights[class2_idx]
        
        # Calculate p-values or significance metric
        if uncertainty is not None:
            _, weight_stds = uncertainty
            # T-test like statistic
            se_diff = np.sqrt(weight_stds[class1_idx]**2 + weight_stds[class2_idx]**2)
            t_stats = np.abs(log_fc) / (se_diff + 1e-8)
            p_values = 2 * (1 - norm.cdf(t_stats))  # Two-tailed test
            neg_log_p = -np.log10(p_values + 1e-10)
        else:
            # Use weight magnitude as significance proxy
            neg_log_p = np.log10(np.abs(log_fc) + 0.01)
        
        # Create volcano plot
        plt.figure(figsize=(12, 8))
        
        # Color points by significance and effect size
        colors = ['gray' if (abs(fc) < 0.5 or nlp < 2) else 'red' if fc > 0 else 'blue' 
                 for fc, nlp in zip(log_fc, neg_log_p)]
        
        scatter = plt.scatter(log_fc, neg_log_p, c=colors, alpha=0.6, s=20)
        
        # Add significance thresholds
        plt.axhline(y=2, color='black', linestyle='--', alpha=0.5, label='p=0.01')
        plt.axvline(x=0.5, color='black', linestyle='--', alpha=0.5)
        plt.axvline(x=-0.5, color='black', linestyle='--', alpha=0.5)
        
        plt.xlabel(f'Weight Difference ({class1_name} - {class2_name})')
        plt.ylabel('-log10(p-value)' if uncertainty else 'log10(|Weight Difference|)')
        plt.title(f'Volcano Plot: {class1_name} vs {class2_name}')
        plt.grid(True, alpha=0.3)
        
        # Annotate top genes
        top_genes_idx = np.argsort(neg_log_p)[-10:]
        for idx in top_genes_idx:
            if abs(log_fc[idx]) > 0.3:  # Only annotate if effect size is meaningful
                plt.annotate(self.gene_names[idx], 
                           (log_fc[idx], neg_log_p[idx]),
                           xytext=(5, 5), textcoords='offset points',
                           fontsize=8, alpha=0.8)
        
        plt.tight_layout()
        plt.savefig(f'volcano_plot_{class1_name}_vs_{class2_name}.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return log_fc, neg_log_p
    
    def create_dot_plot(self, top_k=20, uncertainty=None):
        """
        Create dot plot showing top genes per class with effect size and uncertainty
        """
        fig, ax = plt.subplots(figsize=(15, max(8, len(self.class_names) * 0.4)))
        
        # Get top genes per class
        plot_data = []
        for class_idx, class_name in enumerate(self.class_names):
            class_weights = self.weights[class_idx]
            top_indices = np.argsort(np.abs(class_weights))[-top_k:][::-1]
            
            for rank, gene_idx in enumerate(top_indices):
                weight = class_weights[gene_idx]
                uncertainty_val = uncertainty[1][class_idx, gene_idx] if uncertainty else 0.1
                
                plot_data.append({
                    'class': class_name,
                    'gene': self.gene_names[gene_idx],
                    'weight': weight,
                    'abs_weight': abs(weight),
                    'uncertainty': uncertainty_val,
                    'rank': rank,
                    'class_idx': class_idx,
                    'gene_idx': gene_idx
                })
        
        df = pd.DataFrame(plot_data)
        
        # Create dot plot
        for class_idx, class_name in enumerate(self.class_names[:min(20, len(self.class_names))]):
            class_data = df[df['class'] == class_name].head(top_k)
            
            y_pos = class_idx
            x_pos = class_data['weight'].values
            sizes = (class_data['abs_weight'].values / class_data['abs_weight'].max() * 200)
            
            # Color by effect direction
            colors = ['red' if w > 0 else 'blue' for w in x_pos]
            
            ax.scatter(x_pos, [y_pos] * len(x_pos), s=sizes, c=colors, alpha=0.6)
            
            # Add uncertainty bars if available
            if uncertainty:
                uncertainties = class_data['uncertainty'].values
                ax.errorbar(x_pos, [y_pos] * len(x_pos), xerr=uncertainties, 
                           fmt='none', color='black', alpha=0.3, capsize=2)
        
        ax.set_yticks(range(min(20, len(self.class_names))))
        ax.set_yticklabels(self.class_names[:min(20, len(self.class_names))])
        ax.set_xlabel('Gene Weight')
        ax.set_title(f'Top {top_k} Genes per Class (Dot Plot)')
        ax.grid(True, alpha=0.3)
        ax.axvline(x=0, color='black', linestyle='-', alpha=0.5)
        
        plt.tight_layout()
        plt.savefig('dotplot_genes_per_class.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return df
    
    def create_heatmap_analysis(self, top_k=30):
        """
        Create comprehensive heatmap analysis
        """
        # 1. Gene importance heatmap
        gene_importance = np.mean(np.abs(self.weights), axis=0)
        top_gene_indices = np.argsort(gene_importance)[-top_k:][::-1]
        
        # Select subset of classes for readability
        n_classes_show = min(20, len(self.class_names))
        class_subset = range(0, len(self.class_names), max(1, len(self.class_names) // n_classes_show))[:n_classes_show]
        
        weights_subset = self.weights[np.ix_(class_subset, top_gene_indices)]
        
        plt.figure(figsize=(15, 10))
        
        # Create heatmap
        sns.heatmap(weights_subset, 
                   xticklabels=[self.gene_names[i] for i in top_gene_indices],
                   yticklabels=[self.class_names[i] for i in class_subset],
                   cmap='RdBu_r', center=0, 
                   cbar_kws={'label': 'Gene Weight'})
        
        plt.title(f'Heatmap: Top {top_k} Genes vs Classes')
        plt.xlabel('Genes')
        plt.ylabel('Classes')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.savefig('heatmap_genes_vs_classes.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # 2. Class similarity heatmap
        plt.figure(figsize=(12, 10))
        class_correlations = np.corrcoef(self.weights)
        
        sns.heatmap(class_correlations, 
                   xticklabels=self.class_names,
                   yticklabels=self.class_names,
                   cmap='coolwarm', center=0,
                   square=True,
                   cbar_kws={'label': 'Correlation'})
        
        plt.title('Class Similarity (Weight Pattern Correlation)')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.savefig('heatmap_class_similarity.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return top_gene_indices, class_correlations
    
    def analyze_confounders_vs_biology(self):
        """
        Analyze confounders (plate effects) vs biological variables
        """
        print("\n" + "="*60)
        print("CONFOUNDER vs BIOLOGICAL ANALYSIS")
        print("="*60)
        
        # Identify potential confounders and biological variables
        obs_columns = self.adata.obs.columns.tolist()
        
        # Confounders (technical variables)
        confounders = [col for col in obs_columns if any(x in col.lower() for x in 
                      ['plate', 'batch', 'barcode', 'sublibrary', 'sample'])]
        
        # Biological variables  
        biological = [col for col in obs_columns if any(x in col.lower() for x in 
                     ['drug', 'cell_line', 'cell_type', 'tissue', 'treatment'])]
        
        print(f"Potential confounders: {confounders}")
        print(f"Biological variables: {biological}")
        
        # Analyze variance explained by each
        variance_analysis = {}
        
        for var_type, variables in [('Confounders', confounders), ('Biological', biological)]:
            print(f"\n{var_type}:")
            for var in variables:
                if var in self.adata.obs.columns:
                    unique_vals = self.adata.obs[var].nunique()
                    print(f"  {var}: {unique_vals} unique values")
                    variance_analysis[var] = {
                        'type': var_type,
                        'unique_values': unique_vals
                    }
        
        return variance_analysis
    
    def create_weight_umap(self, n_components=2):
        """
        Create UMAP visualization of gene weights (genes as points)
        """
        try:
            from umap import UMAP
        except ImportError:
            print("UMAP not available. Install with: pip install umap-learn")
            return None
        
        print("Creating UMAP of gene weight patterns...")
        
        # Transpose weights so genes are rows, classes are features
        weights_for_umap = self.weights.T  # Shape: (n_genes, n_classes)
        
        # Apply UMAP
        umap_model = UMAP(n_components=n_components, random_state=42, n_neighbors=15, min_dist=0.1)
        gene_embedding = umap_model.fit_transform(weights_for_umap)
        
        # Calculate gene importance for coloring
        gene_importance = np.mean(np.abs(weights_for_umap), axis=1)
        
        plt.figure(figsize=(12, 8))
        scatter = plt.scatter(gene_embedding[:, 0], gene_embedding[:, 1], 
                            c=gene_importance, cmap='viridis', alpha=0.6, s=20)
        plt.colorbar(scatter, label='Gene Importance')
        plt.xlabel('UMAP 1')
        plt.ylabel('UMAP 2')
        plt.title('UMAP of Gene Weight Patterns')
        
        # Annotate top genes
        top_gene_indices = np.argsort(gene_importance)[-20:]
        for idx in top_gene_indices:
            plt.annotate(self.gene_names[idx], 
                        (gene_embedding[idx, 0], gene_embedding[idx, 1]),
                        xytext=(5, 5), textcoords='offset points',
                        fontsize=8, alpha=0.7)
        
        plt.tight_layout()
        plt.savefig('umap_gene_weights.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return gene_embedding
    
    def comprehensive_analysis(self):
        """
        Run the complete analysis pipeline
        """
        print("🚀 Starting Comprehensive Linear Model Analysis")
        print("="*60)
        
        # 1. Estimate uncertainty
        print("\n📊 Step 1: Estimating weight uncertainty...")
        uncertainty = self.bootstrap_uncertainty(n_bootstrap=50)
        
        # 2. Create volcano plots for top class comparisons
        print("\n🌋 Step 2: Creating volcano plots...")
        # Compare first few classes
        for i in range(min(3, len(self.class_names)-1)):
            self.create_volcano_plot(i, i+1, uncertainty)
        
        # 3. Create dot plot
        print("\n🔴 Step 3: Creating dot plot...")
        dot_data = self.create_dot_plot(top_k=15, uncertainty=uncertainty)
        
        # 4. Create heatmaps
        print("\n🔥 Step 4: Creating heatmaps...")
        top_genes, class_corr = self.create_heatmap_analysis(top_k=25)
        
        # 5. Analyze confounders vs biology
        print("\n🧬 Step 5: Analyzing confounders vs biology...")
        variance_analysis = self.analyze_confounders_vs_biology()
        
        # 6. Create UMAP
        print("\n🗺️  Step 6: Creating UMAP visualization...")
        gene_embedding = self.create_weight_umap()
        
        # 7. Summary statistics
        print("\n📈 Step 7: Summary statistics...")
        self._print_summary_stats(uncertainty, top_genes)
        
        print("\n✅ Analysis complete! Check the generated plots.")
        
        return {
            'uncertainty': uncertainty,
            'dot_data': dot_data,
            'top_genes': top_genes,
            'class_correlations': class_corr,
            'variance_analysis': variance_analysis,
            'gene_embedding': gene_embedding
        }
    
    def _print_summary_stats(self, uncertainty, top_genes):
        """Print summary statistics"""
        weights_mean, weights_std = uncertainty
        
        print(f"Model has {self.n_classes} classes and {self.n_genes} genes")
        print(f"Average weight magnitude: {np.mean(np.abs(self.weights)):.4f}")
        print(f"Average weight uncertainty: {np.mean(weights_std):.4f}")
        print(f"Most variable class: {self.class_names[np.argmax(np.var(self.weights, axis=1))]}")
        print(f"Most important gene: {self.gene_names[top_genes[0]]}")

In [None]:
# Usage
def run_comprehensive_analysis(model, adata, datamodule=None):
    """
    Main function to run all analyses
    """
    analyzer = LinearModelAnalyzer(model, adata, datamodule)
    results = analyzer.comprehensive_analysis()
    return analyzer, results

# Quick analysis function for immediate results
def quick_analysis(model, adata, datamodule=None):
    """
    Quick version focusing on key visualizations
    """
    analyzer = LinearModelAnalyzer(model, adata, datamodule)
    
    print("🚀 Quick Analysis Starting...")
    
    # Simple uncertainty (fast)
    uncertainty = analyzer._simple_weight_uncertainty()
    
    # Key visualizations
    analyzer.create_volcano_plot(0, 1, (uncertainty[0], uncertainty[1]))
    dot_data = analyzer.create_dot_plot(top_k=10, uncertainty=(uncertainty[0], uncertainty[1]))
    top_genes, _ = analyzer.create_heatmap_analysis(top_k=20)
    
    print("✅ Quick analysis complete!")
    
    return analyzer