# Conditional Parameters: Making Spaces Context-Aware

## Overview

This notebook focuses on **conditional parameters** - one of SpaX's most powerful features. Conditional parameters only exist or change based on the values of other parameters.

**What you'll learn:**
- Simple field conditions (if-then logic)
- Different condition types (equality, comparison, membership)
- Composite conditions (AND/OR/NOT)
- Nested conditions (conditions within conditions)
- Multi-field dependencies (conditions on multiple parameters)
- Real-world use cases

**Prerequisites:**
- Basic understanding of SpaX `Config` (if not, see notebook 00 first)
- Familiarity with Python type hints

**Why conditional parameters matter:**
- **Cleaner configs**: No invalid parameter combinations
- **Efficient HPO**: Don't waste trials on irrelevant parameters
- **Type safety**: Enforced dependencies between parameters
- **Complex logic**: Chain and nest conditions for sophisticated constraints

Let's dive in with a simple example.

In [1]:
# Import SpaX
import spax as sp


# The classic example: dropout only when enabled
class SimpleConfig(sp.Config):
    """Configuration with a simple conditional parameter."""

    # First, a boolean flag
    use_dropout: bool

    # Then, a parameter that depends on it
    dropout_rate: float = sp.Conditional(
        sp.FieldCondition("use_dropout", sp.EqualsTo(True)),
        true=sp.Float(gt=0.0, lt=0.5),  # When use_dropout=True
        false=0.0,  # When use_dropout=False
    )


# Sample some configurations
print("🎲 Sampling configurations:\n")

for i in range(4):
    config = SimpleConfig.random(seed=100 + i)
    print(
        f"Sample {i + 1}: use_dropout={config.use_dropout}, dropout_rate={config.dropout_rate}"
    )

🎲 Sampling configurations:

Sample 1: use_dropout=True, dropout_rate=0.22746350226602527
Sample 2: use_dropout=False, dropout_rate=0.0
Sample 3: use_dropout=True, dropout_rate=0.3079826435149187
Sample 4: use_dropout=False, dropout_rate=0.0


## Understanding Conditional Spaces

Let's break down the syntax:
```python
sp.Conditional(
    condition,      # The condition to evaluate
    true=...,       # Space or value when condition is True
    false=...,      # Space or value when condition is False
)
```

**Key components:**

1. **Condition**: `sp.FieldCondition("use_dropout", sp.EqualsTo(True))`
   - Checks if the field `use_dropout` equals `True`
   
2. **True branch**: `sp.Float(gt=0.0, lt=0.5)`
   - When condition is True, this becomes a searchable space
   
3. **False branch**: `0.0`
   - When condition is False, this is a fixed value

**Important:** The field referenced in the condition (`use_dropout`) must exist in the Config class. SpaX automatically handles the dependency ordering during sampling and validation.

Let's explore other condition types:

In [2]:
# Different condition types for various use cases


class OptimizerConfig(sp.Config):
    """Configuration showing different condition types."""

    optimizer: str = sp.Categorical(["adam", "sgd", "rmsprop"])

    # Condition 1: Equality - momentum only for SGD
    momentum: float = sp.Conditional(
        sp.FieldCondition("optimizer", sp.EqualsTo("sgd")),
        true=sp.Float(ge=0.0, le=1.0),
        false=0.0,
    )

    # Condition 2: Membership - beta2 only for adam/rmsprop
    beta2: float = sp.Conditional(
        sp.FieldCondition("optimizer", sp.In(["adam", "rmsprop"])),
        true=sp.Float(ge=0.9, le=0.999),
        false=0.0,
    )


# SpaX enforces conditional constraints - invalid combinations are rejected!
print("🔒 SpaX validates conditional constraints:\n")

try:
    # This will FAIL - momentum should be 0.0 for adam, not 0.9
    config = OptimizerConfig(optimizer="adam", momentum=0.9, beta2=0.99)
except Exception as e:
    print(f"❌ Invalid config rejected: {type(e).__name__}")
    print("   Reason: momentum must be 0.0 when optimizer='adam'\n")

# Valid configurations
print("✅ Valid configurations:\n")

# SGD: can have momentum, beta2 must be 0.0
config_sgd = OptimizerConfig(optimizer="sgd", momentum=0.9, beta2=0.0)
print(f"SGD:     momentum={config_sgd.momentum}, beta2={config_sgd.beta2}")

# Adam: momentum must be 0.0, can have beta2
config_adam = OptimizerConfig(optimizer="adam", momentum=0.0, beta2=0.99)
print(f"Adam:    momentum={config_adam.momentum}, beta2={config_adam.beta2}")

# RMSprop: momentum must be 0.0, can have beta2
config_rms = OptimizerConfig(optimizer="rmsprop", momentum=0.0, beta2=0.95)
print(f"RMSprop: momentum={config_rms.momentum}, beta2={config_rms.beta2}")

print("\n🎲 Random sampling respects conditions automatically:")
for i in range(3):
    config = OptimizerConfig.random(seed=400 + i)
    print(
        f"  {config.optimizer:8} → momentum={config.momentum:.3f}, beta2={config.beta2:.3f}"
    )

🔒 SpaX validates conditional constraints:

❌ Invalid config rejected: ValidationError
   Reason: momentum must be 0.0 when optimizer='adam'

✅ Valid configurations:

SGD:     momentum=0.9, beta2=0.0
Adam:    momentum=0.0, beta2=0.99
RMSprop: momentum=0.0, beta2=0.95

🎲 Random sampling respects conditions automatically:
  adam     → momentum=0.000, beta2=0.925
  sgd      → momentum=0.428, beta2=0.000
  rmsprop  → momentum=0.000, beta2=0.910


## 🔒 Built-in Validation: A Key Feature

Notice what just happened:

**SpaX automatically validates conditional constraints!**

When you try to create a config with `optimizer="adam"` and `momentum=0.9`, SpaX rejects it because the condition says momentum must be `0.0` when optimizer is not SGD.

**This is powerful because:**
- ❌ Prevents invalid parameter combinations at creation time
- ✅ Catches configuration errors early (not during training!)
- ✅ Sampling (random, Optuna trials, etc.) automatically respects all conditions
- ✅ HPO libraries can't suggest invalid combinations

**Manual creation rules:**
- You must provide values that satisfy the conditions
- For conditional parameters, check what the active branch expects
- Or use SpaX's sampling methods (`random()`, `from_trial()`, etc.) to handle it automatically!

## Available Condition Types

SpaX provides a rich set of conditions for different scenarios:

### Equality Conditions
- **`EqualsTo(value)`** - Field equals a specific value
- **`NotEqualsTo(value)`** - Field does not equal a specific value

### Membership Conditions
- **`In(values)`** - Field is one of the specified values
- **`NotIn(values)`** - Field is not one of the specified values

### Comparison Conditions (Numeric)
- **`LargerThan(value, or_equals=False)`** - Field > value (or >= if or_equals=True)
- **`SmallerThan(value, or_equals=False)`** - Field < value (or <= if or_equals=True)

### Type Checking
- **`IsInstance(type_or_tuple)`** - Field is an instance of specified type(s)

### Composite Conditions
- **`And([condition1, condition2, ...])`** - All conditions must be True
- **`Or([condition1, condition2, ...])`** - At least one condition must be True
- **`Not(condition)`** - Negates a condition

### Custom Logic
- **`Lambda(func)`** - Custom function that returns bool
- **`MultiFieldLambdaCondition(fields, func)`** - Custom logic across multiple fields

Let's see some of these in action:

In [3]:
# Showcasing different condition types


class AdvancedConfig(sp.Config):
    """Configuration demonstrating various condition types."""

    num_layers: int = sp.Int(ge=1, le=10)
    optimizer: str = sp.Categorical(["adam", "sgd", "rmsprop"])

    # Comparison: Deep networks (>5 layers) need warmup
    use_warmup: bool = sp.Conditional(
        sp.FieldCondition("num_layers", sp.LargerThan(5)), true=True, false=False
    )

    # Membership: beta2 only for adam/rmsprop
    beta2: float = sp.Conditional(
        sp.FieldCondition("optimizer", sp.In(["adam", "rmsprop"])),
        true=sp.Float(ge=0.9, le=0.999),
        false=0.0,
    )

    # NotEqualsTo: momentum for non-adam optimizers
    momentum: float = sp.Conditional(
        sp.FieldCondition("optimizer", sp.NotEqualsTo("adam")),
        true=sp.Float(ge=0.0, le=1.0),
        false=0.0,
    )

    # Chained: warmup_steps depends on use_warmup
    warmup_steps: int = sp.Conditional(
        sp.FieldCondition("use_warmup", sp.EqualsTo(True)),
        true=sp.Int(ge=100, le=1000),
        false=0,
    )


print("🎯 Multiple condition types in action:\n")

for i in range(5):
    config = AdvancedConfig.random(seed=400 + i)
    print(
        f"layers={config.num_layers:2d}, opt={config.optimizer:8s} → "
        f"warmup={config.use_warmup!s:5}, warmup_steps={config.warmup_steps:4d}, "
        f"beta2={config.beta2:.3f}, momentum={config.momentum:.3f}"
    )

print("\n💡 Observations:")
print("   - layers > 5 → use_warmup=True → warmup_steps sampled")
print("   - optimizer in [adam, rmsprop] → beta2 sampled")
print("   - optimizer != adam → momentum sampled")

🎯 Multiple condition types in action:

layers= 5, opt=sgd      → warmup=False, warmup_steps=   0, beta2=0.000, momentum=0.787
layers= 8, opt=adam     → warmup=True , warmup_steps= 386, beta2=0.981, momentum=0.000
layers= 8, opt=adam     → warmup=True , warmup_steps= 515, beta2=0.964, momentum=0.000
layers= 4, opt=sgd      → warmup=False, warmup_steps=   0, beta2=0.000, momentum=0.036
layers= 2, opt=rmsprop  → warmup=False, warmup_steps=   0, beta2=0.988, momentum=0.518

💡 Observations:
   - layers > 5 → use_warmup=True → warmup_steps sampled
   - optimizer in [adam, rmsprop] → beta2 sampled
   - optimizer != adam → momentum sampled


## 🔄 Conditional Spaces: Not Just Values

So far, we've seen conditions that choose between a space and a fixed value. But conditionals can also **choose between different spaces**!

**Use cases:**
- Different ranges based on context
- Different space types based on choice (e.g., Int vs Categorical)
- Completely different parameter meanings

Let's see practical examples:

In [4]:
# Conditional spaces: Different spaces based on conditions


class NormalizationConfig(sp.Config):
    """Simple example of spaces changing based on conditions."""

    normalization: str = sp.Categorical(["batch_norm", "layer_norm", "none"])

    # Example 1: Different ranges based on choice
    # Batch norm typically needs higher momentum than layer norm
    momentum: float = sp.Conditional(
        sp.FieldCondition("normalization", sp.EqualsTo("batch_norm")),
        true=sp.Float(ge=0.9, le=0.999),  # Higher momentum for batch_norm
        false=sp.Float(ge=0.5, le=0.9),  # Lower momentum otherwise
    )

    # Example 2: Int space vs Categorical space
    activation: str = sp.Categorical(["relu", "gelu", "custom"])

    # If using custom activation, choose the slope; otherwise choose layer count
    custom_param: int | float = sp.Conditional(
        sp.FieldCondition("activation", sp.EqualsTo("custom")),
        true=sp.Float(ge=0.01, le=0.3),  # Slope for custom activation
        false=sp.Int(ge=2, le=8),  # Number of activation layers
    )


print("🔄 Different spaces based on conditions:\n")

for i in range(6):
    config = NormalizationConfig.random(seed=500 + i)

    param_str = (
        f"{config.custom_param:.3f}"
        if isinstance(config.custom_param, float)
        else f"{config.custom_param}"
    )

    print(
        f"{config.normalization:11s}, {config.activation:5s} → "
        f"momentum={config.momentum:.3f}, custom_param={param_str}"
    )

print(
    "\n💡 Key point: Conditional branches can be ANY space type, not just fixed values!"
)

🔄 Different spaces based on conditions:

layer_norm , custom → momentum=0.602, custom_param=0.285
layer_norm , gelu  → momentum=0.756, custom_param=7
layer_norm , gelu  → momentum=0.792, custom_param=3
none       , relu  → momentum=0.651, custom_param=4
layer_norm , gelu  → momentum=0.505, custom_param=6
batch_norm , relu  → momentum=0.933, custom_param=8

💡 Key point: Conditional branches can be ANY space type, not just fixed values!


## 🔗 Nested Conditions and Field Paths

SpaX supports two powerful ways to create complex conditional logic:

### 1. Nested Conditions
You can nest `FieldCondition` inside other conditions to create sophisticated logic chains.

### 2. Dotted Field Paths
Reference nested config fields using dot notation: `"model.optimizer.type"`

This is especially useful for:
- Nested configurations
- Deep parameter dependencies
- Modular config design

Let's see both in action:

In [5]:
# Nested conditions and dotted field paths

# First, a nested config structure
class OptimizerSubConfig(sp.Config):
    name: str = sp.Categorical(["adam", "sgd", "rmsprop"])
    base_lr: float = sp.Float(ge=1e-5, le=1e-2, distribution="log")


class ModelWithNestedConfig(sp.Config):
    """Demonstrating dotted paths and nested FieldConditions."""

    optimizer: OptimizerSubConfig

    # Method 1: Using dotted path notation
    momentum_v1: float = sp.Conditional(
        sp.FieldCondition("optimizer.name", sp.EqualsTo("sgd")),
        true=sp.Float(ge=0.0, le=1.0),
        false=0.0,
    )

    # Method 2: Using nested FieldCondition (equivalent to above!)
    momentum_v2: float = sp.Conditional(
        sp.FieldCondition("optimizer", sp.FieldCondition("name", sp.EqualsTo("sgd"))),
        true=sp.Float(ge=0.0, le=1.0),
        false=0.0,
    )


print("🔗 Two ways to reference nested fields:\n")

for i in range(4):
    config = ModelWithNestedConfig.random(seed=700 + i)
    print(
        f"optimizer.name={config.optimizer.name:8s} → "
        f"momentum_v1={config.momentum_v1:.3f}, momentum_v2={config.momentum_v2:.3f}"
    )

print("\n💡 Both methods are equivalent:")
print("   • Dotted path: sp.FieldCondition('optimizer.name', ...)")
print(
    "   • Nested:      sp.FieldCondition('optimizer', sp.FieldCondition('name', ...))"
)
print("\n   Use whichever is more readable for your use case!")

print("\n" + "=" * 60)


# For complex AND/OR logic, use composite conditions
class CompositeConditionsConfig(sp.Config):
    """For complex logic, use And/Or/Not."""

    model_size: str = sp.Categorical(["small", "large"])
    dataset_size: str = sp.Categorical(["small", "large"])

    # Use large batch only when BOTH model and dataset are large
    batch_size: int = sp.Conditional(
        sp.And(
            [
                sp.FieldCondition("model_size", sp.EqualsTo("large")),
                sp.FieldCondition("dataset_size", sp.EqualsTo("large")),
            ]
        ),
        true=sp.Int(ge=128, le=512),
        false=sp.Int(ge=16, le=64),
    )


print("\n🔀 Composite conditions (And/Or) for complex logic:\n")

for model in ["small", "large"]:
    for dataset in ["small", "large"]:
        config = CompositeConditionsConfig.random(seed=800)
        while config.model_size != model or config.dataset_size != dataset:
            config = CompositeConditionsConfig.random(seed=None)
        print(
            f"model={model:5s}, dataset={dataset:5s} → batch_size={config.batch_size:3d}"
        )

🔗 Two ways to reference nested fields:

optimizer.name=adam     → momentum_v1=0.000, momentum_v2=0.000
optimizer.name=rmsprop  → momentum_v1=0.000, momentum_v2=0.000
optimizer.name=rmsprop  → momentum_v1=0.000, momentum_v2=0.000
optimizer.name=adam     → momentum_v1=0.000, momentum_v2=0.000

💡 Both methods are equivalent:
   • Dotted path: sp.FieldCondition('optimizer.name', ...)
   • Nested:      sp.FieldCondition('optimizer', sp.FieldCondition('name', ...))

   Use whichever is more readable for your use case!


🔀 Composite conditions (And/Or) for complex logic:

model=small, dataset=small → batch_size= 20
model=small, dataset=large → batch_size= 51
model=large, dataset=small → batch_size= 36
model=large, dataset=large → batch_size=152


## 🧮 Multi-Field Lambda Conditions

Sometimes you need conditional logic that depends on **multiple fields with custom rules** that can't be expressed with simple conditions.

**`MultiFieldLambdaCondition`** lets you write custom logic across multiple fields:
```python
sp.MultiFieldLambdaCondition(
    ["field1", "field2", "field3"],
    lambda data: data["field1"] + data["field2"] > data["field3"]
)
```

**Dotted paths work too for nested fields:**
```python
sp.MultiFieldLambdaCondition(
    ["field1.subfield", "field2", "field3"],
    lambda data: data["field1.subfield"] + data["field2"] > data["field3"]
)
```

**When to use it:**
- Mathematical relationships between parameters
- Complex business logic
- Conditions that can't be expressed with built-in condition types

**Important:** The lambda receives a dictionary with field names (or dotted paths) as keys.

Let's see practical examples:

In [6]:
# Multi-field lambda conditions for complex logic


class ResourceConfig(sp.Config):
    """Configuration with multi-field dependencies."""

    num_gpus: int = sp.Int(ge=1, le=8)
    batch_size_per_gpu: int = sp.Int(ge=8, le=128)
    num_workers: int = sp.Int(ge=1, le=16)

    # Total batch size must be reasonable (not too large)
    # Condition: num_gpus * batch_size_per_gpu <= 512
    use_gradient_accumulation: bool = sp.Conditional(
        sp.MultiFieldLambdaCondition(
            ["num_gpus", "batch_size_per_gpu"],
            lambda data: data["num_gpus"] * data["batch_size_per_gpu"] > 256,
        ),
        true=True,  # Need gradient accumulation for large total batch
        false=False,  # No need for small total batch
    )

    accumulation_steps: int = sp.Conditional(
        sp.FieldCondition("use_gradient_accumulation", sp.EqualsTo(True)),
        true=sp.Int(ge=2, le=8),
        false=1,
    )


print("🧮 Multi-field lambda conditions:\n")

for i in range(6):
    config = ResourceConfig.random(seed=900 + i)
    total_batch = config.num_gpus * config.batch_size_per_gpu
    print(
        f"gpus={config.num_gpus}, batch/gpu={config.batch_size_per_gpu:3d}, "
        f"total={total_batch:3d} → grad_accum={config.use_gradient_accumulation!s:5s}, "
        f"steps={config.accumulation_steps}"
    )

print("\n💡 When total_batch > 256, gradient accumulation is enabled!")

print("\n" + "=" * 60)


# Example with nested fields using dotted paths
class OptimizerSubConfig2(sp.Config):
    name: str = sp.Categorical(["adam", "sgd"])
    base_lr: float = sp.Float(ge=1e-5, le=1e-2, distribution="log")


class AdvancedTrainingConfig(sp.Config):
    """Multi-field lambda with dotted paths."""

    optimizer: OptimizerSubConfig2
    weight_decay: float = sp.Float(ge=0.0, le=0.1)

    # Custom logic: strong regularization when using Adam with low LR
    # (Adam with low LR + high weight_decay can be unstable)
    use_ema: bool = sp.Conditional(
        sp.MultiFieldLambdaCondition(
            ["optimizer.name", "optimizer.base_lr", "weight_decay"],
            lambda data: (
                data["optimizer.name"] == "adam"
                and data["optimizer.base_lr"] < 1e-4
                and data["weight_decay"] > 0.01
            ),
        ),
        true=True,  # Use EMA to stabilize training
        false=False,
    )


print("\n🔗 Multi-field lambda with dotted paths:\n")

for i in range(4):
    config = AdvancedTrainingConfig.random(seed=1000 + i)
    print(
        f"opt={config.optimizer.name:4s}, lr={config.optimizer.base_lr:.2e}, "
        f"wd={config.weight_decay:.3f} → use_ema={config.use_ema}"
    )

print("\n💡 EMA enabled when: adam + low LR (<1e-4) + high weight_decay (>0.01)")

🧮 Multi-field lambda conditions:

gpus=4, batch/gpu=127, total=508 → grad_accum=True , steps=8
gpus=5, batch/gpu= 63, total=315 → grad_accum=True , steps=6
gpus=7, batch/gpu= 42, total=294 → grad_accum=True , steps=2
gpus=2, batch/gpu= 13, total= 26 → grad_accum=False, steps=1
gpus=5, batch/gpu= 99, total=495 → grad_accum=True , steps=6
gpus=6, batch/gpu= 95, total=570 → grad_accum=True , steps=8

💡 When total_batch > 256, gradient accumulation is enabled!


🔗 Multi-field lambda with dotted paths:

opt=sgd , lr=1.02e-03, wd=0.010 → use_ema=False
opt=sgd , lr=1.50e-05, wd=0.076 → use_ema=False
opt=sgd , lr=1.84e-04, wd=0.022 → use_ema=False
opt=adam, lr=1.20e-04, wd=0.072 → use_ema=False

💡 EMA enabled when: adam + low LR (<1e-4) + high weight_decay (>0.01)


## 📝 Summary: Conditional Parameters

You've learned how to create sophisticated conditional logic in SpaX:

### ✅ Core Concepts
1. **Simple conditions**: `FieldCondition` with equality, membership, comparison
2. **Composite conditions**: `And`, `Or`, `Not` for complex logic
3. **Changing spaces**: Conditional branches can be different space types, not just fixed values
4. **Nested fields**: Use dotted paths (`"field.subfield"`) or nested `FieldCondition`
5. **Multi-field logic**: `MultiFieldLambdaCondition` for custom mathematical relationships

### ✅ Key Benefits
- **Validation**: SpaX automatically rejects invalid parameter combinations
- **Efficient sampling**: Random and HPO sampling respect all conditions automatically
- **Type safety**: Enforced dependencies between parameters
- **Flexibility**: From simple if-then to complex multi-field logic

### 🎯 When to Use What
- **Simple dependencies**: `FieldCondition` with built-in conditions
- **Complex AND/OR logic**: Composite conditions (`And`, `Or`, `Not`)
- **Nested configs**: Dotted paths for clean syntax
- **Custom math/logic**: `MultiFieldLambdaCondition`

### 💡 Remember
- Conditional fields are validated at creation time
- Sampling methods automatically handle condition ordering
- You can nest conditions and chain dependencies arbitrarily deep

**Conditional parameters make your configurations robust, type-safe, and intelligent! 🚀**