# Tutorial 02 — Train a SOEN Model

In this tutorial, we’ll walk through training a pre-built SOEN model using the training configuration file located at:
`tutorial_notebooks/training/training_configs/pulse_net.yaml`.

We’ll use the `run_from_config` function to launch training. This function makes it easy to set up an experiment — once all training settings are defined in your YAML file, you can start training with a single command.

You can run it either in a script or directly from the command line.
Python:
`run_from_config(str(BASE_CONFIG), script_dir=Path.cwd())`
CLI:
`python -m soen_toolkit.training --config path/to/training_config.yaml`

### ML Task Overview

This example tackles a binary classification problem on time-series inputs:
- Class 1: Input contains a single pulse.
- Class 2: Input contains two distinct pulses.

**Imports**

In [None]:
# Setup: Ensure soen_toolkit is importable
import sys
from pathlib import Path

# Add src directory to path if running from notebook location
notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

from soen_toolkit.training.trainers.experiment import run_from_config

**Training**

We’ll use the example model and dataset to launch a local test training run. You can experiment by modifying the training YAML file as needed. For more detailed configurations, see: `src/soen_toolkit/training/examples/training_configs`.

Additional information about the training process can be found in: `src/soen_toolkit/training/README.md`.

If you wish to construct your own datasets, please use hdf5 file format. All instructions can be found at: `docs/DATASETS.md`.

In [None]:
# Launch training via Python API
run_from_config("training/training_configs/pulse_net.yaml", script_dir=Path.cwd())

**Diagnostic: Check if weights are updating**

Run this cell after training to verify the model is actually learning.

In [None]:
# ============================================================================
# DIAGNOSTIC: Check if model weights are updating during training
# ============================================================================
import torch
import glob
from pathlib import Path

def diagnose_training():
    """Check if the trained model has meaningful weight updates."""
    
    # Find the latest checkpoint
    ckpt_patterns = [
        "training/temp/**/checkpoints/**/*.ckpt",
        "training/temp/**/*.ckpt",
    ]
    
    all_ckpts = []
    for pattern in ckpt_patterns:
        all_ckpts.extend(glob.glob(pattern, recursive=True))
    
    if not all_ckpts:
        print("No checkpoint found. Run training first.")
        return
    
    latest_ckpt = max(all_ckpts, key=lambda x: Path(x).stat().st_mtime)
    print(f"Loading checkpoint: {latest_ckpt}\n")
    
    # Load checkpoint
    ckpt = torch.load(latest_ckpt, map_location='cpu')
    state_dict = ckpt.get('state_dict', ckpt)
    
    print("=" * 70)
    print("WEIGHT DIAGNOSTICS")
    print("=" * 70)
    
    # Analyze each parameter
    total_params = 0
    zero_params = 0
    tiny_params = 0  # params with very small magnitude
    
    print(f"\n{'Parameter':<50} {'Shape':<15} {'Mean':<12} {'Std':<12} {'Max':<12}")
    print("-" * 100)
    
    for name, param in state_dict.items():
        if 'weight' in name.lower() or 'bias' in name.lower() or 'connection' in name.lower():
            p = param.float()
            num_params = p.numel()
            total_params += num_params
            
            # Count zeros and tiny values
            zero_params += (p == 0).sum().item()
            tiny_params += (p.abs() < 1e-6).sum().item()
            
            mean_val = p.mean().item()
            std_val = p.std().item()
            max_val = p.abs().max().item()
            
            shape_str = str(list(p.shape))
            print(f"{name:<50} {shape_str:<15} {mean_val:>11.6f} {std_val:>11.6f} {max_val:>11.6f}")
    
    print("-" * 100)
    print(f"\nTotal trainable params analyzed: {total_params:,}")
    print(f"Zero-valued params: {zero_params:,} ({100*zero_params/max(total_params,1):.1f}%)")
    print(f"Tiny params (<1e-6): {tiny_params:,} ({100*tiny_params/max(total_params,1):.1f}%)")
    
    # Check for gradient flow issues
    print("\n" + "=" * 70)
    print("DIAGNOSIS")
    print("=" * 70)
    
    if zero_params > total_params * 0.9:
        print("⚠️  WARNING: >90% of weights are zero - possible dead network")
    elif tiny_params > total_params * 0.8:
        print("⚠️  WARNING: >80% of weights are tiny - gradients may not be flowing")
    else:
        print("✓ Weights have reasonable magnitude distribution")
    
    # Check if weights look initialized vs trained
    for name, param in state_dict.items():
        if 'connection' in name.lower() and 'weight' not in name.lower():
            p = param.float()
            if p.std() < 0.001:
                print(f"⚠️  {name}: Very low variance ({p.std():.6f}) - may not have trained")

# Run diagnostic
diagnose_training()

**View logs in TensorBoard (Optional)**

Start TensorBoard in a terminal so you can watch metrics live.

1. Activate your environment (if not already):
2. Run TensorBoard, pointing at the logs root printed above ("Logs root:"):
```bash
tensorboard --logdir "/path/to/logs/root"
```

---


### Quick Notes on Datasets

soen_toolkit.training models expect datasets in **HDF5 format** with the following structure:

- **Inputs** (`data`): `[N, T, D]`  
  - `N`: number of samples  
  - `T`: sequence length  
  - `D`: feature dimension (should be equal to the number of units in the input layer - ID=0)

- **Labels** (`labels`): shape depends on the task  
  - Classification (seq2static): `[N]` (int64 class indices)  
  - Classification (seq2seq): `[N, T]` (int64 per-timestep classes)  
  - Regression (seq2static): `[N, K]` (float32)  
  - Regression (seq2seq): `[N, T, K]` (float32)  
  - Unsupervised (seq2seq): labels optional; inputs are used as targets  

**Recommended layout:**

root/
train/{data, labels}
val/{data, labels}
test/{data, labels}

**Key config notes:**
- Set `training.paradigm` and `training.mapping` in your YAML (e.g., `supervised` + `seq2static`).  
- Use `data.target_seq_len` to align input/output sequence lengths.  
- Pooling for seq2static tasks is controlled via `model.time_pooling`.
