In [2]:
import torch
import h5py
import os
import pandas as pd
import numpy as np
import anndata as ad
from cross_validation import CrossValidationFramework

In [7]:
def main(selected_models=None, selected_scenarios=None):
    # Device configuration
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Set directory paths
    output_dir = "/Users/guyshani/Documents/PHD/Aim_2/test_models/full_data_counts/run_20250302_091205_dataset+myannotations/"
    data_path = "/Users/guyshani/Documents/PHD/Aim_2/PBMC_data/mouse/"

    # Load files
    generated_h5ad = os.path.join(output_dir, "generated_data.h5ad")  # The newly generated h5ad file
    real_h5ad = os.path.join(data_path, "test_data_library_counts_PBMC.h5ad")  # The original h5ad file
    adataGen = ad.read_h5ad(generated_h5ad)
    adataReal = ad.read_h5ad(real_h5ad)

    real_data = adataReal.X.toarray()
    labels_real = adataReal.obs['myannotations']
    gen_data = adataGen.X
    labels_gen = adataGen.obs['myannotations']

    ### Real data ###
    # Convert cluster labels to numeric
    unique_clusters = sorted(set(labels_real.values))
    cluster_dict = {t: i for i, t in enumerate(unique_clusters)}
    labels_real = np.vectorize(lambda t: cluster_dict[t])(labels_real).flatten()
    # Count samples per class
    class_counts = np.bincount(labels_real)
    # Keep only classes with at least n_splits samples
    min_samples = 10  
    valid_classes = np.where(class_counts >= min_samples)[0]
    mask = np.isin(labels_real, valid_classes)
    # Filter data and labels
    real_data = real_data[mask]
    labels_real = labels_real[mask]
    
    print("-----Real data-----")
    print(f"Original number of samples: {len(mask)}")
    print(f"Samples after filtering: {len(labels_real)}")
    print(f"Removed classes with < {min_samples} samples")

    ### Generated data ###
    # Convert cluster labels to numeric
    unique_clusters = sorted(set(labels_gen.values))
    cluster_dict = {t: i for i, t in enumerate(unique_clusters)}
    labels_gen = np.vectorize(lambda t: cluster_dict[t])(labels_gen).flatten()
    # Count samples per class
    class_counts = np.bincount(labels_gen)
    # Keep only classes with at least n_splits samples
    valid_classes = np.where(class_counts >= min_samples)[0]
    mask = np.isin(labels_gen, valid_classes)
    # Filter data and labels
    gen_data = gen_data[mask]
    labels_gen = labels_gen[mask]
    
    print("-----Generated data-----")
    print(f"Original number of samples: {len(mask)}")
    print(f"Samples after filtering: {len(labels_gen)}")
    print(f"Removed classes with < {min_samples} samples")
    
    # Create framework instance with selected models
    cv_framework = CrossValidationFramework(
        n_splits=4,
        nn_hidden_dims=[512, 256, 128],
        selected_models=selected_models,
        selected_scenarios=selected_scenarios
    )
    
    # Perform cross-validation for tree-based models if any are selected
    tree_based = ['decision_tree', 'random_forest', 'gradient_boosting', 'xgboost']
    if any(model in selected_models for model in tree_based):
        print("\nPerforming cross-validation for tree-based models...")
        tree_results = cv_framework.cross_validate_trees(
            real_data=real_data,
            real_labels=labels_real,
            generated_data=gen_data,
            generated_labels=labels_gen
        )
    else:
        tree_results = {}
    
    # Perform cross-validation for neural network if selected
    if 'neural_network' in selected_models:
        print("\nPerforming cross-validation for neural network...")
        nn_results = cv_framework.cross_validate_nn(
            real_data=real_data,
            real_labels=labels_real,
            generated_data=gen_data,
            generated_labels=labels_gen,
            batch_size=32,
            epochs=50,
            device=device
        )
    else:
        nn_results = {}
    
    # Print final results
    print("\nFinal Results:")
    
    print("\nTree-based Models:")
    for model_name, model_results in tree_results.items():
        print(f"\n{model_name}:")
        for scenario, metrics in model_results.items():
            print(f"  {scenario}:")
            print(f"    Mean Accuracy: {metrics['mean_accuracy']:.4f} ± {metrics['std_accuracy']:.4f}")
            print(f"    Mean Training Time: {metrics['mean_time']:.2f}s")
    
    print("\nNeural Network:")
    for scenario, metrics in nn_results.items():
        print(f"  {scenario}:")
        print(f"    Mean Accuracy: {metrics['mean_accuracy']:.4f} ± {metrics['std_accuracy']:.4f}")
        print(f"    Mean Training Time: {metrics['mean_time']:.2f}s")

if __name__ == '__main__':
    # Example usage:
    selected_models = ['random_forest']
    selected_scenarios = ['generated_only']
    main(selected_models=selected_models, selected_scenarios=selected_scenarios)

    # Run only random forest and neural network
    # main(['random_forest', 'neural_network'])
    # Run all models
    # main()  # or main(None)
    # Run only tree-based models
    # ['real_only','generated_only', 'mixed']
    # main(['decision_tree', 'random_forest', 'gradient_boosting', 'xgboost'])

Using device: mps
-----Real data-----
Original number of samples: 11854
Samples after filtering: 11854
Removed classes with < 10 samples
-----Generated data-----
Original number of samples: 11854
Samples after filtering: 11854
Removed classes with < 10 samples

Performing cross-validation for tree-based models...

Cross-validating random_forest...
  Scenario: generated_only
    Fold 1: Balanced accuracy = 0.0853, Macro F1 = 0.0640,Weighted ROC AUC = 0.5112, Time = 152.38s
    Fold 2: Balanced accuracy = 0.0848, Macro F1 = 0.0635,Weighted ROC AUC = 0.4934, Time = 151.56s
    Fold 3: Balanced accuracy = 0.0843, Macro F1 = 0.0623,Weighted ROC AUC = 0.5100, Time = 158.91s
    Fold 4: Balanced accuracy = 0.0829, Macro F1 = 0.0605,Weighted ROC AUC = 0.4947, Time = 211.71s

Final Results:

Tree-based Models:

random_forest:
  generated_only:
    Mean Accuracy: 0.0843 ± 0.0009
    Mean Training Time: 168.64s

Neural Network:


In [23]:
data_path_G = "/Users/guyshani/Documents/PHD/Aim_2/test_models/run_20250201_161547_dataset_singler_label/"
# Load expression matrix
with h5py.File(data_path_G+'_generated_data.h5', 'r') as f:
    gen_data = f['matrix'][:]
# Load cluster info
meta_data_gen = pd.read_csv(data_path_G+'_generated_labels.csv')
labels_gen = meta_data_gen['cell_type']
# Convert cluster labels to numeric
unique_clusters = sorted(set(labels_gen.values.flatten()))
cluster_dict = {t: i for i, t in enumerate(unique_clusters)}
labels_gen = np.vectorize(lambda t: cluster_dict[t])(labels_gen).flatten()

In [24]:
labels_gen

array([ 0,  1,  1, ..., 13, 13, 13])