# Training Analysis - ConvNeXt Model

This notebook analyzes the training results of the ConvNeXt earthquake precursor detection model.

In [None]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')

## 1. Load Training History

In [None]:
# Load training history
history_path = '../models/training_history.csv'
history = pd.read_csv(history_path)
print(f"Total epochs: {len(history)}")
history.head()

## 2. Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs = range(1, len(history) + 1)

# Loss
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train', linewidth=2)
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Validation', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Magnitude Accuracy
axes[0, 1].plot(epochs, history['train_mag_acc'], 'b-', label='Train', linewidth=2)
axes[0, 1].plot(epochs, history['val_mag_acc'], 'r-', label='Validation', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Magnitude Classification Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Azimuth Accuracy
axes[1, 0].plot(epochs, history['train_azi_acc'], 'b-', label='Train', linewidth=2)
axes[1, 0].plot(epochs, history['val_azi_acc'], 'r-', label='Validation', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Accuracy')
axes[1, 0].set_title('Azimuth Classification Accuracy')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Learning Rate
axes[1, 1].plot(epochs, history['lr'], 'g-', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()

## 3. Best Results

In [None]:
best_mag_idx = history['val_mag_acc'].idxmax()
best_azi_idx = history['val_azi_acc'].idxmax()

print("Best Magnitude Accuracy:")
print(f"  Epoch: {best_mag_idx + 1}")
print(f"  Train: {history.loc[best_mag_idx, 'train_mag_acc']:.4f}")
print(f"  Val:   {history.loc[best_mag_idx, 'val_mag_acc']:.4f}")

print("\nBest Azimuth Accuracy:")
print(f"  Epoch: {best_azi_idx + 1}")
print(f"  Train: {history.loc[best_azi_idx, 'train_azi_acc']:.4f}")
print(f"  Val:   {history.loc[best_azi_idx, 'val_azi_acc']:.4f}")

## 4. LOEO Cross-Validation Results

In [None]:
# Load LOEO results if available
loeo_path = '../models/loeo_results.json'
if Path(loeo_path).exists():
    with open(loeo_path) as f:
        loeo_results = json.load(f)
    
    folds = list(range(1, len(loeo_results['fold_results']) + 1))
    mag_accs = [r['magnitude_accuracy'] for r in loeo_results['fold_results']]
    azi_accs = [r['azimuth_accuracy'] for r in loeo_results['fold_results']]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].bar(folds, mag_accs, color='steelblue')
    axes[0].axhline(y=np.mean(mag_accs), color='red', linestyle='--', label=f'Mean: {np.mean(mag_accs):.2%}')
    axes[0].set_xlabel('Fold')
    axes[0].set_ylabel('Accuracy')
    axes[0].set_title('Magnitude Accuracy per Fold')
    axes[0].legend()
    
    axes[1].bar(folds, azi_accs, color='coral')
    axes[1].axhline(y=np.mean(azi_accs), color='red', linestyle='--', label=f'Mean: {np.mean(azi_accs):.2%}')
    axes[1].set_xlabel('Fold')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Azimuth Accuracy per Fold')
    axes[1].legend()
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nLOEO Summary:")
    print(f"Magnitude: {np.mean(mag_accs):.2%} ± {np.std(mag_accs):.2%}")
    print(f"Azimuth:   {np.mean(azi_accs):.2%} ± {np.std(azi_accs):.2%}")
else:
    print("LOEO results not found. Run LOEO validation first.")