# Tutorial 02 — Train a SOEN Model

In this tutorial, we'll walk through training a pre-built SOEN model using the training configuration file located at:
`tutorial_notebooks/training/training_configs/pulse_net_trainable.yaml`.

We'll use the `run_from_config` function to launch training. This function makes it easy to set up an experiment — once all training settings are defined in your YAML file, you can start training with a single command.

You can run it either in a script or directly from the command line.
Python:
`run_from_config(str(BASE_CONFIG), script_dir=Path.cwd())`
CLI:
`python -m soen_toolkit.training --config path/to/training_config.yaml`

### ML Task Overview

This example tackles a binary classification problem on time-series inputs:
- **Class 0**: Input contains a single pulse.
- **Class 1**: Input contains two distinct pulses.

### Key Fix Applied
The original `pulse_net.yaml` had `J_1_to_2.learnable: false` which blocked gradient flow.
The new `pulse_net_trainable.yaml` fixes this with `learnable: true`.

**Imports**

In [None]:
# Setup: Ensure soen_toolkit is importable
import sys
from pathlib import Path

# Add src directory to path if running from notebook location
notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import glob

from soen_toolkit.training.trainers.experiment import run_from_config

**Visualize the Dataset**

Let's look at examples from each class to understand the task.

In [None]:
# ============================================================================
# VISUALIZE DATASET: One-pulse vs Two-pulse classification
# ============================================================================

def visualize_dataset(data_path="training/datasets/soen_seq_task_one_or_two_pulses_seq64.hdf5", n_examples=4):
    """Visualize examples from each class in the dataset."""
    
    with h5py.File(data_path, 'r') as f:
        data = np.array(f['train']['data'])
        labels = np.array(f['train']['labels'])
    
    print(f"Dataset shape: {data.shape} (N samples, T timesteps, D features)")
    print(f"Labels shape: {labels.shape}")
    print(f"Class distribution: {np.bincount(labels)}")
    
    # Find examples of each class
    class_0_idx = np.where(labels == 0)[0][:n_examples]
    class_1_idx = np.where(labels == 1)[0][:n_examples]
    
    fig, axes = plt.subplots(2, n_examples, figsize=(3*n_examples, 5))
    fig.suptitle("Input Signals: One-Pulse (Class 0) vs Two-Pulse (Class 1)", fontsize=12, fontweight='bold')
    
    # Plot Class 0 (single pulse)
    for i, idx in enumerate(class_0_idx):
        axes[0, i].plot(data[idx, :, 0], 'b-', linewidth=1.5)
        axes[0, i].set_title(f"Sample {idx}", fontsize=10)
        axes[0, i].set_ylim(-0.1, 1.1)
        axes[0, i].grid(True, alpha=0.3)
        if i == 0:
            axes[0, i].set_ylabel("Class 0\n(One Pulse)", fontsize=10)
    
    # Plot Class 1 (two pulses)
    for i, idx in enumerate(class_1_idx):
        axes[1, i].plot(data[idx, :, 0], 'r-', linewidth=1.5)
        axes[1, i].set_title(f"Sample {idx}", fontsize=10)
        axes[1, i].set_ylim(-0.1, 1.1)
        axes[1, i].grid(True, alpha=0.3)
        if i == 0:
            axes[1, i].set_ylabel("Class 1\n(Two Pulses)", fontsize=10)
        axes[1, i].set_xlabel("Time step")
    
    plt.tight_layout()
    plt.show()
    
    return data, labels

# Visualize the dataset
train_data, train_labels = visualize_dataset()

**Training**

We’ll use the example model and dataset to launch a local test training run. You can experiment by modifying the training YAML file as needed. For more detailed configurations, see: `src/soen_toolkit/training/examples/training_configs`.

Additional information about the training process can be found in: `src/soen_toolkit/training/README.md`.

If you wish to construct your own datasets, please use hdf5 file format. All instructions can be found at: `docs/DATASETS.md`.

In [None]:
# Launch training via Python API using the FIXED trainable config
# Key change: J_1_to_2.learnable = true (was false in original)
run_from_config("training/training_configs/pulse_net_trainable.yaml", script_dir=Path.cwd())

In [None]:
# ============================================================================
# VISUALIZATION: Plot training results using matplotlib (no tensorboard needed)
# ============================================================================

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob

def find_latest_log_dir():
    """Find the most recent training log directory."""
    # Search multiple possible locations
    search_paths = [
        "training/temp/**/events.out.tfevents*",
        "training_logs/**/events.out.tfevents*",
        "lightning_logs/**/events.out.tfevents*",
        "**/logs/**/events.out.tfevents*",
    ]
    
    all_event_files = []
    for pattern in search_paths:
        all_event_files.extend(glob.glob(pattern, recursive=True))
    
    if all_event_files:
        # Get the most recent one
        latest = max(all_event_files, key=lambda x: Path(x).stat().st_mtime)
        return Path(latest).parent
    return None

def parse_tensorboard_logs(log_dir):
    """Parse tensorboard logs using tbparse."""
    try:
        from tbparse import SummaryReader
        reader = SummaryReader(str(log_dir))
        df = reader.scalars
        return df
    except ImportError:
        print("tbparse not available. Install with: pip install tbparse")
        return None
    except Exception as e:
        print(f"Error parsing logs: {e}")
        return None

# Find and parse logs
log_dir = find_latest_log_dir()
if log_dir:
    print(f"Found logs at: {log_dir}")
    df = parse_tensorboard_logs(log_dir)
    
    if df is not None and len(df) > 0:
        # Get unique tags (metrics)
        tags = df['tag'].unique()
        print(f"Available metrics: {list(tags)}")
        
        # Filter for important metrics
        important_tags = [t for t in tags if any(k in t.lower() for k in ['loss', 'accuracy', 'lr'])]
        if not important_tags:
            important_tags = list(tags)[:6]
        
        # Create subplots
        n_plots = min(len(important_tags), 6)
        fig, axes = plt.subplots(2, 3, figsize=(15, 8))
        axes = axes.flatten()
        
        for i, tag in enumerate(important_tags[:6]):
            ax = axes[i]
            metric_data = df[df['tag'] == tag].sort_values('step')
            ax.plot(metric_data['step'], metric_data['value'], 'b-', linewidth=1.5)
            ax.set_xlabel('Step')
            ax.set_ylabel(tag.split('/')[-1])
            ax.set_title(tag, fontsize=10)
            ax.grid(True, alpha=0.3)
        
        # Hide unused subplots
        for i in range(len(important_tags), 6):
            axes[i].set_visible(False)
        
        plt.tight_layout()
        plt.show()
        
        # Print final metrics
        print("\n" + "="*60)
        print("FINAL METRICS")
        print("="*60)
        for tag in important_tags:
            metric_data = df[df['tag'] == tag].sort_values('step')
            if len(metric_data) > 0:
                print(f"{tag}: {metric_data['value'].iloc[-1]:.4f}")
    else:
        print("No scalar data found in logs.")
else:
    print("No training logs found. Run the training cell first.")
    print("Searched in: training/temp/, training_logs/, lightning_logs/")

**Visualize Predictions**

After training, let's see how the model performs on test samples with input visualization.

In [None]:
# ============================================================================
# VISUALIZE PREDICTIONS: Show model predictions on test samples
# ============================================================================

def visualize_predictions(n_samples=8):
    """Load trained model and visualize predictions on test data."""
    
    from soen_toolkit.core.model_yaml import build_model_from_yaml
    
    # Find the latest checkpoint
    ckpt_patterns = [
        "training/temp/**/checkpoints/**/*.ckpt",
        "training/temp/**/*.ckpt",
    ]
    
    all_ckpts = []
    for pattern in ckpt_patterns:
        all_ckpts.extend(glob.glob(pattern, recursive=True))
    
    if not all_ckpts:
        print("No checkpoint found. Run training first.")
        return
    
    latest_ckpt = max(all_ckpts, key=lambda x: Path(x).stat().st_mtime)
    print(f"Loading checkpoint: {latest_ckpt}")
    
    # Load model architecture and weights
    model_path = Path("training/test_models/model_specs/1D_5D_2D_PulseNetSpec_trainable.yaml")
    model = build_model_from_yaml(model_path)
    
    # Load trained weights
    ckpt = torch.load(latest_ckpt, map_location='cpu')
    state_dict = ckpt.get('state_dict', ckpt)
    
    # Remove 'model.' prefix if present (from Lightning wrapper)
    clean_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('model.'):
            clean_state_dict[k[6:]] = v
        else:
            clean_state_dict[k] = v
    
    try:
        model.load_state_dict(clean_state_dict, strict=False)
        print("Model weights loaded successfully")
    except Exception as e:
        print(f"Warning loading weights: {e}")
    
    model.eval()
    
    # Load test data
    data_path = Path("training/datasets/soen_seq_task_one_or_two_pulses_seq64.hdf5")
    with h5py.File(data_path, 'r') as f:
        # Use validation set if available, otherwise train
        if 'val' in f:
            test_data = np.array(f['val']['data'][:n_samples])
            test_labels = np.array(f['val']['labels'][:n_samples])
        else:
            test_data = np.array(f['train']['data'][:n_samples])
            test_labels = np.array(f['train']['labels'][:n_samples])
    
    # Run inference
    with torch.no_grad():
        x = torch.tensor(test_data, dtype=torch.float32)
        output, all_states = model(x)
        
        # Apply max pooling over time (like training)
        if output.dim() == 3:
            pooled = output.max(dim=1)[0]  # [batch, num_classes]
        else:
            pooled = output
        
        # Get predictions
        probs = torch.softmax(pooled, dim=1)
        predictions = torch.argmax(probs, dim=1).numpy()
        confidence = probs.max(dim=1)[0].numpy()
    
    # Visualize
    n_cols = min(4, n_samples)
    n_rows = (n_samples + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.5*n_cols, 3*n_rows))
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    fig.suptitle("Model Predictions on Test Samples", fontsize=14, fontweight='bold')
    
    class_names = ["One Pulse", "Two Pulses"]
    
    for i in range(n_samples):
        row, col = i // n_cols, i % n_cols
        ax = axes[row, col]
        
        # Plot input signal
        ax.plot(test_data[i, :, 0], 'b-', linewidth=1.5, alpha=0.8)
        ax.set_ylim(-0.1, 1.1)
        ax.grid(True, alpha=0.3)
        
        # Determine if prediction is correct
        is_correct = predictions[i] == test_labels[i]
        color = 'green' if is_correct else 'red'
        symbol = '✓' if is_correct else '✗'
        
        # Title with prediction info
        true_label = class_names[test_labels[i]]
        pred_label = class_names[predictions[i]]
        
        ax.set_title(
            f"{symbol} Pred: {pred_label} ({confidence[i]:.1%})\nTrue: {true_label}",
            fontsize=9,
            color=color,
            fontweight='bold' if not is_correct else 'normal'
        )
        
        if col == 0:
            ax.set_ylabel("Signal")
        if row == n_rows - 1:
            ax.set_xlabel("Time step")
    
    # Hide empty subplots
    for i in range(n_samples, n_rows * n_cols):
        row, col = i // n_cols, i % n_cols
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    accuracy = (predictions == test_labels).mean()
    print(f"\n{'='*50}")
    print(f"PREDICTION SUMMARY")
    print(f"{'='*50}")
    print(f"Accuracy on {n_samples} samples: {accuracy:.1%}")
    print(f"Correct: {(predictions == test_labels).sum()}/{n_samples}")
    
    # Show confusion matrix style breakdown
    for true_class in [0, 1]:
        for pred_class in [0, 1]:
            count = ((test_labels == true_class) & (predictions == pred_class)).sum()
            if count > 0:
                print(f"  True {class_names[true_class]} -> Pred {class_names[pred_class]}: {count}")

# Visualize predictions
visualize_predictions(n_samples=8)

---


### Quick Notes on Datasets

soen_toolkit.training models expect datasets in **HDF5 format** with the following structure:

- **Inputs** (`data`): `[N, T, D]`  
  - `N`: number of samples  
  - `T`: sequence length  
  - `D`: feature dimension (should be equal to the number of units in the input layer - ID=0)

- **Labels** (`labels`): shape depends on the task  
  - Classification (seq2static): `[N]` (int64 class indices)  
  - Classification (seq2seq): `[N, T]` (int64 per-timestep classes)  
  - Regression (seq2static): `[N, K]` (float32)  
  - Regression (seq2seq): `[N, T, K]` (float32)  
  - Unsupervised (seq2seq): labels optional; inputs are used as targets  

**Recommended layout:**

root/
train/{data, labels}
val/{data, labels}
test/{data, labels}

**Key config notes:**
- Set `training.paradigm` and `training.mapping` in your YAML (e.g., `supervised` + `seq2static`).  
- Use `data.target_seq_len` to align input/output sequence lengths.  
- Pooling for seq2static tasks is controlled via `model.time_pooling`.


## Manual Evaluation and Visualization

If the above log parsing doesn't work, you can manually evaluate the trained model:

In [None]:
# ============================================================================
# MANUAL EVALUATION: Load trained model and evaluate on test data
# ============================================================================

import torch
import h5py
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Find the latest trained model
def find_latest_model(base_path="training_logs"):
    """Find the most recent trained model checkpoint."""
    patterns = [
        f"{base_path}/**/*.soen",
        f"{base_path}/**/*.ckpt", 
        "lightning_logs/**/*.ckpt",
    ]
    
    all_models = []
    for pattern in patterns:
        all_models.extend(glob.glob(pattern, recursive=True))
    
    if all_models:
        return max(all_models, key=lambda x: Path(x).stat().st_mtime)
    return None

# Load model
model_path = find_latest_model()
if model_path:
    print(f"Found trained model: {model_path}")
    
    # Load based on extension
    if model_path.endswith('.soen'):
        from soen_toolkit.core import SOENModelCore
        model = SOENModelCore.load(model_path)
    else:
        # Load from checkpoint
        from soen_toolkit.training.models import SOENLightningModule
        model = SOENLightningModule.load_from_checkpoint(model_path)
        model = model.model  # Get the underlying SOEN model
    
    model.eval()
    print("Model loaded successfully!")
    
    # Load test data
    data_path = Path("training/datasets/soen_seq_task_one_or_two_pulses_seq64.hdf5")
    if data_path.exists():
        with h5py.File(data_path, 'r') as f:
            # Try test split, fall back to val
            split = 'test' if 'test' in f else 'val'
            test_data = torch.tensor(f[split]['data'][:], dtype=torch.float32)
            test_labels = torch.tensor(f[split]['labels'][:], dtype=torch.long)
        
        print(f"Loaded {split} data: {test_data.shape}")
        
        # Run inference
        with torch.no_grad():
            outputs, _ = model(test_data[:100])  # First 100 samples
            
            # Get predictions (assuming last timestep, argmax for classification)
            if outputs.dim() == 3:
                outputs = outputs[:, -1, :]  # Take last timestep
            predictions = outputs.argmax(dim=-1)
        
        # Calculate accuracy
        correct = (predictions == test_labels[:100]).sum().item()
        accuracy = correct / len(predictions) * 100
        print(f"\nTest Accuracy: {accuracy:.1f}% ({correct}/{len(predictions)})")
        
        # Visualize some predictions
        fig, axes = plt.subplots(2, 4, figsize=(16, 6))
        
        for i, ax in enumerate(axes.flatten()):
            if i >= len(test_data):
                break
            
            # Plot input signal
            signal = test_data[i, :, 0].numpy()
            ax.plot(signal, 'b-', linewidth=1.5)
            
            true_label = test_labels[i].item()
            pred_label = predictions[i].item()
            
            color = 'green' if true_label == pred_label else 'red'
            ax.set_title(f"True: {true_label}, Pred: {pred_label}", color=color)
            ax.set_xlabel("Time")
            ax.set_ylabel("Input")
            ax.grid(True, alpha=0.3)
        
        plt.suptitle(f"Sample Predictions (Accuracy: {accuracy:.1f}%)", fontsize=14)
        plt.tight_layout()
        plt.show()
        
        # Confusion matrix
        from sklearn.metrics import confusion_matrix
        cm = confusion_matrix(test_labels[:100].numpy(), predictions.numpy())
        
        fig, ax = plt.subplots(figsize=(6, 5))
        im = ax.imshow(cm, cmap='Blues')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title('Confusion Matrix')
        ax.set_xticks([0, 1])
        ax.set_yticks([0, 1])
        
        # Add text annotations
        for i in range(2):
            for j in range(2):
                ax.text(j, i, str(cm[i, j]), ha='center', va='center', 
                       color='white' if cm[i, j] > cm.max()/2 else 'black', fontsize=14)
        
        plt.colorbar(im)
        plt.tight_layout()
        plt.show()
    else:
        print(f"Dataset not found at {data_path}")
else:
    print("No trained model found. Run training first.")