# The TrainingDataGenerator class

This tutorial demonstrates how to use the `TrainingDataGenerator` class to create training data for sequential sampling models (SSMs). 
We'll discuss a bit of the overall workflow design here as well.

In [1]:
# Standard library imports
import numpy as np
import matplotlib.pyplot as plt
import time

# SSMS imports
from ssms.dataset_generators.lan_mlp import TrainingDataGenerator
from ssms.config import model_config, get_default_generator_config
from ssms.basic_simulators.simulator_class import Simulator

# Set random seed for reproducibility
np.random.seed(42)

print("‚úì Imports successful")
print(f"‚úì Available models: {len(model_config)} predefined models")

‚úì Imports successful
‚úì Available models: 106 predefined models


## What is TrainingDataGenerator?

The `TrainingDataGenerator` orchestrates the creation of training datasets for neural network-based likelihood approximation. 

It:

- Samples parameters from valid parameter spaces
- Generates RT/choice data (via simulation or analytical methods)
- Turns the raw data into `(feature, label)` pairs for various networks we are interested in downstream
- Optionally saves outputs via `.pickle` files

The `TrainingDataGenerator` is in fact a light-weight class that puts to work a `DataPipeline`.

### High-Level Flow

<div style="width: 100%; max-width: 700px; margin: 30px auto; padding: 0 20px; box-sizing: border-box;">
  <div style="width: 100%; text-align: center; padding: 15px; background-color: #f8f9fa; border: 2px solid #dee2e6; font-weight: 600; font-size: 16px; color: #495057; box-sizing: border-box; margin-bottom: 15px;">
    INPUTS
  </div>
  
  <div style="width: 100%; padding: 20px; background-color: #fff3e0; border: 3px solid #ffb300; text-align: center; box-sizing: border-box; margin-bottom: 15px;">
    <div style="font-size: 28px; margin-bottom: 10px;">‚öôÔ∏è</div>
    <div style="font-weight: 700; font-size: 17px; margin-bottom: 10px; color: #e65100;">config</div>
    <div style="font-size: 14px; color: #555; line-height: 1.6; margin-bottom: 12px;">
      <strong>Polymorphic parameter</strong>
    </div>
    <div style="background-color: #fff8e1; padding: 10px; border-radius: 5px; margin: 8px auto; max-width: 400px; border-left: 4px solid #ff6f00; text-align: left;">
      <div style="font-weight: 600; font-size: 12px; color: #e65100; margin-bottom: 4px;">Option 1: Dict (99%)</div>
      <div style="font-size: 11px; color: #666;">generator_config dict (samples, bins, estimator_type)</div>
    </div>
    <div style="background-color: #f3e5f5; padding: 10px; border-radius: 5px; margin: 8px auto; max-width: 400px; border-left: 4px solid #7b1fa2; text-align: left;">
      <div style="font-weight: 600; font-size: 12px; color: #7b1fa2; margin-bottom: 4px;">Option 2: Strategy (1%)</div>
      <div style="font-size: 11px; color: #666;">Custom strategy object (advanced customization)</div>
    </div>
  </div>
  
  <div style="width: 100%; padding: 20px; background-color: #e7f3ff; border: 3px solid #42a5f5; text-align: center; box-sizing: border-box; margin-bottom: 20px;">
    <div style="font-size: 28px; margin-bottom: 10px;">üìã</div>
    <div style="font-weight: 700; font-size: 17px; margin-bottom: 10px; color: #0066cc;">model_config</div>
    <div style="font-size: 14px; color: #555; line-height: 1.6; margin-bottom: 8px;">
      <strong>Model specification</strong>
    </div>
    <div style="font-size: 13px; color: #666; line-height: 1.5;">
      Which SSM model? (DDM, angle, LCA...)<br/>
      Parameter bounds<br/>
      Boundary/drift functions
    </div>
  </div>
  
  <div style="text-align: center; font-size: 32px; color: #999; margin: 15px 0;">‚¨á</div>
  
  <div style="width: 100%; padding: 25px; background-color: #fff8e1; border: 3px solid #ffb300; text-align: center; margin-bottom: 20px; box-sizing: border-box;">
    <div style="font-size: 32px; margin-bottom: 12px;">üìä</div>
    <div style="font-weight: 700; font-size: 22px; margin-bottom: 8px; color: #f57c00;">TrainingDataGenerator</div>
    <div style="font-size: 14px; color: #666;">Orchestrates the complete data generation workflow</div>
  </div>
  
  <div style="text-align: center; font-size: 32px; color: #999; margin: 15px 0;">‚¨á</div>
  
  <div style="width: 100%; text-align: center; padding: 15px; background-color: #f8f9fa; border: 2px solid #dee2e6; font-weight: 600; font-size: 16px; color: #495057; box-sizing: border-box; margin-bottom: 15px;">
    OUTPUT
  </div>
  
  <div style="width: 100%; padding: 20px; background-color: #e8f5e9; border: 2px solid #81c784; box-sizing: border-box;">
    <div style="text-align: center; font-size: 28px; margin-bottom: 8px;">üì¶</div>
    <div style="text-align: center; font-weight: 700; font-size: 18px; margin-bottom: 15px; color: #2e7d32;">Training Data</div>
    <div style="padding: 10px; background-color: white; border: 1px solid #ddd; border-left: 4px solid #42a5f5; margin-bottom: 8px;">
      <strong style="color: #1976d2;">lan_data</strong><br/>
      <span style="font-size: 12px; color: #666;">RT/choice pairs for training</span>
    </div>
    <div style="padding: 10px; background-color: white; border: 1px solid #ddd; border-left: 4px solid #ef5350; margin-bottom: 8px;">
      <strong style="color: #c62828;">lan_labels</strong><br/>
      <span style="font-size: 12px; color: #666;">Log-likelihoods</span>
    </div>
    <div style="padding: 10px; background-color: white; border: 1px solid #ddd; border-left: 4px solid #66bb6a; margin-bottom: 8px;">
      <strong style="color: #388e3c;">theta</strong><br/>
      <span style="font-size: 12px; color: #666;">Parameter values</span>
    </div>
    <div style="padding: 10px; background-color: white; border: 1px solid #ddd; border-left: 4px solid #ab47bc; margin-bottom: 8px;">
      <strong style="color: #7b1fa2;">cpn_labels</strong><br/>
      <span style="font-size: 12px; color: #666;">Choice probabilities</span>
    </div>
    <div style="padding: 10px; background-color: #f5f5f5; border: 1px solid #ddd; text-align: center; color: #999; font-style: italic; font-size: 14px;">
      ... and more
    </div>
  </div>
</div>

### The Single Injection Point: Pipeline

#### **Pipeline concept?**

The `generation_pipeline` parameter is the **single customization point** for the `TrainingDataGenerator`. It encapsulates the entire data generation workflow, from parameter sampling to the composition of the final training data.

#### **Built-in Strategies**:

**1. `SimulationPipeline`** (default for `estimator_type='kde'`)
- Uses the core native simulators of the `ssm-simulators` package
- Goes through a KDE as the Likelihood Estimator (as in [our LAN paper](https://elifesciences.org/articles/65074) )

**2. `PyDDMPipeline`** (used for `estimator_type='pyddm'`)
- Uses analytical Fokker-Planck PDE solver via `PyDDM` package
- Model choice somewhat more limited (2-choice, Gaussian noise etc.)
- Can be much faster than KDE based strategy where it applies

**3. Custom Strategy**
- Implement your own `DataPipelineProtocol`
- Full control over parameter sampling, simulation, and data structuring

#### **How to Control It?**

 Pass a config dict as the first argument

## Basic Example

In [2]:
model_name = "ddm"
estimator_type = "kde"

In [3]:
# 
my_model_config = model_config[model_name]# For simulation-based (default)

pipeline_config = get_default_generator_config()
pipeline_config["estimator"]['type'] = estimator_type  # or just omit, this is the default
pipeline_config["pipeline"]["n_parameter_sets"] = 100
pipeline_config["simulator"]["n_samples"] = 5000
print(pipeline_config)
gen = TrainingDataGenerator(pipeline_config,
                    my_model_config
                    )

{'pipeline': {'n_parameter_sets': 100, 'n_subruns': 10, 'n_cpus': 'all', 'n_parameter_sets_rejected': 100}, 'estimator': {'type': 'kde', 'displace_t': False}, 'training': {'mixture_probabilities': [0.8, 0.1, 0.1], 'n_samples_per_param': 1000, 'separate_response_channels': False, 'negative_rt_cutoff': -66.77497}, 'simulator': {'delta_t': 0.001, 'max_t': 20.0, 'n_samples': 5000, 'smooth_unif': True, 'filters': {'mode': 20, 'choice_cnt': 0, 'mean_rt': 17, 'std': 0, 'mode_cnt_rel': 0.95}}, 'output': {'folder': 'data/lan_mlp/', 'pickle_protocol': 4, 'nbins': 0}, 'model': 'ddm', 'bin_pointwise': False}


### Generate training data

In [4]:
# Generate data
print("Generating training data...")
start_time = time.time()

training_data = gen.generate_data_training()

elapsed = time.time() - start_time
total_trials = pipeline_config['pipeline']['n_parameter_sets'] * pipeline_config['simulator']['n_samples']
print(f"‚úì Data generation complete in {elapsed:.2f} seconds")
print(f"  ({total_trials / elapsed:.0f} trials/sec)")

Generating training data...
‚úì Data generation complete in 3.02 seconds
  (165428 trials/sec)


### Inspect output

In [5]:
print("Output structure:")
print(f"  Keys: {list(training_data.keys())}")

print("\nData shapes:")
for key, value in training_data.items():
    if value is not None:
        if isinstance(value, np.ndarray):
            print(f"  {key:30s}: {value.shape}")
        elif isinstance(value, dict):
            print(f"  {key:30s}: {value}")
    else:
        print(f"  {key}")
        print(f"  {key:30s}: None")

print("\n--- Understanding the components ---")
print("data: RT/choice pairs [n_parameter_sets, n_samples, 2]")
print("theta: Parameter values [n_parameter_sets, n_params]")
print("choice_p: Choice probabilities for each trial")
print("cpn_*, opn_*, gonogo_*: Additional training labels")

Output structure:
  Keys: ['cpn_no_omission_data', 'lan_labels', 'opn_data', 'gonogo_labels', 'gonogo_data', 'binned_128', 'cpn_labels', 'opn_labels', 'lan_data', 'binned_256', 'cpn_data', 'cpn_no_omission_labels', 'theta', 'generator_config', 'model_config']

Data shapes:
  cpn_no_omission_data          : (100, 4)
  lan_labels                    : (100000,)
  opn_data                      : (100, 4)
  gonogo_labels                 : (100, 1)
  gonogo_data                   : (100, 4)
  binned_128                    : (100, 128, 2)
  cpn_labels                    : (100,)
  opn_labels                    : (100, 1)
  lan_data                      : (100000, 6)
  binned_256                    : (100, 256, 2)
  cpn_data                      : (100, 4)
  cpn_no_omission_labels        : (100,)
  theta                         : (100, 4)
  generator_config              : {'pipeline': {'n_parameter_sets': 100, 'n_subruns': 10, 'n_cpus': 12, 'n_parameter_sets_rejected': 100}, 'estimator': {'typ

## Advanced Example

In [6]:
from ssms.dataset_generators.estimator_builders.kde_builder import KDEEstimatorBuilder

from ssms.dataset_generators.pipelines import SimulationPipeline
from ssms.dataset_generators.strategies import MixtureTrainingStrategy

model_config_advanced = model_config[model_name]# For simulation-based (default)

pipeline_config_advanced = get_default_generator_config()
pipeline_config_advanced["estimator"]["type"] = "kde"  # or just omit, this is the default
pipeline_config_advanced["pipeline"]["n_parameter_sets"] = 100
pipeline_config_advanced["simulator"]["n_samples"] = 5000

# Create custom pipeline with specialized components
custom_pipeline = SimulationPipeline(
    generator_config=pipeline_config_advanced,
    model_config=model_config_advanced,
    estimator_builder=KDEEstimatorBuilder,
    training_strategy=MixtureTrainingStrategy,
)

# Pass the strategy directly as the first positional argument
gen_advanced = TrainingDataGenerator(
    config = custom_pipeline
    ) # no need for `model_config` it's an attribute of the pipeline

### Advanced Example: Generate Training Data

In [7]:
# Generate data
print("Generating training data...")
start_time = time.time()

training_data = gen.generate_data_training()

elapsed = time.time() - start_time
total_trials = pipeline_config['pipeline']['n_parameter_sets'] * pipeline_config['simulator']['n_samples']
print(f"‚úì Data generation complete in {elapsed:.2f} seconds")
print(f"  ({total_trials / elapsed:.0f} trials/sec)")

Generating training data...
‚úì Data generation complete in 2.64 seconds
  (189748 trials/sec)


### Default Behavior

When you pass a **config dict** as the first argument, `TrainingDataGenerator` auto-creates the appropriate strategy based on `estimator_type`:

- `estimator_type='kde'` ‚Üí `SimulationPipeline`
- `estimator_type='pyddm'` ‚Üí `PyDDMPipeline`

This should cover most basic use cases!
