# Tutorial: Using the `DataGenerator` class

This tutorial demonstrates how to use the `DataGenerator` class to create training data for sequential sampling models (SSMs). 
We'll discuss a bit of the overall workflow design here as well.

In [1]:
# Standard library imports
import numpy as np
import matplotlib.pyplot as plt
import time

# SSMS imports
from ssms.dataset_generators.lan_mlp import DataGenerator
from ssms.config import model_config, get_default_generator_config
from ssms.basic_simulators.simulator_class import Simulator

# Set random seed for reproducibility
np.random.seed(42)

print("‚úì Imports successful")
print(f"‚úì Available models: {len(model_config)} predefined models")

‚úì Imports successful
‚úì Available models: 106 predefined models


## What is DataGenerator?

The `DataGenerator` orchestrates the creation of training datasets for neural network-based likelihood approximation. 

It:

- Samples parameters from valid parameter spaces
- Generates RT/choice data (via simulation or analytical methods)
- Turns the raw data into `(feature, label)` pairs for various networks we are interested in downstream
- Optionally saves outputs via `.pickle` files

The `DataGenerator` is in fact a light-weight class that puts to work a `DataPipeline`.

## DataGenerator Architecture: Modular Design

The `DataGenerator` is built on a **modular, strategy-based architecture** that allows you to customize the data generation workflow through a single injection point.

### High-Level Flow

<table style="width: 100%; max-width: 800px; margin: 30px auto; border-collapse: separate; border-spacing: 0;">
  <tr>
    <td colspan="2" style="text-align: center; padding: 20px; background-color: #f8f9fa; border: 2px solid #dee2e6; font-weight: 600; font-size: 16px; color: #495057;">
      INPUTS
    </td>
  </tr>
  <tr>
    <td style="padding: 30px; background-color: #fff3e0; border: 3px solid #ffb300; text-align: center; width: 50%;">
      <div style="font-size: 32px; margin-bottom: 10px;">‚öôÔ∏è</div>
      <div style="font-weight: 700; font-size: 17px; margin-bottom: 10px; color: #e65100;">config</div>
      <div style="font-size: 14px; color: #555; line-height: 1.6; margin-bottom: 12px;">
        <strong>Polymorphic parameter</strong>
      </div>
      <div style="background-color: #fff8e1; padding: 12px; border-radius: 5px; margin-bottom: 8px; border-left: 4px solid #ff6f00;">
        <div style="font-weight: 600; font-size: 13px; color: #e65100; margin-bottom: 5px;">Option 1: Dict (99%)</div>
        <div style="font-size: 12px; color: #666;">generator_config dict<br/>(samples, bins, estimator_type)</div>
      </div>
      <div style="background-color: #f3e5f5; padding: 12px; border-radius: 5px; border-left: 4px solid #7b1fa2;">
        <div style="font-weight: 600; font-size: 13px; color: #7b1fa2; margin-bottom: 5px;">Option 2: Strategy (1%)</div>
        <div style="font-size: 12px; color: #666;">Custom strategy object<br/>(advanced customization)</div>
      </div>
    </td>
    <td style="padding: 30px; background-color: #e7f3ff; border: 3px solid #42a5f5; text-align: center; width: 50%;">
      <div style="font-size: 32px; margin-bottom: 10px;">üìã</div>
      <div style="font-weight: 700; font-size: 17px; margin-bottom: 10px; color: #0066cc;">model_config</div>
      <div style="font-size: 14px; color: #555; line-height: 1.6; margin-bottom: 12px;">
        <strong>Model specification</strong>
      </div>
      <div style="font-size: 13px; color: #666; line-height: 1.5;">
        Which SSM model?<br/>
        <span style="font-size: 12px;">(DDM, angle, LCA...)</span><br/><br/>
        Parameter bounds<br/>
        Boundary/drift functions
      </div>
    </td>
  </tr>
  <tr>
    <td colspan="2" style="text-align: center; padding: 15px; background-color: white; font-size: 40px; color: #999;">
      ‚¨á
    </td>
  </tr>
  <tr>
    <td colspan="2" style="padding: 35px; background-color: #fff8e1; border: 3px solid #ffb300; text-align: center;">
      <div style="font-size: 36px; margin-bottom: 12px;">üìä</div>
      <div style="font-weight: 700; font-size: 24px; margin-bottom: 8px; color: #f57c00;">DataGenerator</div>
      <div style="font-size: 14px; color: #666;">Orchestrates the complete data generation workflow</div>
    </td>
  </tr>
  <tr>
    <td colspan="2" style="text-align: center; padding: 15px; background-color: white; font-size: 40px; color: #999;">
      ‚¨á
    </td>
  </tr>
  <tr>
    <td colspan="2" style="text-align: center; padding: 20px; background-color: #f8f9fa; border: 2px solid #dee2e6; font-weight: 600; font-size: 16px; color: #495057;">
      OUTPUT
    </td>
  </tr>
  <tr>
    <td colspan="2" style="padding: 20px; background-color: #e8f5e9; border: 2px solid #81c784;">
      <div style="text-align: center; font-size: 28px; margin-bottom: 8px;">üì¶</div>
      <div style="text-align: center; font-weight: 700; font-size: 18px; margin-bottom: 15px; color: #2e7d32;">Training Data</div>
      <table style="width: 100%; border-collapse: collapse;">
        <tr>
          <td style="padding: 12px; background-color: white; border: 1px solid #ddd; border-left: 4px solid #42a5f5; width: 50%;">
            <strong style="color: #1976d2;">lan_data</strong><br/>
            <span style="font-size: 12px; color: #666;">RT/choice pairs for training</span>
          </td>
          <td style="padding: 12px; background-color: white; border: 1px solid #ddd; border-left: 4px solid #ef5350; width: 50%;">
            <strong style="color: #c62828;">lan_labels</strong><br/>
            <span style="font-size: 12px; color: #666;">Log-likelihoods</span>
          </td>
        </tr>
        <tr>
          <td style="padding: 12px; background-color: white; border: 1px solid #ddd; border-left: 4px solid #66bb6a;">
            <strong style="color: #388e3c;">theta</strong><br/>
            <span style="font-size: 12px; color: #666;">Parameter values</span>
          </td>
          <td style="padding: 12px; background-color: white; border: 1px solid #ddd; border-left: 4px solid #ab47bc;">
            <strong style="color: #7b1fa2;">cpn_labels</strong><br/>
            <span style="font-size: 12px; color: #666;">Choice probabilities</span>
          </td>
        </tr>
        <tr>
          <td colspan="2" style="padding: 10px; background-color: #f5f5f5; border: 1px solid #ddd; text-align: center; color: #999; font-style: italic; font-size: 14px;">
            ... and more
          </td>
        </tr>
      </table>
    </td>
  </tr>
</table>

### The Single Injection Point: Pipeline

#### **Pipeline concept?**

The `generation_pipeline` parameter is the **single customization point** for the `DataGenerator`. It encapsulates the entire data generation workflow, from parameter sampling to the composition of the final training data.

#### **Built-in Strategies**:

**1. `SimulationPipeline`** (default for `estimator_type='kde'`)
- Uses the core native simulators of the `ssm-simulators` package
- Goes through a KDE as the Likelihood Estimator (as in [our LAN paper](https://elifesciences.org/articles/65074) )

**2. `PyDDMPipeline`** (used for `estimator_type='pyddm'`)
- Uses analytical Fokker-Planck PDE solver via `PyDDM` package
- Model choice somewhat more limited (2-choice, Gaussian noise etc.)
- Can be much faster than KDE based strategy where it applies

**3. Custom Strategy**
- Implement your own `DataPipelineProtocol`
- Full control over parameter sampling, simulation, and data structuring

#### **How to Control It?**

 Pass a config dict as the first argument

## Basic Example

In [None]:
model_name = "ddm"
estimator_type = "kde"

In [None]:
# 
my_model_config = model_config[model_name]# For simulation-based (default)

pipeline_config = get_default_generator_config()
pipeline_config['estimator']['type'] = estimator_type  # or just omit, this is the default

gen = DataGenerator(pipeline_config,
                    my_model_config
                    )

### Generate training data

In [None]:
# Generate data
print("Generating training data...")
start_time = time.time()

training_data = gen.generate_data_training()

elapsed = time.time() - start_time
total_trials = pipeline_config['pipeline']['n_parameter_sets'] * pipeline_config['simulator']['n_samples']
print(f"‚úì Data generation complete in {elapsed:.2f} seconds")
print(f"  ({total_trials / elapsed:.0f} trials/sec)")

### Inspect output

In [None]:
print("Output structure:")
print(f"  Keys: {list(training_data.keys())}")

print("\nData shapes:")
for key, value in training_data.items():
    if value is not None:
        if isinstance(value, np.ndarray):
            print(f"  {key:30s}: {value.shape}")
        elif isinstance(value, dict):
            print(f"  {key:30s}: {value}")
    else:
        print(f"  {key}")
        print(f"  {key:30s}: None")

print("\n--- Understanding the components ---")
print("data: RT/choice pairs [n_parameter_sets, n_samples, 2]")
print("theta: Parameter values [n_parameter_sets, n_params]")
print("choice_p: Choice probabilities for each trial")
print("cpn_*, opn_*, gonogo_*: Additional training labels")

## Advanced Example

In [None]:
from ssms.dataset_generators.estimator_builders.kde_builder import KDEEstimatorBuilder

from ssms.dataset_generators.pipelines import SimulationPipeline
from ssms.dataset_generators.strategies import MixtureTrainingStrategy

model_config_advanced = model_config[model_name]# For simulation-based (default)

pipeline_config_advanced = get_default_generator_config()
pipeline_config_advanced['estimator']['type'] = "kde"  # or just omit, this is the default



# Create custom pipeline with specialized components
custom_pipeline = SimulationPipeline(
    generator_config=pipeline_config_advanced,
    model_config=model_config_advanced,
    estimator_builder=KDEEstimatorBuilder,
    training_strategy=MixtureTrainingStrategy,
)

# Pass the strategy directly as the first positional argument
gen_advanced = DataGenerator(
    config = custom_pipeline
    ) # no need for `model_config` it's an attribute of the pipeline

### Default Behavior

When you pass a **config dict** as the first argument, `DataGenerator` auto-creates the appropriate strategy based on `estimator_type`:

- `estimator_type='kde'` ‚Üí `SimulationPipeline`
- `estimator_type='pyddm'` ‚Üí `PyDDMPipeline`

This should cover most basic use cases!
