In [None]:
import pandas as pd

from ChromaVDB.chroma import ChromaFramework
from tqdm.notebook import tqdm

vdb = ChromaFramework(persist_directory="./ChromaVDB/chroma_db")
records = vdb.list_records()
names = [record['name'] for record in records]
embs = [record['embeddings'] for record in records]

In [None]:
data = pd.read_excel('data/2025_03_29.xlsx')

In [None]:
gene_set = "LNF"
gene_measure = "MUT"

gene_data = data[[col for col in data.columns if gene_set in col and gene_measure in col]]

In [None]:
genes = list(set([ gene.split('_')[0] for gene in gene_data.columns  ]))

final_columns = []
embeddings = []

for gene in genes:
    if gene in names:
        final_columns.append(gene+"_"+gene_set+"_"+gene_measure)
        embeddings.append(embs[names.index(gene)])
    else:
        print(gene)

print(len(final_columns))

In [None]:
gene_data = gene_data[final_columns]
gene_data['pfs'] = data['PFS_Months_updated']
gene_data = gene_data.dropna()

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr

# GPU Configuration and Optimization
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("GPU device:", torch.cuda.get_device_name(0))
    
    # Set GPU memory management
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

class PFSPredictionMLP(nn.Module):
    """
    MLP model that combines gene embeddings with patient binary mutation features
    to predict progression-free survival scores.
    
    Data structure:
    - gene_embeddings: (n_genes, embedding_dim) - GNN embeddings for each gene
    - patient_mutations: (n_patients, n_genes) - binary mutation matrix
    - pfs_scores: (n_patients,) - PFS score for each patient
    """
    
    def __init__(self, n_genes, embedding_dim, hidden_dims=[512, 256, 128], 
                 dropout_rate=0.3, combination_method='weighted_sum'):
        super(PFSPredictionMLP, self).__init__()
        
        self.n_genes = n_genes
        self.embedding_dim = embedding_dim
        self.combination_method = combination_method
        
        # Gene-level fusion methods
        if combination_method == 'weighted_sum':
            # Weight each gene embedding by mutation status and sum across genes
            self.gene_weight = nn.Linear(1, 1, bias=False)  # Learn importance of mutation
            patient_repr_dim = embedding_dim
            
        elif combination_method == 'attention':
            # Use attention mechanism over genes
            self.gene_attention = nn.MultiheadAttention(
                embed_dim=embedding_dim, 
                num_heads=4, 
                batch_first=True
            )
            # Project mutation status to same dim as embeddings for attention
            self.mutation_proj = nn.Linear(1, embedding_dim)
            patient_repr_dim = embedding_dim
            
        elif combination_method == 'transformer':
            # More sophisticated gene-level transformer
            self.gene_transform = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    d_model=embedding_dim + 1,  # embedding + mutation status
                    nhead=4,
                    dim_feedforward=embedding_dim * 2,
                    dropout=dropout_rate,
                    batch_first=True
                ),
                num_layers=2
            )
            patient_repr_dim = embedding_dim + 1
            
        elif combination_method == 'gene_mlp':
            # MLP for each gene, then aggregate
            self.gene_mlp = nn.Sequential(
                nn.Linear(embedding_dim + 1, embedding_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Linear(embedding_dim, embedding_dim // 2)
            )
            patient_repr_dim = embedding_dim // 2
            
        else:
            raise ValueError("combination_method must be 'weighted_sum', 'attention', 'transformer', or 'gene_mlp'")
        
        # Patient-level MLP for final PFS prediction
        layers = []
        prev_dim = patient_repr_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.BatchNorm1d(hidden_dim)
            ])
            prev_dim = hidden_dim
        
        # Output layer for PFS prediction
        layers.append(nn.Linear(prev_dim, 1))
        
        self.patient_mlp = nn.Sequential(*layers)
        
    def forward(self, patient_mutations, gene_embeddings):
        """
        Forward pass combining gene embeddings with patient mutations.
        
        Args:
            patient_mutations: (batch_size, n_genes) binary mutation matrix
            gene_embeddings: (n_genes, embedding_dim) gene embeddings from GNN
        
        Returns:
            pfs_prediction: (batch_size, 1) predicted PFS scores
        """
        batch_size = patient_mutations.shape[0]
        
        if self.combination_method == 'weighted_sum':
            # For each patient, weight gene embeddings by mutation status
            # patient_mutations: (batch_size, n_genes, 1)
            mutations_expanded = patient_mutations.unsqueeze(-1)  # (batch_size, n_genes, 1)
            
            # Apply learned weight to mutation status
            weighted_mutations = self.gene_weight(mutations_expanded)  # (batch_size, n_genes, 1)
            
            # Weight gene embeddings by mutation status
            # gene_embeddings: (n_genes, embedding_dim) -> (1, n_genes, embedding_dim)
            gene_emb_expanded = gene_embeddings.unsqueeze(0)  # (1, n_genes, embedding_dim)
            
            # Broadcast multiply: (batch_size, n_genes, 1) * (1, n_genes, embedding_dim)
            weighted_gene_embs = weighted_mutations * gene_emb_expanded  # (batch_size, n_genes, embedding_dim)
            
            # Sum across genes for each patient
            patient_repr = weighted_gene_embs.sum(dim=1)  # (batch_size, embedding_dim)
            
        elif self.combination_method == 'attention':
            # Use attention mechanism to focus on relevant genes
            mutations_expanded = patient_mutations.unsqueeze(-1)  # (batch_size, n_genes, 1)
            
            # Project mutations to embedding dimension
            mutation_features = self.mutation_proj(mutations_expanded)  # (batch_size, n_genes, embedding_dim)
            
            # Add gene embeddings to mutation features
            gene_emb_batch = gene_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)  # (batch_size, n_genes, embedding_dim)
            combined_features = gene_emb_batch + mutation_features  # (batch_size, n_genes, embedding_dim)
            
            # Self-attention over genes
            attended_features, _ = self.gene_attention(
                combined_features, combined_features, combined_features
            )  # (batch_size, n_genes, embedding_dim)
            
            # Average pool across genes (could also use max pool or learned aggregation)
            patient_repr = attended_features.mean(dim=1)  # (batch_size, embedding_dim)
            
        elif self.combination_method == 'transformer':
            # Concatenate gene embeddings with mutation status
            mutations_expanded = patient_mutations.unsqueeze(-1)  # (batch_size, n_genes, 1)
            gene_emb_batch = gene_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)  # (batch_size, n_genes, embedding_dim)
            
            # Concatenate: (batch_size, n_genes, embedding_dim + 1)
            gene_features = torch.cat([gene_emb_batch, mutations_expanded], dim=-1)
            
            # Apply transformer encoder
            transformed = self.gene_transform(gene_features)  # (batch_size, n_genes, embedding_dim + 1)
            
            # Global average pooling across genes
            patient_repr = transformed.mean(dim=1)  # (batch_size, embedding_dim + 1)
            
        elif self.combination_method == 'gene_mlp':
            # Process each gene with MLP, then aggregate
            mutations_expanded = patient_mutations.unsqueeze(-1)  # (batch_size, n_genes, 1)
            gene_emb_batch = gene_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)  # (batch_size, n_genes, embedding_dim)
            
            # Concatenate and reshape for MLP processing
            gene_features = torch.cat([gene_emb_batch, mutations_expanded], dim=-1)  # (batch_size, n_genes, embedding_dim + 1)
            
            # Reshape to process all genes at once: (batch_size * n_genes, embedding_dim + 1)
            gene_features_flat = gene_features.view(-1, self.embedding_dim + 1)
            
            # Apply MLP to each gene
            gene_processed = self.gene_mlp(gene_features_flat)  # (batch_size * n_genes, embedding_dim // 2)
            
            # Reshape back: (batch_size, n_genes, embedding_dim // 2)
            gene_processed = gene_processed.view(batch_size, self.n_genes, -1)
            
            # Sum across genes for each patient
            patient_repr = gene_processed.sum(dim=1)  # (batch_size, embedding_dim // 2)
        
        # Final MLP for PFS prediction
        pfs_prediction = self.patient_mlp(patient_repr)
        return pfs_prediction

class PFSPredictor:
    """
    Main class for training and evaluating PFS prediction models with gene-level fusion.
    """
    
    def __init__(self, combination_method='weighted_sum', device='cuda'):
        if torch.cuda.is_available():
            self.device = device
            torch.cuda.empty_cache()
            gpu_name = torch.cuda.get_device_name(0)
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
            print(f"Using GPU: {gpu_name} ({gpu_memory:.1f}GB)")
        else:
            self.device = 'cpu'
            print("CUDA not available, using CPU")
                
        self.combination_method = combination_method
        self.model = None
        self.pfs_scaler = StandardScaler()
        
    def prepare_data(self, patient_mutations, gene_embeddings, pfs_scores):
        """
        Prepare and normalize data for training.
        
        Args:
            patient_mutations: (n_patients, n_genes) binary mutation matrix
            gene_embeddings: (n_genes, embedding_dim) gene embeddings from GNN
            pfs_scores: (n_patients,) PFS scores
        
        Returns:
            Normalized tensors ready for training
        """
        # Normalize PFS scores only
        pfs_normalized = self.pfs_scaler.fit_transform(pfs_scores.reshape(-1, 1)).flatten()
        
        # Convert to tensors (keep mutations binary, normalize embeddings)
        mutations_tensor = torch.FloatTensor(patient_mutations)  # Keep binary
        embeddings_tensor = torch.FloatTensor(gene_embeddings)   # Already normalized from GNN
        pfs_tensor = torch.FloatTensor(pfs_normalized)
        
        return mutations_tensor, embeddings_tensor, pfs_tensor
    
    def create_model(self, n_genes, embedding_dim, **model_kwargs):
        """Create and initialize the model."""
        self.model = PFSPredictionMLP(
            n_genes=n_genes,
            embedding_dim=embedding_dim,
            combination_method=self.combination_method,
            **model_kwargs
        ).to(self.device)
        
    def train(self, patient_mutations, gene_embeddings, pfs_scores, 
              test_size=0.15, validation_size=0.15, batch_size=64, 
              epochs=200, lr=0.001, weight_decay=1e-4, 
              early_stopping_patience=20, verbose=True):
        """
        Train the PFS prediction model.
        
        Args:
            patient_mutations: (n_patients, n_genes) binary mutation matrix
            gene_embeddings: (n_genes, embedding_dim) gene embeddings from GNN
            pfs_scores: (n_patients,) PFS scores
            
        Returns:
            training_history: dict with loss curves and metrics
        """
        
        # Prepare data
        mut_tensor, emb_tensor, pfs_tensor = self.prepare_data(patient_mutations, gene_embeddings, pfs_scores)
        n_patients = len(patient_mutations)
        n_genes, embedding_dim = gene_embeddings.shape
        
        if verbose:
            print(f"Training on {self.device.upper()}")
            print(f"Data: {n_patients} patients, {n_genes} genes, {embedding_dim}D gene embeddings")
        
        # Split data: train/val/test
        indices = np.arange(n_patients)
        train_idx, temp_idx = train_test_split(indices, test_size=test_size+validation_size, random_state=42)
        val_idx, test_idx = train_test_split(temp_idx, test_size=test_size/(test_size+validation_size), random_state=42)
        
        # Create splits and move to GPU
        train_mut = mut_tensor[train_idx].to(self.device)
        train_pfs = pfs_tensor[train_idx].to(self.device)
        
        val_mut = mut_tensor[val_idx].to(self.device)
        val_pfs = pfs_tensor[val_idx].to(self.device)
        
        test_mut = mut_tensor[test_idx].to(self.device)
        test_pfs = pfs_tensor[test_idx].to(self.device)
        
        # Gene embeddings stay constant (move to GPU once)
        gene_emb_gpu = emb_tensor.to(self.device)
        
        if verbose and self.device == 'cuda':
            print(f"Data moved to GPU. Memory usage: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
        
        # Create model
        if self.model is None:
            self.create_model(n_genes, embedding_dim)
            
        if verbose and self.device == 'cuda':
            model_params = sum(p.numel() for p in self.model.parameters())
            print(f"Model created with {model_params:,} parameters. GPU memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
        
        # Create data loader
        train_dataset = TensorDataset(train_mut, train_pfs)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False)
        
        # Optimizer and loss
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5)
        criterion = nn.MSELoss()
        
        # Training history
        history = {
            'train_losses': [], 'val_losses': [], 'test_losses': [],
            'train_r2': [], 'val_r2': [], 'test_r2': []
        }
        
        best_val_loss = float('inf')
        patience_counter = 0
        
        if verbose:
            print(f"Starting training for {epochs} epochs...")
        
        for epoch in range(epochs):
            # Training
            self.model.train()
            epoch_train_loss = 0
            train_predictions = []
            train_targets = []
            
            for batch_mut, batch_pfs in train_loader:
                optimizer.zero_grad()
                predictions = self.model(batch_mut, gene_emb_gpu).squeeze()
                loss = criterion(predictions, batch_pfs)
                loss.backward()
                optimizer.step()
                
                epoch_train_loss += loss.item()
                train_predictions.append(predictions.detach().cpu().numpy())
                train_targets.append(batch_pfs.detach().cpu().numpy())
            
            # Validation
            self.model.eval()
            with torch.no_grad():
                val_pred = self.model(val_mut, gene_emb_gpu).squeeze()
                val_loss = criterion(val_pred, val_pfs)
                
                # Test evaluation
                test_pred = self.model(test_mut, gene_emb_gpu).squeeze()
                test_loss = criterion(test_pred, test_pfs)
            
            # Calculate R² scores
            train_pred_all = np.concatenate(train_predictions)
            train_true_all = np.concatenate(train_targets)
            train_r2 = r2_score(train_true_all, train_pred_all)
            
            val_r2 = r2_score(val_pfs.cpu().numpy(), val_pred.cpu().numpy())
            test_r2 = r2_score(test_pfs.cpu().numpy(), test_pred.cpu().numpy())
            
            # Update history
            history['train_losses'].append(epoch_train_loss / len(train_loader))
            history['val_losses'].append(val_loss.item())
            history['test_losses'].append(test_loss.item())
            history['train_r2'].append(train_r2)
            history['val_r2'].append(val_r2)
            history['test_r2'].append(test_r2)
            
            # Learning rate scheduling
            scheduler.step(val_loss)
            
            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                self.best_model_state = {k: v.clone() for k, v in self.model.state_dict().items()}
            else:
                patience_counter += 1
                
            if patience_counter >= early_stopping_patience:
                if verbose:
                    print(f"Early stopping at epoch {epoch}")
                break
            
            if verbose and epoch % 25 == 0:
                print(f'Epoch {epoch}: Train Loss = {history["train_losses"][-1]:.4f}, '
                      f'Val Loss = {history["val_losses"][-1]:.4f}, '
                      f'Val R² = {val_r2:.4f}')
        
        # Load best model
        if hasattr(self, 'best_model_state'):
            self.model.load_state_dict(self.best_model_state)
        
        # Clear GPU cache
        if self.device == 'cuda':
            torch.cuda.empty_cache()
        
        return history
    
    def predict(self, patient_mutations, gene_embeddings, batch_size=256):
        """Make predictions on new data."""
        self.model.eval()
        
        all_predictions = []
        n_patients = len(patient_mutations)
        
        # Move gene embeddings to GPU once
        gene_emb_gpu = torch.FloatTensor(gene_embeddings).to(self.device)
        
        with torch.no_grad():
            for i in range(0, n_patients, batch_size):
                end_idx = min(i + batch_size, n_patients)
                
                batch_mut = torch.FloatTensor(patient_mutations[i:end_idx]).to(self.device)
                batch_pred = self.model(batch_mut, gene_emb_gpu).cpu().numpy()
                all_predictions.append(batch_pred)
        
        # Combine all predictions
        predictions = np.concatenate(all_predictions, axis=0)
        
        # Inverse transform to original scale
        predictions_original = self.pfs_scaler.inverse_transform(predictions)
        
        return predictions_original.flatten()
    
    def evaluate(self, patient_mutations, gene_embeddings, true_pfs):
        """Evaluate model performance."""
        predictions = self.predict(patient_mutations, gene_embeddings)
        
        # Regression metrics
        mse = mean_squared_error(true_pfs, predictions)
        mae = mean_absolute_error(true_pfs, predictions)
        r2 = r2_score(true_pfs, predictions)
        
        # Correlation metrics
        pearson_r, pearson_p = pearsonr(true_pfs, predictions)
        spearman_r, spearman_p = spearmanr(true_pfs, predictions)
        
        metrics = {
            'mse': mse,
            'mae': mae,
            'rmse': np.sqrt(mse),
            'r2': r2,
            'pearson_r': pearson_r,
            'pearson_p': pearson_p,
            'spearman_r': spearman_r,
            'spearman_p': spearman_p
        }
        
        return metrics, predictions

def plot_training_curves(history):
    """Plot training curves."""
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss curves
    axes[0].plot(history['train_losses'], label='Train Loss', alpha=0.8)
    axes[0].plot(history['val_losses'], label='Validation Loss', alpha=0.8)
    axes[0].plot(history['test_losses'], label='Test Loss', alpha=0.8)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('MSE Loss')
    axes[0].set_title('Training Curves - Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # R² curves
    axes[1].plot(history['train_r2'], label='Train R²', alpha=0.8)
    axes[1].plot(history['val_r2'], label='Validation R²', alpha=0.8)
    axes[1].plot(history['test_r2'], label='Test R²', alpha=0.8)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('R² Score')
    axes[1].set_title('Training Curves - R²')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def plot_predictions(true_pfs, predictions, title="PFS Predictions vs True Values"):
    """Plot predictions vs true values."""
    plt.figure(figsize=(8, 6))
    
    # Scatter plot
    plt.scatter(true_pfs, predictions, alpha=0.6, s=50)
    
    # Perfect prediction line
    min_val = min(true_pfs.min(), predictions.min())
    max_val = max(true_pfs.max(), predictions.max())
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='Perfect Prediction')
    
    # Calculate and display metrics
    r2 = r2_score(true_pfs, predictions)
    pearson_r, _ = pearsonr(true_pfs, predictions)
    
    plt.xlabel('True PFS Scores')
    plt.ylabel('Predicted PFS Scores')
    plt.title(f'{title}\nR² = {r2:.3f}, Pearson r = {pearson_r:.3f}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

In [None]:
pfs_scores = gene_data['pfs'].values
mutants = gene_data.drop(columns=['pfs'])
patient_mutations = mutants.values
gene_embeddings = np.array(embeddings)

In [None]:
"""Example usage with CORRECT data structure."""
print("=== Gene-Level PFS Prediction ===")

print("Generating example data...")
# patient_mutations, gene_embeddings, pfs_scores = generate_example_data()

print(f"Data shapes:")
print(f"Patient mutations: {patient_mutations.shape} (n_patients, n_genes)")
print(f"Gene embeddings: {gene_embeddings.shape} (n_genes, embedding_dim)")
print(f"PFS scores: {pfs_scores.shape} (n_patients,)")
print(f"PFS score range: {pfs_scores.min():.2f} to {pfs_scores.max():.2f}")
print(f"Mutation rate: {patient_mutations.mean():.3f}")

# Test different gene-level combination methods
methods = ['weighted_sum', 'attention', 'gene_mlp']
results = {}

for method in methods:
    print(f"\n--- Training with {method} method ---")
    
    predictor = PFSPredictor(combination_method=method)
    
    # Train model
    history = predictor.train(
        patient_mutations, gene_embeddings, pfs_scores,
        epochs=100, batch_size=32, verbose=True
    )
    
    # Evaluate
    metrics, predictions = predictor.evaluate(patient_mutations, gene_embeddings, pfs_scores)
    results[method] = {'metrics': metrics, 'predictions': predictions, 'history': history}
    
    print(f"Results:")
    print(f"  R² Score: {metrics['r2']:.4f}")
    print(f"  RMSE: {metrics['rmse']:.4f}")
    print(f"  Pearson r: {metrics['pearson_r']:.4f}")
    
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Find best method
best_method = max(results.keys(), key=lambda x: results[x]['metrics']['r2'])
print(f"\nBest method: {best_method} (R² = {results[best_method]['metrics']['r2']:.4f})")

# Plot results for best method
plot_training_curves(results[best_method]['history'])
plot_predictions(pfs_scores, results[best_method]['predictions'], 
                f"Best Method: {best_method}")