In [None]:
import pandas as pd

from ChromaVDB.chroma import ChromaFramework
from tqdm.notebook import tqdm

vdb = ChromaFramework(persist_directory="./ChromaVDB/chroma_db")
records = vdb.list_records()
names = [record['name'] for record in records]
embs = [record['embeddings'] for record in records]

In [None]:
data = pd.read_excel('data/2025_03_29.xlsx')

In [None]:
# gene_set = "LNF"
gene_set = "plasma"
# gene_measure = "MUT"
gene_measure = "VAF"

gene_data = data[[col for col in data.columns if gene_set in col and gene_measure in col]]

In [None]:
genes = list(set([ gene.split('_')[0] for gene in gene_data.columns  ]))

final_columns = []
embeddings = []

for gene in genes:
    if gene in names:
        final_columns.append(gene+"_"+gene_set+"_"+gene_measure)
        embeddings.append(embs[names.index(gene)])
    else:
        print(gene)

print(len(final_columns))

In [None]:
gene_data = gene_data[final_columns]
gene_data['pfs'] = data['PFS_Cens_updated']
# gene_data = gene_data.dropna()
gene_data = gene_data.fillna(0)

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score, 
                            roc_auc_score, confusion_matrix, roc_curve, precision_recall_curve)
import matplotlib.pyplot as plt
import seaborn as sns

# GPU Configuration and Optimization
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("GPU device:", torch.cuda.get_device_name(0))
    
    # Set GPU memory management
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

class PFSClassificationMLP(nn.Module):
    """
    MLP model that combines gene embeddings with patient binary mutation features
    to predict binary PFS outcomes (0=poor prognosis, 1=good prognosis).
    
    Data structure:
    - gene_embeddings: (n_genes, embedding_dim) - GNN embeddings for each gene
    - patient_mutations: (n_patients, n_genes) - binary mutation matrix
    - pfs_binary: (n_patients,) - binary PFS outcome for each patient
    """
    
    def __init__(self, n_genes, embedding_dim, hidden_dims=[512, 256, 128], 
                 dropout_rate=0.3, combination_method='weighted_sum'):
        super(PFSClassificationMLP, self).__init__()
        
        self.n_genes = n_genes
        self.embedding_dim = embedding_dim
        self.combination_method = combination_method
        
        # Gene-level fusion methods
        if combination_method == 'weighted_sum':
            # Weight each gene embedding by mutation status and sum across genes
            self.gene_weight = nn.Linear(1, 1, bias=False)  # Learn importance of mutation
            patient_repr_dim = embedding_dim
            
        elif combination_method == 'attention':
            # Use attention mechanism over genes
            self.gene_attention = nn.MultiheadAttention(
                embed_dim=embedding_dim, 
                num_heads=8, 
                batch_first=True
            )
            # Project mutation status to same dim as embeddings for attention
            self.mutation_proj = nn.Linear(1, embedding_dim)
            patient_repr_dim = embedding_dim
            
        elif combination_method == 'transformer':
            # More sophisticated gene-level transformer
            self.gene_transform = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    d_model=embedding_dim + 1,  # embedding + mutation status
                    nhead=8,
                    dim_feedforward=embedding_dim * 2,
                    dropout=dropout_rate,
                    batch_first=True
                ),
                num_layers=2
            )
            patient_repr_dim = embedding_dim + 1
            
        elif combination_method == 'gene_mlp':
            # MLP for each gene, then aggregate
            self.gene_mlp = nn.Sequential(
                nn.Linear(embedding_dim + 1, embedding_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Linear(embedding_dim, embedding_dim // 2)
            )
            patient_repr_dim = embedding_dim // 2
            
        else:
            raise ValueError("combination_method must be 'weighted_sum', 'attention', 'transformer', or 'gene_mlp'")
        
        # Patient-level MLP for final binary classification
        layers = []
        prev_dim = patient_repr_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.BatchNorm1d(hidden_dim)
            ])
            prev_dim = hidden_dim
        
        # Output layer for binary classification (logits)
        layers.append(nn.Linear(prev_dim, 1))
        
        self.patient_mlp = nn.Sequential(*layers)
        
    def forward(self, patient_mutations, gene_embeddings):
        """
        Forward pass combining gene embeddings with patient mutations.
        
        Args:
            patient_mutations: (batch_size, n_genes) binary mutation matrix
            gene_embeddings: (n_genes, embedding_dim) gene embeddings from GNN
        
        Returns:
            logits: (batch_size, 1) logits for binary classification
        """
        batch_size = patient_mutations.shape[0]
        
        if self.combination_method == 'weighted_sum':
            # For each patient, weight gene embeddings by mutation status
            mutations_expanded = patient_mutations.unsqueeze(-1)  # (batch_size, n_genes, 1)
            
            # Apply learned weight to mutation status
            weighted_mutations = self.gene_weight(mutations_expanded)  # (batch_size, n_genes, 1)
            
            # Weight gene embeddings by mutation status
            gene_emb_expanded = gene_embeddings.unsqueeze(0)  # (1, n_genes, embedding_dim)
            
            # Broadcast multiply: (batch_size, n_genes, 1) * (1, n_genes, embedding_dim)
            weighted_gene_embs = weighted_mutations * gene_emb_expanded  # (batch_size, n_genes, embedding_dim)
            
            # Sum across genes for each patient
            patient_repr = weighted_gene_embs.sum(dim=1)  # (batch_size, embedding_dim)
            
        elif self.combination_method == 'attention':
            # Use attention mechanism to focus on relevant genes
            mutations_expanded = patient_mutations.unsqueeze(-1)  # (batch_size, n_genes, 1)
            
            # Project mutations to embedding dimension
            mutation_features = self.mutation_proj(mutations_expanded)  # (batch_size, n_genes, embedding_dim)
            
            # Add gene embeddings to mutation features
            gene_emb_batch = gene_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)  # (batch_size, n_genes, embedding_dim)
            combined_features = gene_emb_batch + mutation_features  # (batch_size, n_genes, embedding_dim)
            
            # Self-attention over genes
            attended_features, _ = self.gene_attention(
                combined_features, combined_features, combined_features
            )  # (batch_size, n_genes, embedding_dim)
            
            # Average pool across genes
            patient_repr = attended_features.mean(dim=1)  # (batch_size, embedding_dim)
            
        elif self.combination_method == 'transformer':
            # Concatenate gene embeddings with mutation status
            mutations_expanded = patient_mutations.unsqueeze(-1)  # (batch_size, n_genes, 1)
            gene_emb_batch = gene_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)  # (batch_size, n_genes, embedding_dim)
            
            # Concatenate: (batch_size, n_genes, embedding_dim + 1)
            gene_features = torch.cat([gene_emb_batch, mutations_expanded], dim=-1)
            
            # Apply transformer encoder
            transformed = self.gene_transform(gene_features)  # (batch_size, n_genes, embedding_dim + 1)
            
            # Global average pooling across genes
            patient_repr = transformed.mean(dim=1)  # (batch_size, embedding_dim + 1)
            
        elif self.combination_method == 'gene_mlp':
            # Process each gene with MLP, then aggregate
            mutations_expanded = patient_mutations.unsqueeze(-1)  # (batch_size, n_genes, 1)
            gene_emb_batch = gene_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)  # (batch_size, n_genes, embedding_dim)
            
            # Concatenate and reshape for MLP processing
            gene_features = torch.cat([gene_emb_batch, mutations_expanded], dim=-1)  # (batch_size, n_genes, embedding_dim + 1)
            
            # Reshape to process all genes at once: (batch_size * n_genes, embedding_dim + 1)
            gene_features_flat = gene_features.view(-1, self.embedding_dim + 1)
            
            # Apply MLP to each gene
            gene_processed = self.gene_mlp(gene_features_flat)  # (batch_size * n_genes, embedding_dim // 2)
            
            # Reshape back: (batch_size, n_genes, embedding_dim // 2)
            gene_processed = gene_processed.view(batch_size, self.n_genes, -1)
            
            # Sum across genes for each patient
            patient_repr = gene_processed.sum(dim=1)  # (batch_size, embedding_dim // 2)
        
        # Final MLP for binary classification (returns logits)
        logits = self.patient_mlp(patient_repr)
        return logits

class PFSClassifier:
    """
    Main class for training and evaluating binary PFS classification models with gene-level fusion.
    """
    
    def __init__(self, combination_method='weighted_sum', device='cuda'):
        if torch.cuda.is_available():
            self.device = device
            torch.cuda.empty_cache()
            gpu_name = torch.cuda.get_device_name(0)
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
            print(f"Using GPU: {gpu_name} ({gpu_memory:.1f}GB)")
        else:
            self.device = 'cpu'
            print("CUDA not available, using CPU")
                
        self.combination_method = combination_method
        self.model = None
        
    def prepare_data(self, patient_mutations, gene_embeddings, pfs_binary):
        """
        Prepare data for training.
        
        Args:
            patient_mutations: (n_patients, n_genes) binary mutation matrix
            gene_embeddings: (n_genes, embedding_dim) gene embeddings from GNN
            pfs_binary: (n_patients,) binary PFS outcomes (0/1)
        
        Returns:
            Tensors ready for training
        """
        # Convert to tensors (no normalization needed for binary classification)
        mutations_tensor = torch.FloatTensor(patient_mutations)  # Keep binary
        embeddings_tensor = torch.FloatTensor(gene_embeddings)   # Already normalized from GNN
        pfs_tensor = torch.FloatTensor(pfs_binary)               # Binary labels
        
        return mutations_tensor, embeddings_tensor, pfs_tensor
    
    def create_model(self, n_genes, embedding_dim, **model_kwargs):
        """Create and initialize the model."""
        self.model = PFSClassificationMLP(
            n_genes=n_genes,
            embedding_dim=embedding_dim,
            combination_method=self.combination_method,
            **model_kwargs
        ).to(self.device)
        
    def train(self, patient_mutations, gene_embeddings, pfs_binary, 
              test_size=0.15, validation_size=0.20, batch_size=512, 
              epochs=200, lr=0.001, weight_decay=1e-4, 
              early_stopping_patience=50, verbose=True):
        """
        Train the binary PFS classification model.
        
        Args:
            patient_mutations: (n_patients, n_genes) binary mutation matrix
            gene_embeddings: (n_genes, embedding_dim) gene embeddings from GNN
            pfs_binary: (n_patients,) binary PFS outcomes (0/1)
            
        Returns:
            training_history: dict with loss curves and metrics
        """
        
        # Prepare data
        mut_tensor, emb_tensor, pfs_tensor = self.prepare_data(patient_mutations, gene_embeddings, pfs_binary)
        n_patients = len(patient_mutations)
        n_genes, embedding_dim = gene_embeddings.shape
        
        # Check class distribution and ensure minimum samples per class
        class_counts = np.bincount(pfs_binary.astype(int))
        if min(class_counts) < 5:
            print(f"Warning: Very few samples in minority class: {class_counts}")
            print("Consider collecting more data or using different evaluation strategy")
        
        # Calculate class weights more robustly
        if class_counts[1] == 0:
            pos_weight = torch.FloatTensor([1.0]).to(self.device)
        elif class_counts[0] == 0:
            pos_weight = torch.FloatTensor([1.0]).to(self.device)
        else:
            pos_weight = torch.FloatTensor([class_counts[0] / class_counts[1]]).to(self.device)
        
        if verbose:
            print(f"Training binary classifier on {self.device.upper()}")
            print(f"Data: {n_patients} patients, {n_genes} genes, {embedding_dim}D gene embeddings")
            print(f"Class distribution: {class_counts[0]} negative (0), {class_counts[1]} positive (1)")
            print(f"Class balance ratio: {class_counts[1] / len(pfs_binary):.3f}")
            print(f"Positive class weight: {pos_weight.item():.3f}")
        
        # Split data with better stratification
        indices = np.arange(n_patients)
        
        # Ensure both classes are present in splits
        if len(np.unique(pfs_binary)) < 2:
            print("Warning: Only one class present in dataset!")
            # Use regular split without stratification
            train_idx, temp_idx = train_test_split(indices, test_size=test_size+validation_size, random_state=42)
            val_idx, test_idx = train_test_split(temp_idx, test_size=test_size/(test_size+validation_size), random_state=42)
        else:
            try:
                train_idx, temp_idx = train_test_split(indices, test_size=test_size+validation_size, 
                                                     random_state=42, stratify=pfs_binary)
                val_idx, test_idx = train_test_split(temp_idx, test_size=test_size/(test_size+validation_size), 
                                                   random_state=42, stratify=pfs_binary[temp_idx])
            except ValueError as e:
                print(f"Stratification failed: {e}. Using random split.")
                train_idx, temp_idx = train_test_split(indices, test_size=test_size+validation_size, random_state=42)
                val_idx, test_idx = train_test_split(temp_idx, test_size=test_size/(test_size+validation_size), random_state=42)
        
        # Check class distribution in splits
        train_classes = np.bincount(pfs_binary[train_idx].astype(int), minlength=2)
        val_classes = np.bincount(pfs_binary[val_idx].astype(int), minlength=2)
        test_classes = np.bincount(pfs_binary[test_idx].astype(int), minlength=2)
        
        if verbose:
            print(f"Train split: {train_classes[0]} negative, {train_classes[1]} positive")
            print(f"Val split: {val_classes[0]} negative, {val_classes[1]} positive") 
            print(f"Test split: {test_classes[0]} negative, {test_classes[1]} positive")
        
        # Create splits and move to GPU
        train_mut = mut_tensor[train_idx].to(self.device)
        train_pfs = pfs_tensor[train_idx].to(self.device)
        
        val_mut = mut_tensor[val_idx].to(self.device)
        val_pfs = pfs_tensor[val_idx].to(self.device)
        
        test_mut = mut_tensor[test_idx].to(self.device)
        test_pfs = pfs_tensor[test_idx].to(self.device)
        
        # Gene embeddings stay constant (move to GPU once)
        gene_emb_gpu = emb_tensor.to(self.device)
        
        if verbose and self.device == 'cuda':
            print(f"Data moved to GPU. Memory usage: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
        
        # Create model with better initialization
        if self.model is None:
            self.create_model(n_genes, embedding_dim)
            # Initialize weights properly
            for module in self.model.modules():
                if isinstance(module, nn.Linear):
                    nn.init.xavier_uniform_(module.weight)
                    if module.bias is not None:
                        nn.init.zeros_(module.bias)
                elif isinstance(module, nn.BatchNorm1d):
                    nn.init.ones_(module.weight)
                    nn.init.zeros_(module.bias)
            
        if verbose and self.device == 'cuda':
            model_params = sum(p.numel() for p in self.model.parameters())
            print(f"Model created with {model_params:,} parameters. GPU memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
        
        # Create data loader
        train_dataset = TensorDataset(train_mut, train_pfs)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False)
        
        # Optimizer and loss (with class weighting)
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5, mode='max')
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        
        # Training history
        history = {
            'train_losses': [], 'val_losses': [], 'test_losses': [],
            'train_acc': [], 'val_acc': [], 'test_acc': [],
            'train_auc': [], 'val_auc': [], 'test_auc': []
        }
        
        best_val_auc = 0.0
        patience_counter = 0
        
        if verbose:
            print(f"Starting training for {epochs} epochs...")
        
        for epoch in range(epochs):
            # Training
            self.model.train()
            epoch_train_loss = 0
            train_logits = []
            train_targets = []
            
            for batch_mut, batch_pfs in train_loader:
                optimizer.zero_grad()
                logits = self.model(batch_mut, gene_emb_gpu).squeeze()
                loss = criterion(logits, batch_pfs)
                loss.backward()
                optimizer.step()
                
                epoch_train_loss += loss.item()
                train_logits.append(logits.detach().cpu().numpy())
                train_targets.append(batch_pfs.detach().cpu().numpy())
            
            # Validation
            self.model.eval()
            with torch.no_grad():
                val_logits = self.model(val_mut, gene_emb_gpu).squeeze()
                val_loss = criterion(val_logits, val_pfs)
                
                # Test evaluation
                test_logits = self.model(test_mut, gene_emb_gpu).squeeze()
                test_loss = criterion(test_logits, test_pfs)
            
            # Calculate metrics with robust error handling
            train_logits_all = np.concatenate(train_logits)
            train_targets_all = np.concatenate(train_targets)
            train_probs = torch.sigmoid(torch.tensor(train_logits_all)).numpy()
            train_preds = (train_probs > 0.5).astype(int)
            
            val_probs = torch.sigmoid(val_logits).cpu().numpy()
            val_preds = (val_probs > 0.5).astype(int)
            val_targets_np = val_pfs.cpu().numpy()
            
            test_probs = torch.sigmoid(test_logits).cpu().numpy()
            test_preds = (test_probs > 0.5).astype(int)
            test_targets_np = test_pfs.cpu().numpy()
            
            # Use robust metrics calculation
            train_metrics = calculate_robust_metrics(train_targets_all, train_preds, train_probs)
            val_metrics = calculate_robust_metrics(val_targets_np, val_preds, val_probs)
            test_metrics = calculate_robust_metrics(test_targets_np, test_preds, test_probs)
            
            train_acc = train_metrics['accuracy']
            val_acc = val_metrics['accuracy']
            test_acc = test_metrics['accuracy']
            
            train_auc = train_metrics['auc']
            val_auc = val_metrics['auc']
            test_auc = test_metrics['auc']
            
            # Update history
            history['train_losses'].append(epoch_train_loss / len(train_loader))
            history['val_losses'].append(val_loss.item())
            history['test_losses'].append(test_loss.item())
            history['train_acc'].append(train_acc)
            history['val_acc'].append(val_acc)
            history['test_acc'].append(test_acc)
            history['train_auc'].append(train_auc)
            history['val_auc'].append(val_auc)
            history['test_auc'].append(test_auc)
            
            # Learning rate scheduling (based on validation AUC)
            scheduler.step(val_auc)
            
            # Early stopping (based on validation AUC)
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                patience_counter = 0
                self.best_model_state = {k: v.clone() for k, v in self.model.state_dict().items()}
            else:
                patience_counter += 1
                
            if patience_counter >= early_stopping_patience:
                if verbose:
                    print(f"Early stopping at epoch {epoch}")
                break
            
            if verbose and epoch % 25 == 0:
                print(f'Epoch {epoch}: Train Loss = {history["train_losses"][-1]:.4f}, '
                      f'Val Loss = {history["val_losses"][-1]:.4f}, '
                      f'Val Acc = {val_acc:.4f}, Val AUC = {val_auc:.4f}')
        
        # Load best model
        if hasattr(self, 'best_model_state'):
            self.model.load_state_dict(self.best_model_state)
        
        # Final evaluation with detailed metrics
        if verbose:
            print("\n=== Final Training Summary ===")
            self.model.eval()
            with torch.no_grad():
                # Evaluate on all splits
                train_logits_final = self.model(train_mut, gene_emb_gpu).squeeze()
                val_logits_final = self.model(val_mut, gene_emb_gpu).squeeze()
                test_logits_final = self.model(test_mut, gene_emb_gpu).squeeze()
                
                train_probs_final = torch.sigmoid(train_logits_final).cpu().numpy()
                val_probs_final = torch.sigmoid(val_logits_final).cpu().numpy()
                test_probs_final = torch.sigmoid(test_logits_final).cpu().numpy()
                
                train_preds_final = (train_probs_final > 0.5).astype(int)
                val_preds_final = (val_probs_final > 0.5).astype(int)
                test_preds_final = (test_probs_final > 0.5).astype(int)
                
                print(f"Train - Predictions: {np.bincount(train_preds_final, minlength=2)}, Actual: {np.bincount(train_pfs.cpu().numpy().astype(int), minlength=2)}")
                print(f"Val   - Predictions: {np.bincount(val_preds_final, minlength=2)}, Actual: {np.bincount(val_pfs.cpu().numpy().astype(int), minlength=2)}")
                print(f"Test  - Predictions: {np.bincount(test_preds_final, minlength=2)}, Actual: {np.bincount(test_pfs.cpu().numpy().astype(int), minlength=2)}")
                
                # Show probability distributions
                print(f"Val probability range: [{val_probs_final.min():.3f}, {val_probs_final.max():.3f}]")
                print(f"Val probability mean: {val_probs_final.mean():.3f}")
        
        # Clear GPU cache
        if self.device == 'cuda':
            torch.cuda.empty_cache()
        
        return history
    
    def predict(self, patient_mutations, gene_embeddings, batch_size=256, return_probs=False):
        """Make predictions on new data."""
        self.model.eval()
        
        all_logits = []
        n_patients = len(patient_mutations)
        
        # Move gene embeddings to GPU once
        gene_emb_gpu = torch.FloatTensor(gene_embeddings).to(self.device)
        
        with torch.no_grad():
            for i in range(0, n_patients, batch_size):
                end_idx = min(i + batch_size, n_patients)
                
                batch_mut = torch.FloatTensor(patient_mutations[i:end_idx]).to(self.device)
                batch_logits = self.model(batch_mut, gene_emb_gpu).cpu().numpy()
                all_logits.append(batch_logits)
        
        # Combine all logits
        logits = np.concatenate(all_logits, axis=0).flatten()
        
        # Convert to probabilities
        probabilities = torch.sigmoid(torch.tensor(logits)).numpy()
        
        if return_probs:
            return probabilities
        else:
            # Return binary predictions
            return (probabilities > 0.5).astype(int)
    
    def evaluate(self, patient_mutations, gene_embeddings, true_labels):
        """Evaluate model performance with classification metrics."""
        probabilities = self.predict(patient_mutations, gene_embeddings, return_probs=True)
        predictions = (probabilities > 0.5).astype(int)
        
        # Classification metrics
        accuracy = accuracy_score(true_labels, predictions)
        precision = precision_score(true_labels, predictions, zero_division=0)
        recall = recall_score(true_labels, predictions, zero_division=0)
        f1 = f1_score(true_labels, predictions, zero_division=0)
        
        try:
            auc = roc_auc_score(true_labels, probabilities)
        except ValueError:
            auc = 0.5  # Random performance if only one class
        
        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'auc': auc
        }
        
        return metrics, predictions, probabilities

def calculate_robust_metrics(true_labels, predictions, probabilities):
    """Calculate metrics with robust error handling."""
    metrics = {}
    
    # Basic counts
    tp = np.sum((true_labels == 1) & (predictions == 1))
    tn = np.sum((true_labels == 0) & (predictions == 0))
    fp = np.sum((true_labels == 0) & (predictions == 1))
    fn = np.sum((true_labels == 1) & (predictions == 0))
    
    # Accuracy
    metrics['accuracy'] = (tp + tn) / len(true_labels) if len(true_labels) > 0 else 0.0
    
    # Precision
    metrics['precision'] = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    
    # Recall (Sensitivity)
    metrics['recall'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    
    # Specificity
    metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    
    # F1 Score
    prec = metrics['precision']
    rec = metrics['recall']
    metrics['f1'] = 2 * (prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    
    # Balanced Accuracy
    metrics['balanced_accuracy'] = (metrics['recall'] + metrics['specificity']) / 2
    
    # AUC
    try:
        if len(np.unique(true_labels)) > 1:
            metrics['auc'] = roc_auc_score(true_labels, probabilities)
        else:
            metrics['auc'] = 0.5
    except:
        metrics['auc'] = 0.5
    
    return metrics

In [None]:
pfs_binary = gene_data['pfs'].values
mutants = gene_data.drop(columns=['pfs'])
patient_mutations = mutants.values
gene_embeddings = np.array(embeddings)

In [None]:
patient_mutations.max()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

def normalize_embeddings(embeddings, method='l2'):
    """
    Normalize embedding vectors.
    
    Args:
        embeddings: numpy array of shape (n_vectors, emb_dim)
        method: 'l2', 'standard', or 'minmax'
    
    Returns:
        normalized embeddings
    """
    if method == 'l2':
        # L2 normalization (unit vectors)
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
        # Avoid division by zero
        norms = np.where(norms == 0, 1, norms)
        return embeddings / norms
    
    elif method == 'standard':
        # Standardization (zero mean, unit variance)
        scaler = StandardScaler()
        return scaler.fit_transform(embeddings)
    
    elif method == 'minmax':
        # Min-max normalization (0-1 range)
        scaler = MinMaxScaler()
        return scaler.fit_transform(embeddings)
    
    else:
        raise ValueError("Method must be 'l2', 'standard', or 'minmax'")

def compute_patient_embeddings_vectorized(embeddings, mutations):
    """
    Vectorized version for better performance.
    """
    # Create mask for valid mutations
    valid_mask = ~(np.isnan(mutations) | (mutations == 0))
    
    embeddings = normalize_embeddings(embeddings, method='l2')

    # Set invalid mutations to 0 for computation
    weights = np.where(valid_mask, mutations, 0)
    
    # Compute weighted embeddings
    weighted_embeddings = weights @ embeddings  # (n_patients, emb_dim)
    
    # Compute total weights per patient
    total_weights = np.sum(weights, axis=1, keepdims=True)  # (n_patients, 1)
    
    # Avoid division by zero
    total_weights = np.where(total_weights == 0, 1, total_weights)
    
    # Compute mean embeddings
    patient_embeddings = weighted_embeddings / total_weights
    
    # Set to zero vector for patients with no valid mutations
    no_mutations_mask = np.sum(valid_mask, axis=1) == 0
    patient_embeddings[no_mutations_mask] = 0
    
    return patient_embeddings

# Alternative with more customization options
def visualize_embeddings_advanced(patient_embeddings, patient_classes, 
                                 class_names=None, colors=None,
                                 perplexity=50, random_state=42,
                                 figsize=(12, 8)):
    """
    Advanced visualization with customizable options.
    """
    # Standardize embeddings
    scaler = StandardScaler()
    embeddings_scaled = scaler.fit_transform(patient_embeddings)
    
    # Apply t-SNE
    tsne = TSNE(n_components=2, perplexity=perplexity, 
                random_state=random_state, n_iter=10000)
    embeddings_2d = tsne.fit_transform(embeddings_scaled)
    
    # Set up colors and labels
    unique_classes = np.unique(patient_classes)
    if colors is None:
        colors = plt.cm.Set1(np.linspace(0, 1, len(unique_classes)))
    if class_names is None:
        class_names = [f'Class {c}' for c in unique_classes]
    
    # Create plot
    fig, ax = plt.subplots(figsize=figsize)
    
    for i, class_label in enumerate(unique_classes):
        mask = patient_classes == class_label
        ax.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], 
                  c=[colors[i]], label=class_names[i], 
                  alpha=0.7, s=60, edgecolors='black', linewidth=0.5)
    
    ax.set_xlabel('t-SNE Component 1', fontsize=12)
    ax.set_ylabel('t-SNE Component 2', fontsize=12)
    ax.set_title('Patient Gene Mutation Embeddings (t-SNE)', fontsize=14, fontweight='bold')
    ax.legend(frameon=True, fancybox=True, shadow=True)
    ax.grid(True, alpha=0.3)
    
    # Add class distribution info
    class_counts = [np.sum(patient_classes == c) for c in unique_classes]
    info_text = f"Total patients: {len(patient_classes)}\n"
    for i, (name, count) in enumerate(zip(class_names, class_counts)):
        info_text += f"{name}: {count}\n"
    
    ax.text(0.02, 0.98, info_text, transform=ax.transAxes, 
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    plt.tight_layout()
    plt.show()
    
    return embeddings_2d

final_embs = compute_patient_embeddings_vectorized(gene_embeddings, patient_mutations)

embeddings_2d = visualize_embeddings_advanced(final_embs, pfs_binary, 
                                             class_names=['Control', 'Disease'])


In [None]:
print(f"Data shapes:")
print(f"Patient mutations: {patient_mutations.shape} (n_patients, n_genes)")
print(f"Gene embeddings: {gene_embeddings.shape} (n_genes, embedding_dim)")
print(f"PFS binary: {pfs_binary.shape} (n_patients,)")
print(f"Class distribution: {np.bincount(pfs_binary)} (0=poor, 1=good)")
print(f"Mutation rate: {patient_mutations.mean():.3f}")

# Test different gene-level combination methods
methods = ['weighted_sum', 'attention', 'gene_mlp']
results = {}

for method in methods:
    print(f"\n--- Training with {method} method ---")
    
    classifier = PFSClassifier(combination_method=method)
    
    # Train model
    history = classifier.train(
        patient_mutations, gene_embeddings, pfs_binary,
        epochs=150, batch_size=512, early_stopping_patience=75, verbose=True
    )
    
    # Evaluate
    metrics, predictions, probabilities = classifier.evaluate(patient_mutations, gene_embeddings, pfs_binary)
    results[method] = {
        'metrics': metrics, 
        'predictions': predictions, 
        'probabilities': probabilities,
        'history': history
    }
    
    print(f"Results for {method}:")
    print(f"  Accuracy: {metrics['accuracy']:.4f}")
    print(f"  F1 Score: {metrics['f1']:.4f}")
    print(f"  AUC-ROC: {metrics['auc']:.4f}")
    print(f"  Precision: {metrics['precision']:.4f}")
    print(f"  Recall: {metrics['recall']:.4f}")
    
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Find best method based on AUC
best_method = max(results.keys(), key=lambda x: results[x]['metrics']['auc'])
print(f"\nBest method: {best_method} (AUC = {results[best_method]['metrics']['auc']:.4f})")