In [1]:
import torch
import h5py
import pandas as pd
import numpy as np
from cross_validation import CrossValidationFramework
import wandb

In [None]:
def main():
    # Device configuration
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load real data
    data_path = "/path/to/your/data/"
    
    # Load expression matrix
    with h5py.File(data_path+'train_data_1dataset.h5', 'r') as f:
        real_data = f['matrix'][:]
    
    # Load cluster info
    cluster_vec = pd.read_csv(data_path+'train_data_1dataset_cluster.csv').T
    
    # Convert cluster labels to numeric
    unique_clusters = sorted(set(cluster_vec.values.flatten()))
    cluster_dict = {t: i for i, t in enumerate(unique_clusters)}
    real_labels = np.vectorize(lambda t: cluster_dict[t])(cluster_vec).flatten()
    
    # Initialize wandb
    wandb.init(
        project='single_cell_cross_validation',
        config={
            'n_splits': 5,
            'epochs': 50,
            'batch_size': 32,
            'nn_hidden_dims': [512, 256, 128]
        }
    )
    '''
    custom_models = {
    'random_forest': RandomForestClassifier(n_estimators=200),
    'xgboost': XGBClassifier(max_depth=7)
    }
    cv_framework = CrossValidationFramework(tree_models=custom_models)
    '''
    # Create framework instance
    cv_framework = CrossValidationFramework(
        n_splits=5,
        nn_hidden_dims=[512, 256, 128]
    )
    
    # Perform cross-validation for tree-based models
    print("\nPerforming cross-validation for tree-based models...")
    tree_results = cv_framework.cross_validate_trees(
        real_data=real_data,
        real_labels=real_labels
    )
    
    # Perform cross-validation for neural network
    print("\nPerforming cross-validation for neural network...")
    nn_results = cv_framework.cross_validate_nn(
        real_data=real_data,
        real_labels=real_labels,
        batch_size=32,
        epochs=50,
        device=device
    )
    
    # Print final results
    print("\nFinal Results:")
    
    print("\nTree-based Models:")
    for model_name, model_results in tree_results.items():
        print(f"\n{model_name}:")
        for scenario, metrics in model_results.items():
            print(f"  {scenario}:")
            print(f"    Mean Accuracy: {metrics['mean_accuracy']:.4f} ± {metrics['std_accuracy']:.4f}")
            print(f"    Mean Training Time: {metrics['mean_time']:.2f}s")
    
    print("\nNeural Network:")
    for scenario, metrics in nn_results.items():
        print(f"  {scenario}:")
        print(f"    Mean Accuracy: {metrics['mean_accuracy']:.4f} ± {metrics['std_accuracy']:.4f}")
        print(f"    Mean Training Time: {metrics['mean_time']:.2f}s")
    
    wandb.finish()

if __name__ == '__main__':
    main()