In [None]:
%pip install -U "jax[cuda]"

In [None]:
%pip install -U "git+https://github.com/briancf1/QDax.git#egg=qdax[examples]"

In [None]:
# Clone the repository to get experiment scripts
!git clone https://github.com/briancf1/QDax.git
%cd QDax/examples

## Humanoid Omni Environment - Full 31-Seed Study

**Environment**: humanoid_omni (biped, 17 DoF, final xy position descriptor)
**Purpose**: Test Competition-GA generalization to complex bipedal locomotion
**Timeline**: Start at 3:15am, finish ~12:15pm

In [None]:
import os
import json
import time
from datetime import datetime
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import functools
import warnings
warnings.filterwarnings('ignore')

import jax
import jax.numpy as jnp

from qdax.core.dns_ga import DominatedNoveltySearchGA
from qdax.core.dns import DominatedNoveltySearch
import qdax.tasks.brax as environments
from qdax.tasks.brax.env_creators import scoring_function_brax_envs as scoring_function
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.metrics import CSVLogger, default_qd_metrics

# Configure plotting
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 6)

# Create experiment logs directory
os.makedirs("seed_variability_logs_humanoid_omni", exist_ok=True)

print("Setup complete!")
print(f"Current directory: {os.getcwd()}")
print(f"JAX devices: {jax.devices()}")
print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

## Generate Random Seeds (SAME as ant_omni for consistency)

In [None]:
# Use SAME 31 seeds as ant_omni for direct comparison
np.random.seed(2024)
RANDOM_SEEDS = np.random.randint(1, 100000, size=31).tolist()

print("="*80)
print("USING SAME 31 RANDOM SEEDS AS ANT_OMNI")
print("="*80)
print(f"Seeds: {RANDOM_SEEDS[:10]}... (showing first 10)")
print(f"Total: {len(RANDOM_SEEDS)} seeds")
print("="*80)

# Save seeds
with open('seed_variability_logs_humanoid_omni/random_seeds.json', 'w') as f:
    json.dump({'seeds': RANDOM_SEEDS, 'generation_seed': 2024}, f, indent=2)

## Experiment Configuration - Humanoid Omni

In [None]:
FIXED_PARAMS = {
    'batch_size': 100,
    'env_name': 'humanoid_omni',  # Bipedal humanoid, 17 DoF, xy position descriptor
    'episode_length': 100,
    'num_iterations': 3000,
    'policy_hidden_layer_sizes': (64, 64),
    'population_size': 1024,
    'k': 3,
    'line_sigma': 0.05,
    'iso_sigma': 0.01,
}

MAIN_CONFIGS = [
    # Baseline
    {
        'type': 'baseline',
        'name': 'DNS_baseline',
        'g_n': None,
        'num_ga_children': None,
        'num_ga_generations': None,
    },
    # Frequent GA
    {
        'type': 'dns-ga',
        'name': 'DNS-GA_g300_gen2',
        'g_n': 300,
        'num_ga_children': 2,
        'num_ga_generations': 2,
    },
    # Rare but deep GA
    {
        'type': 'dns-ga',
        'name': 'DNS-GA_g1000_gen4',
        'g_n': 1000,
        'num_ga_children': 2,
        'num_ga_generations': 4,
    },
]

print("="*80)
print("HUMANOID OMNI CONFIGURATION")
print("="*80)
print(f"\nEnvironment: {FIXED_PARAMS['env_name']}")
print(f"  Type: Bipedal humanoid")
print(f"  DoF: 17 (most complex)")
print(f"  Descriptor: Final xy position (2D)")
print(f"  Iterations: {FIXED_PARAMS['num_iterations']}")
print(f"  Seeds: {len(RANDOM_SEEDS)}")

print(f"\nConfigurations:")
for config in MAIN_CONFIGS:
    if config['type'] == 'baseline':
        print(f"  â€¢ {config['name']}: No GA")
    else:
        ga_calls = FIXED_PARAMS['num_iterations'] // config['g_n']
        print(f"  â€¢ {config['name']}: {ga_calls} GA calls")

total_exp = len(MAIN_CONFIGS) * len(RANDOM_SEEDS)
print(f"\nTotal Experiments: {total_exp}")
print(f"Estimated time (2-parallel): ~{(total_exp / 2) * 13.5 / 60:.1f} hours")
print(f"Expected completion: ~12:15pm")
print("="*80)

## Helper Functions (same as ant_omni)

In [None]:
def calculate_ga_overhead_evals(g_n, num_iterations, population_size, num_ga_children, num_ga_generations):
    """Calculate total evaluations performed by Competition-GA."""
    if g_n is None or g_n >= num_iterations:
        return 0, 0, 0
    
    num_ga_calls = num_iterations // g_n
    if num_ga_children == 1:
        offspring_per_call = population_size * num_ga_generations
    else:
        offspring_per_call = population_size * num_ga_children * (num_ga_children**num_ga_generations - 1) // (num_ga_children - 1)
    evals_per_ga_call = offspring_per_call
    total_ga_evals = num_ga_calls * evals_per_ga_call
    return total_ga_evals, num_ga_calls, evals_per_ga_call


def setup_environment(env_name, episode_length, policy_hidden_layer_sizes, batch_size, seed):
    """Initialize environment and policy network."""
    env = environments.create(env_name, episode_length=episode_length)
    reset_fn = jax.jit(env.reset)
    key = jax.random.key(seed)
    
    policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
    policy_network = MLP(
        layer_sizes=policy_layer_sizes,
        kernel_init=jax.nn.initializers.lecun_uniform(),
        final_activation=jnp.tanh,
    )
    
    key, subkey = jax.random.split(key)
    keys = jax.random.split(subkey, num=batch_size)
    fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
    init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
    
    return env, policy_network, reset_fn, init_variables, key


def create_scoring_function(env, policy_network, reset_fn, episode_length, env_name):
    """Create scoring function for fitness evaluation."""
    def play_step_fn(env_state, policy_params, key):
        actions = policy_network.apply(policy_params, env_state.obs)
        state_desc = env_state.info["state_descriptor"]
        next_state = env.step(env_state, actions)
        
        transition = QDTransition(
            obs=env_state.obs,
            next_obs=next_state.obs,
            rewards=next_state.reward,
            dones=next_state.done,
            actions=actions,
            truncations=next_state.info["truncation"],
            state_desc=state_desc,
            next_state_desc=next_state.info["state_descriptor"],
        )
        return next_state, policy_params, key, transition
    
    descriptor_extraction_fn = environments.descriptor_extractor[env_name]
    scoring_fn = functools.partial(
        scoring_function,
        episode_length=episode_length,
        play_reset_fn=reset_fn,
        play_step_fn=play_step_fn,
        descriptor_extractor=descriptor_extraction_fn,
    )
    
    return scoring_fn


def create_mutation_function(iso_sigma):
    """Create mutation function for Competition-GA."""
    def competition_ga_mutation_fn(genotype, key):
        genotype_flat, tree_def = jax.tree_util.tree_flatten(genotype)
        num_leaves = len(genotype_flat)
        keys = jax.random.split(key, num_leaves)
        keys_tree = jax.tree_util.tree_unflatten(tree_def, keys)
        
        def add_noise(x, k):
            return x + jax.random.normal(k, shape=x.shape) * iso_sigma
        
        mutated = jax.tree_util.tree_map(add_noise, genotype, keys_tree)
        return mutated
    
    return competition_ga_mutation_fn

print("Helper functions loaded!")

## Single Experiment Runner

In [None]:
def run_single_experiment(config, seed, fixed_params):
    """Run a single experiment with given config and seed."""
    exp_name = f"{config['name']}_seed{seed}"
    
    env, policy_network, reset_fn, init_variables, key = setup_environment(
        fixed_params['env_name'],
        fixed_params['episode_length'],
        fixed_params['policy_hidden_layer_sizes'],
        fixed_params['batch_size'],
        seed
    )
    
    scoring_fn = create_scoring_function(env, policy_network, reset_fn, 
                                        fixed_params['episode_length'],
                                        fixed_params['env_name'])
    
    reward_offset = environments.reward_offset[fixed_params['env_name']]
    metrics_function = functools.partial(
        default_qd_metrics,
        qd_offset=reward_offset * fixed_params['episode_length'],
    )
    
    variation_fn = functools.partial(
        isoline_variation,
        iso_sigma=fixed_params['iso_sigma'],
        line_sigma=fixed_params['line_sigma']
    )
    
    mixing_emitter = MixingEmitter(
        mutation_fn=None,
        variation_fn=variation_fn,
        variation_percentage=1.0,
        batch_size=fixed_params['batch_size']
    )
    
    if config['type'] == 'baseline':
        algorithm = DominatedNoveltySearch(
            scoring_function=scoring_fn,
            emitter=mixing_emitter,
            metrics_function=metrics_function,
            population_size=fixed_params['population_size'],
            k=fixed_params['k'],
        )
    else:
        mutation_fn = create_mutation_function(fixed_params['iso_sigma'])
        algorithm = DominatedNoveltySearchGA(
            scoring_function=scoring_fn,
            emitter=mixing_emitter,
            metrics_function=metrics_function,
            population_size=fixed_params['population_size'],
            k=fixed_params['k'],
            g_n=config['g_n'],
            num_ga_children=config['num_ga_children'],
            num_ga_generations=config['num_ga_generations'],
            mutation_fn=mutation_fn,
        )
    
    key, subkey = jax.random.split(key)
    repertoire, emitter_state, init_metrics = algorithm.init(init_variables, subkey)
    
    log_period = 100
    num_loops = fixed_params['num_iterations'] // log_period
    
    metrics = {key: jnp.array([]) for key in ["iteration", "qd_score", "coverage", "max_fitness", "time"]}
    init_metrics = jax.tree.map(lambda x: jnp.array([x]) if x.shape == () else x, init_metrics)
    init_metrics["iteration"] = jnp.array([0], dtype=jnp.int32)
    init_metrics["time"] = jnp.array([0.0])
    metrics = jax.tree.map(
        lambda metric, init_metric: jnp.concatenate([metric, init_metric], axis=0),
        metrics, init_metrics
    )
    
    log_filename = os.path.join("seed_variability_logs_humanoid_omni", f"{exp_name}_logs.csv")
    csv_logger = CSVLogger(log_filename, header=list(metrics.keys()))
    csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))
    
    if config['type'] == 'baseline':
        scan_state = (repertoire, emitter_state, key)
    else:
        scan_state = (repertoire, emitter_state, key, 1)
    
    start_time_total = time.time()
    
    for i in range(num_loops):
        start_time = time.time()
        
        scan_state, current_metrics = jax.lax.scan(
            algorithm.scan_update,
            scan_state,
            (),
            length=log_period,
        )
        
        timelapse = time.time() - start_time
        
        current_metrics["iteration"] = jnp.arange(
            1 + log_period * i, 1 + log_period * (i + 1), dtype=jnp.int32
        )
        current_metrics["time"] = jnp.repeat(timelapse, log_period)
        metrics = jax.tree.map(
            lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0),
            metrics, current_metrics
        )
        
        csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))
    
    total_time = time.time() - start_time_total
    
    ga_total_evals, ga_num_calls, ga_evals_per_call = calculate_ga_overhead_evals(
        config.get('g_n'), fixed_params['num_iterations'], fixed_params['population_size'],
        config.get('num_ga_children'), config.get('num_ga_generations')
    )
    
    # Save final repertoire for behavior space visualization
    # Extract final repertoire state from scan_state
    if config['type'] == 'baseline':
        final_repertoire = scan_state[0]  # (repertoire, emitter_state, key)
    else:
        final_repertoire = scan_state[0]  # (repertoire, emitter_state, key, generation_counter)
    
    repertoire_file = os.path.join("seed_variability_logs_humanoid_omni", f"{exp_name}_repertoire.npz")
    jnp.savez(repertoire_file,
        descriptors=final_repertoire.descriptors,
        fitnesses=final_repertoire.fitnesses
    )
    
    return {
        'config_name': config['name'],
        'config_type': config['type'],
        'seed': seed,
        'g_n': config.get('g_n'),
        'num_ga_generations': config.get('num_ga_generations'),
        'final_qd_score': float(metrics['qd_score'][-1]),
        'final_max_fitness': float(metrics['max_fitness'][-1]),
        'final_coverage': float(metrics['coverage'][-1]),
        'total_time': total_time,
        'ga_overhead_evals': ga_total_evals,
        'log_file': log_filename,
        'repertoire_file': repertoire_file,
    }

print("Experiment runner ready!")

## Build Queue and Run Experiments (2-Parallel with ipyparallel)

In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

print("="*80)
print(f"BUILDING EXPERIMENT QUEUE - {timestamp}")
print("="*80)

experiment_queue = []
exp_num = 0

for config in MAIN_CONFIGS:
    for seed in RANDOM_SEEDS:
        exp_num += 1
        experiment_queue.append((exp_num, exp_num, config, seed))

print(f"\nTotal experiments: {len(experiment_queue)}")
print(f"Execution: 2-parallel with ipyparallel")
print(f"Estimated time: ~{len(experiment_queue) / 2 * 13.5 / 60:.1f} hours")
print("="*80)

In [None]:
print("\n" + "="*80)
print("RUNNING HUMANOID_OMNI EXPERIMENTS (2-PARALLEL)")
print("="*80)
print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Expected completion: ~12:15pm")
print("="*80)

start_time_all = time.time()

import ipyparallel as ipp

cluster = ipp.Cluster(n=2)
rc = cluster.start_and_connect_sync()

print(f"âœ“ Cluster started with {len(rc)} engines")

rc[:].execute("""
import jax
import jax.numpy as jnp
import functools
import time
import os
from qdax.core.dns_ga import DominatedNoveltySearchGA
from qdax.core.dns import DominatedNoveltySearch
import qdax.tasks.brax as environments
from qdax.tasks.brax.env_creators import scoring_function_brax_envs as scoring_function
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.metrics import CSVLogger, default_qd_metrics
""").wait()

rc[:].push({
    'FIXED_PARAMS': FIXED_PARAMS,
    'setup_environment': setup_environment,
    'create_scoring_function': create_scoring_function,
    'create_mutation_function': create_mutation_function,
    'calculate_ga_overhead_evals': calculate_ga_overhead_evals,
    'run_single_experiment': run_single_experiment
}).wait()

print("âœ“ Engines initialized")

def run_experiment_wrapper(exp_tuple):
    exp_num, total_exp, config, seed = exp_tuple
    try:
        result = run_single_experiment(config, seed, FIXED_PARAMS)
        result['exp_num'] = exp_num
        return ('success', result)
    except Exception as e:
        return ('error', {'config_name': config['name'], 'seed': seed, 'error': str(e)})

rc[:].push({'run_experiment_wrapper': run_experiment_wrapper}).wait()

lview = rc.load_balanced_view()

print("Submitting all experiments...")
async_results = []
for exp_tuple in experiment_queue:
    ar = lview.apply_async(run_experiment_wrapper, exp_tuple)
    async_results.append(ar)

print(f"âœ“ Submitted {len(async_results)} experiments")
print("\nMonitoring progress...")

all_results = []
errors = []
completed_count = 0
last_update = time.time()

while completed_count < len(async_results):
    for ar in async_results:
        if ar.ready() and not hasattr(ar, '_collected'):
            ar._collected = True
            status, result = ar.result()
            
            if status == 'success':
                all_results.append(result)
                print(f"  âœ“ Completed: {result['config_name']}, seed={result['seed']}, QD={result['final_qd_score']:.1f}")
            else:
                errors.append(result)
                print(f"  âœ— Failed: {result['config_name']}, seed={result['seed']}")
            
            completed_count += 1
    
    if time.time() - last_update > 10:
        elapsed = time.time() - start_time_all
        pct = completed_count / len(experiment_queue) * 100
        if completed_count > 0:
            avg_time = elapsed / completed_count
            remaining_time = (len(experiment_queue) - completed_count) * avg_time / 3600
            print(f"ðŸ“Š Progress: {completed_count}/{len(experiment_queue)} ({pct:.1f}%) | Elapsed: {elapsed/60:.1f}m | Remaining: ~{remaining_time:.2f}h")
        last_update = time.time()
    
    time.sleep(2)

cluster.stop_cluster_sync()
print("\nâœ“ Cluster stopped")

total_time = time.time() - start_time_all

print("\n" + "="*80)
print("HUMANOID_OMNI EXPERIMENTS COMPLETE!")
print("="*80)
print(f"End time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total time: {total_time / 60:.1f} minutes ({total_time / 3600:.2f} hours)")
print(f"Successful: {len(all_results)}/{len(experiment_queue)}")
print(f"Failed: {len(errors)}")

if errors:
    print("\nErrors:")
    for error in errors:
        print(f"  â€¢ {error['config_name']}, seed={error['seed']}")

results_file = f"seed_variability_logs_humanoid_omni/all_results_{timestamp}.json"
with open(results_file, 'w') as f:
    json.dump({
        'results': all_results,
        'errors': errors,
        'total_time': total_time,
        'timestamp': timestamp,
        'environment': 'humanoid_omni',
    }, f, indent=2)

print(f"\nResults saved to: {results_file}")
print("="*80)

## Quick Results Summary

In [None]:
if len(all_results) > 0:
    df = pd.DataFrame(all_results)
    print("="*80)
    print("HUMANOID_OMNI RESULTS SUMMARY")
    print("="*80)
    print(f"\nTotal experiments: {len(df)}")
    print("\nFinal QD Scores by Configuration:")
    print(df.groupby('config_name')['final_qd_score'].agg(['mean', 'std', 'min', 'max']).round(2))
    print("\n" + "="*80)
    print("âœ“ Results ready for integration with ant_omni and walker2d data")
    print("="*80)

In [None]:
import os
from google.colab import drive
from datetime import datetime

# --- 1. Mount your Google Drive ---
# This will pop up an authorization window the first time.
drive.mount('/content/drive')

# --- 2. Create a unique filename with a timestamp ---
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
zip_filename = f"QDax_backup_{timestamp}.zip"
drive_save_path = f"/content/drive/MyDrive/Colab_Backups/{zip_filename}"

# Create the backup directory in your Drive if it doesn't exist
os.makedirs("/content/drive/MyDrive/Colab_Backups", exist_ok=True)

# --- 3. Zip the directory and save it to Google Drive ---
# -q = quiet (no file list)
# -r = recursive (include all subdirectories)
print(f"Zipping /content/QDax/ ...")
!zip -q -r {drive_save_path} /content/QDax/

print(f"âœ… Successfully saved backup to: {drive_save_path}")