# Stage 2 Model Training

Train a classifier to predict ZigZag signal type: HH, LH, HL, LL

Using filtered data from Stage 1 signal detection.

In [None]:
import sys
from pathlib import Path
import pickle
import numpy as np
import pandas as pd

project_root = Path('.').resolve().parent
sys.path.insert(0, str(project_root))

from src.stage2_trainer import Stage2Trainer

### Step 1: Load Stage 2 Data

In [None]:
data_dir = Path('data/stage2')

with open(data_dir / 'X_stage2_train.pkl', 'rb') as f:
    X_stage2_train = pickle.load(f)
with open(data_dir / 'y_stage2_train.pkl', 'rb') as f:
    y_stage2_train = pickle.load(f)

with open(data_dir / 'X_stage2_val.pkl', 'rb') as f:
    X_stage2_val = pickle.load(f)
with open(data_dir / 'y_stage2_val.pkl', 'rb') as f:
    y_stage2_val = pickle.load(f)

with open(data_dir / 'X_stage2_test.pkl', 'rb') as f:
    X_stage2_test = pickle.load(f)
with open(data_dir / 'y_stage2_test.pkl', 'rb') as f:
    y_stage2_test = pickle.load(f)

print(f'Data loaded successfully')
print(f'Train shape: {X_stage2_train.shape}')
print(f'Val shape: {X_stage2_val.shape}')
print(f'Test shape: {X_stage2_test.shape}')

### Step 2: Initialize and Train Stage 2 Model

In [None]:
trainer = Stage2Trainer(model_dir='models/stage2')

results = trainer.train(
    X_stage2_train, y_stage2_train,
    X_stage2_val, y_stage2_val,
    normalize=True,
    save_model=True
)

### Step 3: Evaluate on Test Set

In [None]:
metrics = trainer.evaluate(X_stage2_test, y_stage2_test)

### Step 4: Cross-Validation

In [None]:
# Combine train and validation data for CV
X_combined = np.vstack([X_stage2_train, X_stage2_val])
y_combined = np.hstack([y_stage2_train, y_stage2_val])

cv_results = trainer.cross_validate(X_combined, y_combined, cv=5)

### Step 5: Summary

In [None]:
print(f'\n' + '='*60)
print('STAGE 2 TRAINING COMPLETE')
print('='*60)
print(f'\nTest Accuracy: {metrics["accuracy"]:.4f}')
print(f'Test F1-Score: {metrics["f1_score"]:.4f}')
print(f'\nCross-Validation Mean Accuracy: {cv_results["mean_accuracy"]:.4f}')
print(f'Cross-Validation Std: {cv_results["std_accuracy"]:.4f}')
print(f'\nReady for inference!')
print('='*60)