# Tutorial 02 — Train a SOEN Model

In this tutorial, we’ll walk through training a pre-built SOEN model using the training configuration file located at:
`tutorial_notebooks/training/training_configs/pulse_net.yaml`.

We’ll use the `run_from_config` function to launch training. This function makes it easy to set up an experiment — once all training settings are defined in your YAML file, you can start training with a single command.

You can run it either in a script or directly from the command line.
Python:
`run_from_config(str(BASE_CONFIG), script_dir=Path.cwd())`
CLI:
`python -m soen_toolkit.training --config path/to/training_config.yaml`

### ML Task Overview

This example tackles a binary classification problem on time-series inputs:
- Class 1: Input contains a single pulse.
- Class 2: Input contains two distinct pulses.

**Imports**

In [None]:
# Setup: Ensure soen_toolkit is importable
import sys
from pathlib import Path

# Add src directory to path if running from notebook location
notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

from soen_toolkit.training.trainers.experiment import run_from_config

**Training**

We’ll use the example model and dataset to launch a local test training run. You can experiment by modifying the training YAML file as needed. For more detailed configurations, see: `src/soen_toolkit/training/examples/training_configs`.

Additional information about the training process can be found in: `src/soen_toolkit/training/README.md`.

If you wish to construct your own datasets, please use hdf5 file format. All instructions can be found at: `docs/DATASETS.md`.

In [None]:
# Launch training via Python API
run_from_config("training/training_configs/pulse_net.yaml", script_dir=Path.cwd())

In [ ]:
# ============================================================================
# VISUALIZATION: Plot training results using matplotlib (no tensorboard needed)
# ============================================================================

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob

def find_latest_log_dir(base_path="training_logs"):
    """Find the most recent training log directory."""
    # Look for tensorboard event files
    pattern = f"{base_path}/**/events.out.tfevents*"
    event_files = glob.glob(pattern, recursive=True)
    
    if not event_files:
        # Try lightning_logs
        pattern = "lightning_logs/**/events.out.tfevents*"
        event_files = glob.glob(pattern, recursive=True)
    
    if event_files:
        # Get the most recent one
        latest = max(event_files, key=lambda x: Path(x).stat().st_mtime)
        return Path(latest).parent
    return None

def parse_tensorboard_logs(log_dir):
    """Parse tensorboard logs using tbparse."""
    try:
        from tbparse import SummaryReader
        reader = SummaryReader(str(log_dir))
        df = reader.scalars
        return df
    except ImportError:
        print("tbparse not available, trying manual parsing...")
        return None
    except Exception as e:
        print(f"Error parsing logs: {e}")
        return None

# Find and parse logs
log_dir = find_latest_log_dir()
if log_dir:
    print(f"Found logs at: {log_dir}")
    df = parse_tensorboard_logs(log_dir)
    
    if df is not None and len(df) > 0:
        # Get unique tags (metrics)
        tags = df['tag'].unique()
        print(f"Available metrics: {list(tags)}")
        
        # Create subplots
        n_metrics = min(len(tags), 6)  # Show up to 6 metrics
        fig, axes = plt.subplots(2, 3, figsize=(15, 8))
        axes = axes.flatten()
        
        for i, tag in enumerate(tags[:6]):
            ax = axes[i]
            metric_data = df[df['tag'] == tag].sort_values('step')
            ax.plot(metric_data['step'], metric_data['value'], 'b-', linewidth=1.5)
            ax.set_xlabel('Step')
            ax.set_ylabel(tag.split('/')[-1])
            ax.set_title(tag)
            ax.grid(True, alpha=0.3)
        
        # Hide unused subplots
        for i in range(len(tags), 6):
            axes[i].set_visible(False)
        
        plt.tight_layout()
        plt.show()
    else:
        print("No scalar data found in logs.")
else:
    print("No training logs found. Run the training cell first.")

---


### Quick Notes on Datasets

soen_toolkit.training models expect datasets in **HDF5 format** with the following structure:

- **Inputs** (`data`): `[N, T, D]`  
  - `N`: number of samples  
  - `T`: sequence length  
  - `D`: feature dimension (should be equal to the number of units in the input layer - ID=0)

- **Labels** (`labels`): shape depends on the task  
  - Classification (seq2static): `[N]` (int64 class indices)  
  - Classification (seq2seq): `[N, T]` (int64 per-timestep classes)  
  - Regression (seq2static): `[N, K]` (float32)  
  - Regression (seq2seq): `[N, T, K]` (float32)  
  - Unsupervised (seq2seq): labels optional; inputs are used as targets  

**Recommended layout:**

root/
train/{data, labels}
val/{data, labels}
test/{data, labels}

**Key config notes:**
- Set `training.paradigm` and `training.mapping` in your YAML (e.g., `supervised` + `seq2static`).  
- Use `data.target_seq_len` to align input/output sequence lengths.  
- Pooling for seq2static tasks is controlled via `model.time_pooling`.


## Manual Evaluation and Visualization

If the above log parsing doesn't work, you can manually evaluate the trained model:

In [None]:
# ============================================================================
# MANUAL EVALUATION: Load trained model and evaluate on test data
# ============================================================================

import torch
import h5py
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Find the latest trained model
def find_latest_model(base_path="training_logs"):
    """Find the most recent trained model checkpoint."""
    patterns = [
        f"{base_path}/**/*.soen",
        f"{base_path}/**/*.ckpt", 
        "lightning_logs/**/*.ckpt",
    ]
    
    all_models = []
    for pattern in patterns:
        all_models.extend(glob.glob(pattern, recursive=True))
    
    if all_models:
        return max(all_models, key=lambda x: Path(x).stat().st_mtime)
    return None

# Load model
model_path = find_latest_model()
if model_path:
    print(f"Found trained model: {model_path}")
    
    # Load based on extension
    if model_path.endswith('.soen'):
        from soen_toolkit.core import SOENModelCore
        model = SOENModelCore.load(model_path)
    else:
        # Load from checkpoint
        from soen_toolkit.training.models import SOENLightningModule
        model = SOENLightningModule.load_from_checkpoint(model_path)
        model = model.model  # Get the underlying SOEN model
    
    model.eval()
    print("Model loaded successfully!")
    
    # Load test data
    data_path = Path("training/datasets/soen_seq_task_one_or_two_pulses_seq64.hdf5")
    if data_path.exists():
        with h5py.File(data_path, 'r') as f:
            # Try test split, fall back to val
            split = 'test' if 'test' in f else 'val'
            test_data = torch.tensor(f[split]['data'][:], dtype=torch.float32)
            test_labels = torch.tensor(f[split]['labels'][:], dtype=torch.long)
        
        print(f"Loaded {split} data: {test_data.shape}")
        
        # Run inference
        with torch.no_grad():
            outputs, _ = model(test_data[:100])  # First 100 samples
            
            # Get predictions (assuming last timestep, argmax for classification)
            if outputs.dim() == 3:
                outputs = outputs[:, -1, :]  # Take last timestep
            predictions = outputs.argmax(dim=-1)
        
        # Calculate accuracy
        correct = (predictions == test_labels[:100]).sum().item()
        accuracy = correct / len(predictions) * 100
        print(f"\nTest Accuracy: {accuracy:.1f}% ({correct}/{len(predictions)})")
        
        # Visualize some predictions
        fig, axes = plt.subplots(2, 4, figsize=(16, 6))
        
        for i, ax in enumerate(axes.flatten()):
            if i >= len(test_data):
                break
            
            # Plot input signal
            signal = test_data[i, :, 0].numpy()
            ax.plot(signal, 'b-', linewidth=1.5)
            
            true_label = test_labels[i].item()
            pred_label = predictions[i].item()
            
            color = 'green' if true_label == pred_label else 'red'
            ax.set_title(f"True: {true_label}, Pred: {pred_label}", color=color)
            ax.set_xlabel("Time")
            ax.set_ylabel("Input")
            ax.grid(True, alpha=0.3)
        
        plt.suptitle(f"Sample Predictions (Accuracy: {accuracy:.1f}%)", fontsize=14)
        plt.tight_layout()
        plt.show()
        
        # Confusion matrix
        from sklearn.metrics import confusion_matrix
        cm = confusion_matrix(test_labels[:100].numpy(), predictions.numpy())
        
        fig, ax = plt.subplots(figsize=(6, 5))
        im = ax.imshow(cm, cmap='Blues')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title('Confusion Matrix')
        ax.set_xticks([0, 1])
        ax.set_yticks([0, 1])
        
        # Add text annotations
        for i in range(2):
            for j in range(2):
                ax.text(j, i, str(cm[i, j]), ha='center', va='center', 
                       color='white' if cm[i, j] > cm.max()/2 else 'black', fontsize=14)
        
        plt.colorbar(im)
        plt.tight_layout()
        plt.show()
    else:
        print(f"Dataset not found at {data_path}")
else:
    print("No trained model found. Run training first.")