In [1]:
"""
Optuna GP-based Hyperparameter Optimization for Grok scRNA-seq Translation Model
Implements curriculum learning with adaptive loss normalization
"""

import os
import sys
import torch
import torch.nn as nn
import numpy as np
import optuna
from optuna.samplers._gp import GPSampler
from optuna.pruners import MedianPruner
import optuna.visualization as vis
import anndata as ad
import scanpy as sc
from datetime import datetime
import json
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score, adjusted_rand_score
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import umap

# Import your existing modules
#sys.path.append('/content/drive/MyDrive/Colab_Notebooks/translation/')
from GAN_functions import *

# Set up GPU memory management
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True

In [2]:
class GrokOptimizer:
    """Optuna-based optimizer for Grok CycleGAN model with curriculum learning"""
    
    def __init__(self, mouse_adata, human_adata, output_dir='optuna_results'):
        self.mouse_adata = mouse_adata
        self.human_adata = human_adata
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Setup device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Prepare data scalers and dimensionality reduction
        self.setup_data_preprocessing()
        
        # Initialize best score tracking
        self.best_score = float('inf')
        self.best_config = None
        
    def setup_data_preprocessing(self):
        """Setup scalers and dimensionality reduction models"""
        # Preprocess data using your existing function
        self.mouse_tensor = preprocess_data(
            self.mouse_adata, 'counts', library_size=10000
        )
        self.human_tensor = preprocess_data(
            self.human_adata, 'counts', library_size=10000
        )
        
        # Setup scaler
        self.scaler = StandardScaler()
        human_data_np = self.human_tensor.numpy()
        self.scaler.fit(human_data_np)
        
        # Setup PCA and UMAP
        self.pca_model = PCA(n_components=50)
        human_pca = self.pca_model.fit_transform(self.scaler.transform(human_data_np))
        self.umap_model = umap.UMAP(n_components=2, random_state=42)
        self.human_umap_coords = self.umap_model.fit_transform(human_pca)
        
        # Setup categorical covariates
        self.setup_categorical_covariates()
        
    def setup_categorical_covariates(self):
        """Setup cell type mappings"""
        unique_cell_types = sorted(
            set(self.mouse_adata.obs['myannotations']) | 
            set(self.human_adata.obs['myannotations'])
        )
        self.value_to_idx = {val: idx for idx, val in enumerate(unique_cell_types)}
        self.idx_to_cell_type = {idx: val for val, idx in self.value_to_idx.items()}
        
        mouse_cat = np.array([
            self.value_to_idx[val] for val in self.mouse_adata.obs['myannotations'].values
        ])
        human_cat = np.array([
            self.value_to_idx[val] for val in self.human_adata.obs['myannotations'].values
        ])
        
        self.cat_tensor = np.concatenate([mouse_cat, human_cat])[:, None]
        self.vocab_sizes = [len(unique_cell_types)]
        
    def create_curriculum_schedule(self, trial):
        """Create curriculum learning schedule based on trial parameters"""
        curriculum_config = {
            'warmup_epochs': trial.suggest_int('warmup_epochs', 10, 50),
            'cycle_focus_epochs': trial.suggest_int('cycle_focus_epochs', 20, 80),
            'adversarial_rampup': trial.suggest_int('adversarial_rampup', 10, 40),
            'full_training_start': trial.suggest_int('full_training_start', 50, 150),
            
            # Initial weights during warmup
            'initial_cycle_weight': trial.suggest_float('initial_cycle_weight', 5.0, 20.0),
            'initial_celltype_weight': trial.suggest_float('initial_celltype_weight', 0.1, 1.0),
            
            # Final weights
            'final_adv_weight': trial.suggest_float('final_adv_weight', 0.5, 2.0),
            'final_fm_weight': trial.suggest_float('final_fm_weight', 0.1, 0.5),
            'final_kl_weight': trial.suggest_float('final_kl_weight', 0.05, 0.2),
        }
        
        return curriculum_config
    
    def get_curriculum_weights(self, epoch, curriculum_config, base_config):
        """Get loss weights based on curriculum schedule"""
        weights = {}
        
        if epoch < curriculum_config['warmup_epochs']:
            # Phase 1: Focus on reconstruction
            progress = epoch / curriculum_config['warmup_epochs']
            weights['lambda_adv'] = 0.0
            weights['lambda_cycle'] = curriculum_config['initial_cycle_weight']
            weights['lambda_celltype'] = curriculum_config['initial_celltype_weight'] * progress
            weights['lambda_fm'] = 0.0
            weights['lambda_kl'] = base_config['lambda_kl'] * 0.1  # Small KL from start
            
        elif epoch < curriculum_config['cycle_focus_epochs']:
            # Phase 2: Strong cycle consistency
            weights['lambda_adv'] = 0.1  # Very small adversarial
            weights['lambda_cycle'] = base_config['lambda_cycle']
            weights['lambda_celltype'] = base_config['lambda_celltype']
            weights['lambda_fm'] = 0.0
            weights['lambda_kl'] = base_config['lambda_kl'] * 0.5
            
        elif epoch < curriculum_config['full_training_start']:
            # Phase 3: Gradual adversarial introduction
            progress = (epoch - curriculum_config['cycle_focus_epochs']) / \
                      (curriculum_config['full_training_start'] - curriculum_config['cycle_focus_epochs'])
            weights['lambda_adv'] = curriculum_config['final_adv_weight'] * progress
            weights['lambda_cycle'] = base_config['lambda_cycle']
            weights['lambda_celltype'] = base_config['lambda_celltype']
            weights['lambda_fm'] = curriculum_config['final_fm_weight'] * progress
            weights['lambda_kl'] = base_config['lambda_kl']
            
        else:
            # Phase 4: Full training
            weights['lambda_adv'] = curriculum_config['final_adv_weight']
            weights['lambda_cycle'] = base_config['lambda_cycle']
            weights['lambda_celltype'] = base_config['lambda_celltype']
            weights['lambda_fm'] = curriculum_config['final_fm_weight']
            weights['lambda_kl'] = base_config['lambda_kl']
            
        return weights
    
    def objective(self, trial):
        """Objective function for Optuna optimization"""
        try:
            # Suggest hyperparameters
            config = self.suggest_hyperparameters(trial)
            
            # Create curriculum schedule
            curriculum_config = self.create_curriculum_schedule(trial)
            
            # Build models
            models = self.build_models(config)
            
            # Train with curriculum learning
            metrics = self.train_with_curriculum(
                models, config, curriculum_config, trial
            )
            
            # Return multiple objectives
            return [
                metrics['final_fid'],           # Minimize FID
                metrics['final_cycle_loss'],    # Minimize cycle loss
                metrics['biological_score'],    # Minimize bio score (negative correlation)
                -metrics['cell_type_preservation']  # Maximize cell type preservation
            ]
            
        except Exception as e:
            print(f"Trial {trial.number} failed: {str(e)}")
            # Clean up GPU memory
            torch.cuda.empty_cache()
            raise optuna.TrialPruned()
    
    def suggest_hyperparameters(self, trial):
        """Suggest hyperparameters for the trial"""
        config = {
            # Training parameters
            'epochs': 300,  # Fixed for fair comparison
            'batch_size': trial.suggest_categorical('batch_size', [16, 32, 64]),
            'lr': trial.suggest_float('lr', 1e-5, 5e-4, log=True),
            
            # Architecture parameters
            'hidden_dims': trial.suggest_categorical(
                'hidden_dims',
                [[2048, 1024, 512], [1024, 512, 256], [4096, 2048, 1024]]
            ),
            'latent_dim': trial.suggest_categorical('latent_dim', [128, 256, 512]),
            'num_attention_heads': trial.suggest_int('num_attention_heads', 4, 16, step=4),
            'residual_blocks': trial.suggest_int('residual_blocks', 3, 6),
            
            # Loss weights (base values, will be modified by curriculum)
            'lambda_cycle': trial.suggest_float('lambda_cycle', 5.0, 20.0),
            'lambda_celltype': trial.suggest_float('lambda_celltype', 0.1, 0.5),
            'lambda_gp': trial.suggest_float('lambda_gp', 5.0, 20.0),
            'lambda_kl': trial.suggest_float('lambda_kl', 0.05, 0.2),
            
            # Training strategies
            'nb_critic': trial.suggest_int('nb_critic', 3, 7),
            'use_self_attention': trial.suggest_categorical('use_self_attention', [True, False]),
            'use_sparse_attention': False,  # Keep false for stability
            
            # Optimizer parameters
            'beta1': trial.suggest_float('beta1', 0.0, 0.9),
            'beta2': trial.suggest_float('beta2', 0.9, 0.999),
            
            # Learning rate decay
            'use_lr_decay': trial.suggest_categorical('use_lr_decay', [True, False]),
            'decay_start_epoch': 100,
            'decay_factor': trial.suggest_float('decay_factor', 0.1, 0.5),
            
            # Other fixed parameters
            'library_size': 10000,
            'preprocess_type': 'counts',
            'temperature': 0.5,
            'sparsity_target': 0.9,
            'device': str(self.device),
            'use_wandb': False  # Disable for optimization
        }
        
        # Conditional parameters
        if config['use_lr_decay']:
            config['decay_type'] = trial.suggest_categorical('decay_type', ['linear', 'cosine'])
        
        return config
    
    def build_models(self, config):
        """Build generator and discriminator models"""
        # Generators
        g_m2h = GeneExpressionGenerator(
            self.mouse_tensor.shape[1], 
            self.human_tensor.shape[1],
            self.vocab_sizes, 
            config
        ).to(self.device)
        
        g_h2m = GeneExpressionGenerator(
            self.human_tensor.shape[1],
            self.mouse_tensor.shape[1],
            self.vocab_sizes,
            config
        ).to(self.device)
        
        # Discriminators
        d_m = GeneExpressionDiscriminator(
            self.mouse_tensor.shape[1],
            self.vocab_sizes,
            config
        ).to(self.device)
        
        d_h = GeneExpressionDiscriminator(
            self.human_tensor.shape[1],
            self.vocab_sizes,
            config
        ).to(self.device)
        
        return {
            'g_m2h': g_m2h,
            'g_h2m': g_h2m,
            'd_m': d_m,
            'd_h': d_h
        }
    
    def train_with_curriculum(self, models, config, curriculum_config, trial):
        """Train models with curriculum learning and adaptive loss normalization"""
        # Setup optimizers
        g_optimizer = torch.optim.Adam(
            itertools.chain(models['g_m2h'].parameters(), models['g_h2m'].parameters()),
            lr=config['lr'],
            betas=(config['beta1'], config['beta2'])
        )
        
        d_m_optimizer = torch.optim.Adam(
            models['d_m'].parameters(),
            lr=config['lr'],
            betas=(config['beta1'], config['beta2'])
        )
        
        d_h_optimizer = torch.optim.Adam(
            models['d_h'].parameters(),
            lr=config['lr'],
            betas=(config['beta1'], config['beta2'])
        )
        
        # Learning rate scheduler
        if config['use_lr_decay']:
            lr_scheduler = LRDecayScheduler(
                [g_optimizer, d_m_optimizer, d_h_optimizer],
                config['lr'],
                config['decay_start_epoch'],
                config['epochs'],
                config['decay_factor'],
                config.get('decay_type', 'cosine')
            )
        else:
            lr_scheduler = None
        
        # Prepare data loaders
        n_mouse = self.mouse_tensor.size(0)
        mouse_cat_tensor = torch.tensor(self.cat_tensor[:n_mouse], dtype=torch.long).to(self.device)
        human_cat_tensor = torch.tensor(self.cat_tensor[n_mouse:], dtype=torch.long).to(self.device)
        
        mouse_dataset = TensorDataset(self.mouse_tensor, mouse_cat_tensor)
        human_dataset = TensorDataset(self.human_tensor, human_cat_tensor)
        
        mouse_loader = DataLoader(mouse_dataset, batch_size=config['batch_size'], shuffle=True)
        human_loader = DataLoader(human_dataset, batch_size=config['batch_size'], shuffle=True)
        
        # Training metrics
        best_biological_score = float('inf')
        metrics_history = []
        
        # Initialize adaptive loss scaling (from your original code)
        moving_avg_adv = 0
        moving_avg_cycle = 0
        moving_avg_celltype = 0
        moving_avg_fm = 0
        moving_avg_kl = 0
        alpha = 0.3  # Smoothing factor
        
        # Training loop
        for epoch in range(config['epochs']):
            # Get curriculum weights
            curriculum_weights = self.get_curriculum_weights(
                epoch, curriculum_config, config
            )
            
            # Train one epoch
            epoch_metrics = self.train_epoch(
                models, mouse_loader, human_loader,
                g_optimizer, d_m_optimizer, d_h_optimizer,
                config, curriculum_weights,
                moving_avg_adv, moving_avg_cycle, moving_avg_celltype,
                moving_avg_fm, moving_avg_kl, alpha
            )
            
            # Update moving averages
            moving_avg_adv = alpha * epoch_metrics['avg_adv_loss'] + (1 - alpha) * moving_avg_adv
            moving_avg_cycle = alpha * epoch_metrics['avg_cycle_loss'] + (1 - alpha) * moving_avg_cycle
            moving_avg_celltype = alpha * epoch_metrics['avg_celltype_loss'] + (1 - alpha) * moving_avg_celltype
            moving_avg_fm = alpha * epoch_metrics['avg_fm_loss'] + (1 - alpha) * moving_avg_fm
            moving_avg_kl = alpha * epoch_metrics['avg_kl_loss'] + (1 - alpha) * moving_avg_kl
            
            # Update learning rate
            if lr_scheduler:
                current_lr = lr_scheduler.step(epoch)
                epoch_metrics['lr'] = current_lr
            
            metrics_history.append(epoch_metrics)
            
            # Evaluate and report to Optuna (every 10 epochs)
            if epoch % 10 == 0 and epoch > 0:
                eval_metrics = self.evaluate_models(models, epoch)
                
                # Report primary metric for pruning
                trial.report(eval_metrics['biological_score'], epoch)
                
                # Check for pruning
                if trial.should_prune():
                    print(f"Trial {trial.number} pruned at epoch {epoch}")
                    raise optuna.TrialPruned()
                
                # Check for GAN-specific failures
                if self.check_gan_failure(epoch_metrics):
                    print(f"Trial {trial.number} failed due to GAN instability")
                    raise optuna.TrialPruned()
                
                # Update best score
                if eval_metrics['biological_score'] < best_biological_score:
                    best_biological_score = eval_metrics['biological_score']
                    self.save_checkpoint(models, config, trial.number, epoch)
            
            # Progress logging
            if epoch % 20 == 0:
                print(f"Trial {trial.number}, Epoch {epoch}: "
                      f"G Loss: {epoch_metrics['avg_g_loss']:.4f}, "
                      f"Cycle: {epoch_metrics['avg_cycle_loss']:.4f}, "
                      f"Bio Score: {best_biological_score:.4f}")
        
        # Final evaluation
        final_metrics = self.evaluate_models(models, config['epochs'])
        final_metrics['training_history'] = metrics_history
        
        # Save trial results
        self.save_trial_results(trial, config, final_metrics)
        
        return final_metrics
    
    def train_epoch(self, models, mouse_loader, human_loader, 
                   g_optimizer, d_m_optimizer, d_h_optimizer,
                   config, curriculum_weights,
                   moving_avg_adv, moving_avg_cycle, moving_avg_celltype,
                   moving_avg_fm, moving_avg_kl, alpha):
        """Train one epoch with adaptive loss normalization"""
        # Tracking metrics
        g_losses, d_m_losses, d_h_losses = [], [], []
        cycle_losses, celltype_losses, fm_losses, kl_losses = [], [], [], []
        adv_losses = []
        
        mouse_iter = iter(mouse_loader)
        human_iter = iter(human_loader)
        
        for i in range(max(len(mouse_loader), len(human_loader))):
            # Get batches
            try:
                mouse_data, mouse_cat = next(mouse_iter)
            except StopIteration:
                mouse_iter = iter(mouse_loader)
                mouse_data, mouse_cat = next(mouse_iter)
            
            try:
                human_data, human_cat = next(human_iter)
            except StopIteration:
                human_iter = iter(human_loader)
                human_data, human_cat = next(human_iter)
            
            # Synchronize batch sizes
            min_batch_size = min(mouse_data.size(0), human_data.size(0))
            mouse_data = mouse_data[:min_batch_size].to(self.device)
            human_data = human_data[:min_batch_size].to(self.device)
            mouse_cat = mouse_cat[:min_batch_size].to(self.device)
            human_cat = human_cat[:min_batch_size].to(self.device)
            
            # Train discriminators
            if curriculum_weights['lambda_adv'] > 0:
                d_m_optimizer.zero_grad()
                d_h_optimizer.zero_grad()
                
                with torch.no_grad():
                    fake_human, _, _ = models['g_m2h'](mouse_data, mouse_cat)
                    fake_mouse, _, _ = models['g_h2m'](human_data, human_cat)
                
                # Discriminator losses
                real_m_validity, real_m_features = models['d_m'](mouse_data, mouse_cat, return_features=True)
                fake_m_validity, fake_m_features = models['d_m'](fake_mouse.detach(), human_cat, return_features=True)
                gp_m = compute_gradient_penalty(models['d_m'], mouse_data, fake_mouse, self.device, human_cat)
                d_m_loss = -torch.mean(real_m_validity) + torch.mean(fake_m_validity) + config['lambda_gp'] * gp_m
                
                real_h_validity, real_h_features = models['d_h'](human_data, human_cat, return_features=True)
                fake_h_validity, fake_h_features = models['d_h'](fake_human.detach(), mouse_cat, return_features=True)
                gp_h = compute_gradient_penalty(models['d_h'], human_data, fake_human, self.device, mouse_cat)
                d_h_loss = -torch.mean(real_h_validity) + torch.mean(fake_h_validity) + config['lambda_gp'] * gp_h
                
                d_m_loss.backward()
                d_h_loss.backward()
                d_m_optimizer.step()
                d_h_optimizer.step()
                
                d_m_losses.append(d_m_loss.item())
                d_h_losses.append(d_h_loss.item())
            
            # Train generators
            if i % config['nb_critic'] == 0:
                g_optimizer.zero_grad()
                
                # Forward pass
                fake_human, mu_h, log_var_h = models['g_m2h'](mouse_data, mouse_cat)
                fake_mouse, mu_m, log_var_m = models['g_h2m'](human_data, human_cat)
                
                # Calculate losses
                g_adv = 0
                if curriculum_weights['lambda_adv'] > 0:
                    g_adv = -torch.mean(models['d_h'](fake_human, human_cat)) - \
                            torch.mean(models['d_m'](fake_mouse, mouse_cat))
                    adv_losses.append(g_adv.item())
                
                # Cycle consistency loss
                if config['preprocess_type'] == 'counts':
                    cycle_mouse, _, _ = models['g_h2m'](fake_human, mouse_cat)
                    cycle_human, _, _ = models['g_m2h'](fake_mouse, human_cat)
                    g_cycle = compute_poisson_loss(mouse_data, cycle_mouse) + \
                             compute_poisson_loss(human_data, cycle_human)
                else:
                    cycle_mouse, _, _ = models['g_h2m'](fake_human, mouse_cat)
                    cycle_human, _, _ = models['g_m2h'](fake_mouse, human_cat)
                    g_cycle = F.l1_loss(mouse_data, cycle_mouse) + F.l1_loss(human_data, cycle_human)
                
                # Cell type loss
                g_celltype = F.cross_entropy(models['d_h'](fake_human, mouse_cat), mouse_cat.squeeze()) + \
                            F.cross_entropy(models['d_m'](fake_mouse, human_cat), human_cat.squeeze())
                
                # Feature matching loss
                g_fm = 0
                if curriculum_weights['lambda_fm'] > 0:
                    _, real_h_feat = models['d_h'](human_data, human_cat, return_features=True)
                    _, fake_h_feat = models['d_h'](fake_human, mouse_cat, return_features=True)
                    _, real_m_feat = models['d_m'](mouse_data, mouse_cat, return_features=True)
                    _, fake_m_feat = models['d_m'](fake_mouse, human_cat, return_features=True)
                    
                    for rf, ff in zip(real_h_feat, fake_h_feat):
                        g_fm += F.l1_loss(rf.detach(), ff)
                    for rf, ff in zip(real_m_feat, fake_m_feat):
                        g_fm += F.l1_loss(rf.detach(), ff)
                
                # KL loss
                kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var_h - mu_h ** 2 - log_var_h.exp(), dim=1)) + \
                         torch.mean(-0.5 * torch.sum(1 + log_var_m - mu_m ** 2 - log_var_m.exp(), dim=1))
                
                # Adaptive loss scaling (from your original code)
                if moving_avg_adv > 0:
                    scale_adv = 1.0 / (1.0 + moving_avg_adv)
                else:
                    scale_adv = 1.0
                scale_cycle = 1.0 / (1.0 + moving_avg_cycle) if moving_avg_cycle > 0 else 1.0
                scale_celltype = 1.0 / (1.0 + moving_avg_celltype) if moving_avg_celltype > 0 else 1.0
                scale_fm = 1.0 / (1.0 + moving_avg_fm) if moving_avg_fm > 0 else 1.0
                scale_kl = 1.0 / (1.0 + moving_avg_kl) if moving_avg_kl > 0 else 1.0
                
                # Total loss with curriculum weights and adaptive scaling
                g_loss = (curriculum_weights['lambda_adv'] * scale_adv * g_adv +
                         curriculum_weights['lambda_cycle'] * scale_cycle * g_cycle +
                         curriculum_weights['lambda_celltype'] * scale_celltype * g_celltype +
                         curriculum_weights['lambda_fm'] * scale_fm * g_fm +
                         curriculum_weights['lambda_kl'] * scale_kl * kl_loss)
                
                g_loss.backward()
                g_optimizer.step()
                
                # Track losses
                g_losses.append(g_loss.item())
                cycle_losses.append(g_cycle.item())
                celltype_losses.append(g_celltype.item())
                fm_losses.append(g_fm.item() if curriculum_weights['lambda_fm'] > 0 else 0)
                kl_losses.append(kl_loss.item())
        
        # Return epoch metrics
        return {
            'avg_g_loss': np.mean(g_losses) if g_losses else 0,
            'avg_d_m_loss': np.mean(d_m_losses) if d_m_losses else 0,
            'avg_d_h_loss': np.mean(d_h_losses) if d_h_losses else 0,
            'avg_cycle_loss': np.mean(cycle_losses) if cycle_losses else 0,
            'avg_celltype_loss': np.mean(celltype_losses) if celltype_losses else 0,
            'avg_fm_loss': np.mean(fm_losses) if fm_losses else 0,
            'avg_kl_loss': np.mean(kl_losses) if kl_losses else 0,
            'avg_adv_loss': np.mean(adv_losses) if adv_losses else 0,
        }
    
    def evaluate_models(self, models, epoch):
        """Evaluate models with biological metrics"""
        models['g_m2h'].eval()
        models['g_h2m'].eval()
        
        with torch.no_grad():
            # Sample translations
            n_samples = min(1000, self.mouse_tensor.size(0))
            indices = np.random.choice(self.mouse_tensor.size(0), n_samples, replace=False)
            
            mouse_sample = self.mouse_tensor[indices].to(self.device)
            n_mouse = self.mouse_tensor.size(0)
            cat_sample = torch.tensor(
                self.cat_tensor[:n_mouse][indices], 
                dtype=torch.long
            ).to(self.device)
            
            # Generate translations
            fake_human, _, _ = models['g_m2h'](mouse_sample, cat_sample)
            fake_human_np = fake_human.cpu().numpy()
            
            # Biological evaluation metrics
            
            # 1. Cell type preservation (correlation of cell type centroids)
            real_centroids = compute_cell_type_centroids(
                self.human_adata, self.human_tensor, self.device
            )
            fake_adata = ad.AnnData(fake_human_np)
            fake_adata.obs['myannotations'] = [
                self.idx_to_cell_type[idx.item()] for idx in cat_sample
            ]
            fake_centroids = compute_cell_type_centroids(
                fake_adata, fake_human, self.device
            )
            
            correlations = []
            for ct in real_centroids:
                if ct in fake_centroids:
                    corr = np.corrcoef(
                        real_centroids[ct].cpu().numpy(),
                        fake_centroids[ct].cpu().numpy()
                    )[0, 1]
                    if not np.isnan(corr):
                        correlations.append(corr)
            
            avg_correlation = np.mean(correlations) if correlations else 0
            
            # 2. UMAP-based evaluation
            if self.pca_model and self.scaler:
                fake_human_scaled = self.scaler.transform(fake_human_np)
                fake_human_pca = self.pca_model.transform(fake_human_scaled)
                fake_human_umap = self.umap_model.transform(fake_human_pca)
                
                # Calculate silhouette score
                labels = [self.idx_to_cell_type[idx.item()] for idx in cat_sample]
                if len(set(labels)) > 1:
                    silhouette = silhouette_score(fake_human_umap, labels)
                else:
                    silhouette = 0
            else:
                silhouette = 0
            
            # 3. Simple FID approximation (using PCA distances)
            if hasattr(self, 'human_pca_mean'):
                fake_pca_mean = np.mean(fake_human_pca, axis=0)
                fake_pca_cov = np.cov(fake_human_pca, rowvar=False)
                
                # Simplified FID calculation
                mean_diff = np.sum((self.human_pca_mean - fake_pca_mean) ** 2)
                cov_diff = np.trace(self.human_pca_cov + fake_pca_cov - 
                                   2 * np.sqrt(self.human_pca_cov @ fake_pca_cov))
                fid_score = mean_diff + cov_diff
            else:
                # Cache human statistics for future use
                human_pca = self.pca_model.transform(
                    self.scaler.transform(self.human_tensor.numpy())
                )
                self.human_pca_mean = np.mean(human_pca, axis=0)
                self.human_pca_cov = np.cov(human_pca, rowvar=False)
                fid_score = 100.0  # Default high value
            
            # 4. Cycle consistency check
            cycle_mouse, _, _ = models['g_h2m'](fake_human, cat_sample)
            cycle_loss = F.l1_loss(mouse_sample, cycle_mouse).item()
        
        models['g_m2h'].train()
        models['g_h2m'].train()
        
        # Combine metrics into biological score
        biological_score = (
            -avg_correlation * 100 +  # Negative because we want high correlation
            (1 - silhouette) * 50 +   # Low silhouette is bad
            cycle_loss * 10           # Low cycle loss is good
        )
        
        return {
            'biological_score': biological_score,
            'cell_type_preservation': avg_correlation,
            'silhouette_score': silhouette,
            'final_fid': fid_score,
            'final_cycle_loss': cycle_loss
        }
    
    def check_gan_failure(self, metrics):
        """Check for GAN-specific failure modes"""
        # Mode collapse detection
        if metrics['avg_g_loss'] < 0.1 and metrics['avg_d_m_loss'] > 5.0:
            return True
        
        # Training divergence
        if (metrics['avg_g_loss'] > 10.0 or 
            metrics['avg_d_m_loss'] > 10.0 or
            np.isnan(metrics['avg_g_loss'])):
            return True
        
        return False
    
    def save_checkpoint(self, models, config, trial_number, epoch):
        """Save model checkpoint"""
        checkpoint_dir = os.path.join(self.output_dir, f'trial_{trial_number}')
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        checkpoint = {
            'epoch': epoch,
            'config': config,
            'g_m2h_state': models['g_m2h'].state_dict(),
            'g_h2m_state': models['g_h2m'].state_dict(),
            'd_m_state': models['d_m'].state_dict(),
            'd_h_state': models['d_h'].state_dict(),
        }
        
        torch.save(checkpoint, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt'))
    
    def save_trial_results(self, trial, config, metrics):
        """Save detailed trial results"""
        results = {
            'trial_number': trial.number,
            'config': config,
            'metrics': metrics,
            'params': trial.params,
            'values': trial.values if hasattr(trial, 'values') else None,
        }
        
        with open(os.path.join(self.output_dir, f'trial_{trial.number}_results.pkl'), 'wb') as f:
            pickle.dump(results, f)
    
    def run_optimization(self, n_trials=100, n_jobs=1):
        """Run the optimization study"""
        # Create study with GP sampler
        sampler = GPSampler(
            n_startup_trials=20,  # More startup trials for GP
            deterministic_objective=False,
            seed=42
        )
        
        pruner = MedianPruner(
            n_startup_trials=5,
            n_warmup_steps=30,
            interval_steps=10
        )
        
        study = optuna.create_study(
            directions=['minimize', 'minimize', 'minimize', 'minimize'],
            sampler=sampler,
            pruner=pruner,
            study_name='grok_cyclegan_optimization'
        )
        
        # Add callbacks
        study.set_user_attr("best_biological_score", float('inf'))
        
        # Run optimization
        study.optimize(
            self.objective,
            n_trials=n_trials,
            n_jobs=n_jobs,
            gc_after_trial=True,
            show_progress_bar=True
        )
        
        # Save study
        with open(os.path.join(self.output_dir, 'study.pkl'), 'wb') as f:
            pickle.dump(study, f)
        
        # Analyze results
        self.analyze_results(study)
        
        return study
    
    def analyze_results(self, study):
        """Analyze and visualize optimization results"""
        # Get Pareto front trials
        pareto_trials = study.best_trials
        
        print(f"\nOptimization completed with {len(study.trials)} trials")
        print(f"Number of Pareto optimal solutions: {len(pareto_trials)}")
        
        # Print best trials
        print("\nTop 5 Pareto optimal configurations:")
        for i, trial in enumerate(pareto_trials[:5]):
            print(f"\nTrial {trial.number}:")
            print(f"  Objectives: {trial.values}")
            print(f"  Key parameters:")
            for key in ['lr', 'batch_size', 'lambda_cycle', 'lambda_celltype']:
                if key in trial.params:
                    print(f"    {key}: {trial.params[key]}")
        
        # Visualizations
        self.create_optimization_plots(study)
        
        # Save best configuration
        best_trial = min(pareto_trials, key=lambda t: t.values[2])  # Best biological score
        best_config = self.suggest_hyperparameters(best_trial)
        
        with open(os.path.join(self.output_dir, 'best_config.json'), 'w') as f:
            json.dump(best_config, f, indent=4)
        
        print(f"\nBest configuration saved to {self.output_dir}/best_config.json")
    
    def create_optimization_plots(self, study):
        """Create visualization plots for the optimization results"""
        
        # Parameter importance
        fig = vis.plot_param_importances(study)
        fig.write_html(os.path.join(self.output_dir, 'param_importances.html'))
        
        # Optimization history
        fig = vis.plot_optimization_history(study)
        fig.write_html(os.path.join(self.output_dir, 'optimization_history.html'))
        
        # Pareto front
        if len(study.best_trials) > 1:
            fig = vis.plot_pareto_front(study, targets=['FID', 'Biological Score'])
            fig.write_html(os.path.join(self.output_dir, 'pareto_front.html'))
        
        print(f"Visualization plots saved to {self.output_dir}/")

In [None]:
def main():
    """Main execution function"""
    # Load your data
    print("Loading scRNA-seq data...")
    mouse_adata = ad.read_h5ad("/Users/guyshani/Documents/PHD/Aim_2/PBMC_data/mouse/train_data_library_counts_PBMC.h5ad")
    human_adata = ad.read_h5ad("/Users/guyshani/Documents/PHD/Aim_2/PBMC_data/human/320k_test/train_data_library_counts_PBMC_human.h5ad")
    
    # Create optimizer
    print("Initializing Grok optimizer...")
    optimizer = GrokOptimizer(mouse_adata, human_adata, output_dir='optuna_grok_results')
    
    # Run optimization
    print("Starting hyperparameter optimization...")
    study = optimizer.run_optimization(
        n_trials=100,  # Adjust based on your computational budget
        n_jobs=1       # Set to number of GPUs for parallel trials
    )
    
    print("\nOptimization completed!")
    
    # Train final model with best parameters
    print("\nTraining final model with best parameters...")
    best_trial = min(study.best_trials, key=lambda t: t.values[2])
    best_config = optimizer.suggest_hyperparameters(best_trial)
    
    # You can now use best_config to train your final model
    print(f"Best configuration found:")
    print(json.dumps(best_config, indent=2))


if __name__ == "__main__":
    main()

Loading scRNA-seq data...
Initializing Grok optimizer...


