# **AML Mistake Recognition - LSTM Baseline Training**

This notebook trains and evaluates the new LSTM baseline for mistake detection in procedural activities.

**What's new:**
- Train LSTM baseline from scratch
- Compare with MLP and Transformer baselines
- Evaluate on step and recordings splits
- Generate comparison visualizations

# **1 Initial Setup**

## 1.1 Install dependencies

In [None]:
!pip install torcheval wandb -q

## 1.2 Import the project (with LSTM baseline)

In [None]:
## Delete /code if it already exists and you need to reclone the project
!rm -rf code

In [None]:
# Clone the project repository
!git clone --recursive https://github.com/SimoneColu/AML_error_recognition.git code

### 1.2.1 Add LSTM Implementation Files

**Option 1:** If your repository already has the LSTM files, skip this cell.

**Option 2:** If not, you need to upload the LSTM files we created:

In [None]:
# ONLY RUN THIS IF YOUR REPO DOESN'T HAVE THE LSTM FILES YET

# You have two options:
# Option A: Upload from your local machine (click the folder icon, then upload)
# Upload these files to /content/:
#   - er_lstm.py (goes to code/core/models/)
#   - updated base.py (goes to code/)
#   - updated constants.py (goes to code/)

# Option B: Copy from your Google Drive (if you've already uploaded them there)
from google.colab import drive
drive.mount('/content/drive')

# Copy files from Drive to the project
# !cp /content/drive/MyDrive/AML_DAAI_25_26/er_lstm.py code/core/models/
# !cp /content/drive/MyDrive/AML_DAAI_25_26/base.py code/
# !cp /content/drive/MyDrive/AML_DAAI_25_26/constants.py code/

### 1.2.2 Verify LSTM Files Are Present

In [None]:
# Check that LSTM model file exists
!ls -lh code/core/models/er_lstm.py

# Check that constants have LSTM variants
!grep "LSTM_VARIANT" code/constants.py

# Check that base.py supports LSTM
!grep -A 3 "LSTM_VARIANT" code/base.py | head -10

## 1.3 Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 1.4 Set the constants

In [None]:
DRIVE_PATH = "/content/drive/MyDrive/AML_DAAI_25_26"
LOCAL_CODE_PATH = "/content/code"

## Create necessary directories
!mkdir -p {LOCAL_CODE_PATH}/data/video
!mkdir -p {LOCAL_CODE_PATH}/checkpoints/error_recognition/LSTM/omnivore
!mkdir -p {LOCAL_CODE_PATH}/checkpoints/error_recognition/LSTM/slowfast
!mkdir -p {LOCAL_CODE_PATH}/checkpoints/error_recognition/GRU/omnivore
!mkdir -p {LOCAL_CODE_PATH}/checkpoints/error_recognition/LSTM_Attention/omnivore
!mkdir -p {LOCAL_CODE_PATH}/results
!mkdir -p {LOCAL_CODE_PATH}/plots

# **2 Extract Features**

## 2.1 Extract Omnivore Features

In [None]:
# Extract Omnivore features (quiet mode)
!unzip -q "{DRIVE_PATH}/data/backbone/omnivore.zip" -d {LOCAL_CODE_PATH}/data/video

## 2.2 Extract SlowFast Features (Optional)

Uncomment if you want to train on SlowFast features as well.

In [None]:
# !unzip -q "{DRIVE_PATH}/data/backbone/slowfast.zip" -d {LOCAL_CODE_PATH}/data/video

# **3 Train LSTM Baseline**

## 3.1 Train LSTM on Step Split (Omnivore)

This will train the LSTM model on the step split. Expected training time: **1-2 hours**

In [None]:
%%bash

cd code
python train_er.py \
  --variant LSTM \
  --backbone omnivore \
  --split step \
  --num_epochs 50 \
  --lr 1e-3 \
  --batch_size 1 \
  --weight_decay 1e-3 \
  --seed 42 \
  --ckpt_directory /content/code/checkpoints

## 3.2 Train LSTM on Recordings Split (Omnivore)

This will train on the more challenging recordings split.

In [None]:
%%bash

cd code
python train_er.py \
  --variant LSTM \
  --backbone omnivore \
  --split recordings \
  --num_epochs 50 \
  --lr 1e-3 \
  --batch_size 1 \
  --weight_decay 1e-3 \
  --seed 42 \
  --ckpt_directory /content/code/checkpoints

## 3.3 Train Alternative Variants (Optional)

### 3.3.1 LSTM with Attention

In [None]:
%%bash

cd code
python train_er.py \
  --variant LSTM_Attention \
  --backbone omnivore \
  --split step \
  --num_epochs 50 \
  --lr 1e-3 \
  --batch_size 1 \
  --weight_decay 1e-3 \
  --seed 42 \
  --ckpt_directory /content/code/checkpoints

### 3.3.2 GRU Variant

In [None]:
%%bash

cd code
python train_er.py \
  --variant GRU \
  --backbone omnivore \
  --split step \
  --num_epochs 50 \
  --lr 1e-3 \
  --batch_size 1 \
  --weight_decay 1e-3 \
  --seed 42 \
  --ckpt_directory /content/code/checkpoints

## 3.4 Monitor Training Progress

Check the training logs to see how the model is performing.

In [None]:
# View training statistics
!tail -20 code/stats/error_recognition/LSTM/omnivore/*_training_performance.txt

# **4 Evaluate LSTM Baseline**

## 4.1 Evaluate LSTM - Step Split

In [None]:
%%bash

cd code
python -m core.evaluate \
  --variant LSTM \
  --backbone omnivore \
  --ckpt checkpoints/error_recognition/LSTM/omnivore/error_recognition_step_omnivore_LSTM_video_best.pt \
  --split step \
  --threshold 0.6

## 4.2 Evaluate LSTM - Recordings Split

In [None]:
%%bash

cd code
python -m core.evaluate \
  --variant LSTM \
  --backbone omnivore \
  --ckpt checkpoints/error_recognition/LSTM/omnivore/error_recognition_recordings_omnivore_LSTM_video_best.pt \
  --split recordings \
  --threshold 0.4

# **5 Compare with Existing Baselines**

## 5.1 Evaluate MLP Baseline (if checkpoint available)

In [None]:
# Copy MLP checkpoint from Drive
!cp "{DRIVE_PATH}/data/checkpoint/MLP/error_recognition_MLP_omnivore_step_epoch_43.pt" code/checkpoints/error_recognition_best/MLP/omnivore/

In [None]:
%%bash

cd code
python -m core.evaluate \
  --variant MLP \
  --backbone omnivore \
  --ckpt checkpoints/error_recognition_best/MLP/omnivore/error_recognition_MLP_omnivore_step_epoch_43.pt \
  --split step \
  --threshold 0.6

## 5.2 Evaluate Transformer Baseline (if checkpoint available)

In [None]:
# Copy Transformer checkpoint from Drive
!cp "{DRIVE_PATH}/data/checkpoint/Transformer/error_recognition_Transformer_omnivore_step_epoch_9.pt" code/checkpoints/error_recognition_best/Transformer/omnivore/

In [None]:
%%bash

cd code
python -m core.evaluate \
  --variant Transformer \
  --backbone omnivore \
  --ckpt checkpoints/error_recognition_best/Transformer/omnivore/error_recognition_Transformer_omnivore_step_epoch_9.pt \
  --split step \
  --threshold 0.6

# **6 Visualize Results**

## 6.1 View Results CSV

In [None]:
import pandas as pd

# Read results for step split
df = pd.read_csv('code/results/error_recognition/combined_results/step_True_substep_True_threshold_0.6.csv')
print("\n===== STEP SPLIT RESULTS =====")
print(df)

# Filter for main variants
print("\n===== COMPARISON (MLP vs Transformer vs LSTM) =====")
comparison = df[df['Variant'].isin(['MLP', 'Transformer', 'LSTM'])]
print(comparison[['Variant', 'Backbone', 'Step F1', 'Step AUC', 'Step Precision', 'Step Recall']])

## 6.2 Generate Comparison Plots

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 150

# Read results
df = pd.read_csv('code/results/error_recognition/combined_results/step_True_substep_True_threshold_0.6.csv')
main_variants = df[df['Variant'].isin(['MLP', 'Transformer', 'LSTM'])]

# Plot F1 comparison
plt.figure(figsize=(10, 6))
ax = sns.barplot(data=main_variants, x='Variant', y='Step F1', palette='Set2')

# Add value labels
for i, bar in enumerate(ax.patches):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{height:.2f}',
            ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.title('F1 Score Comparison (Step Split)', fontsize=16, fontweight='bold')
plt.ylabel('F1 Score', fontsize=12)
plt.xlabel('Model Variant', fontsize=12)
plt.ylim(0, 100)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig('code/plots/f1_comparison_step.png', bbox_inches='tight')
plt.show()

print("\nPlot saved to: code/plots/f1_comparison_step.png")

## 6.3 Plot All Metrics Comparison

In [None]:
# Plot all metrics
metrics = ['Step F1', 'Step AUC', 'Step Precision', 'Step Recall']
plot_data = main_variants[['Variant'] + metrics].copy()
plot_data_melted = plot_data.melt(id_vars='Variant', var_name='Metric', value_name='Score')
plot_data_melted['Metric'] = plot_data_melted['Metric'].str.replace('Step ', '')

plt.figure(figsize=(14, 6))
ax = sns.barplot(data=plot_data_melted, x='Metric', y='Score', hue='Variant', palette='Set2')
plt.title('All Metrics Comparison (Step Split)', fontsize=16, fontweight='bold')
plt.ylabel('Score', fontsize=12)
plt.xlabel('Metric', fontsize=12)
plt.legend(title='Model', fontsize=11)
plt.ylim(0, 100)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig('code/plots/all_metrics_comparison.png', bbox_inches='tight')
plt.show()

print("\nPlot saved to: code/plots/all_metrics_comparison.png")

## 6.4 Learning Curves (if training logs available)

In [None]:
import numpy as np

# Read training log
try:
    log_file = !ls code/stats/error_recognition/LSTM/omnivore/*_training_performance.txt
    log_file = log_file[0] if log_file else None
    
    if log_file:
        # Parse log file
        data = np.genfromtxt(log_file, delimiter=',', skip_header=1)
        epochs = data[:, 0]
        train_loss = data[:, 1]
        val_loss = data[:, 2]
        f1 = data[:, 6]
        auc = data[:, 7]
        
        # Plot learning curves
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
        
        # Loss curves
        ax1.plot(epochs, train_loss, label='Train Loss', linewidth=2, color='#3498db')
        ax1.plot(epochs, val_loss, label='Val Loss', linewidth=2, color='#e74c3c')
        ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Epoch', fontsize=12)
        ax1.set_ylabel('Loss', fontsize=12)
        ax1.legend(fontsize=11)
        ax1.grid(alpha=0.3)
        
        # Metrics curves
        ax2.plot(epochs, f1, label='F1 Score', linewidth=2, color='#2ecc71')
        ax2.plot(epochs, auc, label='AUC', linewidth=2, color='#f39c12')
        ax2.set_title('Validation Metrics', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Epoch', fontsize=12)
        ax2.set_ylabel('Score', fontsize=12)
        ax2.legend(fontsize=11)
        ax2.grid(alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('code/plots/lstm_learning_curves.png', bbox_inches='tight')
        plt.show()
        
        print("\nPlot saved to: code/plots/lstm_learning_curves.png")
    else:
        print("Training log not found. Train the model first.")
except Exception as e:
    print(f"Error plotting learning curves: {e}")

# **7 Save Results to Drive**

In [None]:
# Create results directory in Drive
!mkdir -p "{DRIVE_PATH}/results/lstm_baseline"

# Copy checkpoints
!cp -r code/checkpoints/error_recognition/LSTM "{DRIVE_PATH}/results/lstm_baseline/checkpoints/"

# Copy results
!cp -r code/results/error_recognition "{DRIVE_PATH}/results/lstm_baseline/results/"

# Copy plots
!cp -r code/plots "{DRIVE_PATH}/results/lstm_baseline/plots/"

# Copy training logs
!cp -r code/stats/error_recognition/LSTM "{DRIVE_PATH}/results/lstm_baseline/training_logs/"

print("\nAll results saved to Google Drive!")
print(f"Location: {DRIVE_PATH}/results/lstm_baseline/")

# **8 Summary and Next Steps**

## 8.1 Print Summary

In [None]:
print("="*60)
print("LSTM BASELINE TRAINING COMPLETE")
print("="*60)
print("\nWhat you've accomplished:")
print("âœ… Trained LSTM baseline on step and/or recordings split")
print("âœ… Evaluated LSTM performance with detailed metrics")
print("âœ… Compared with MLP and Transformer baselines")
print("âœ… Generated visualization plots")
print("âœ… Saved all results to Google Drive")

print("\nNext steps:")
print("1. Analyze the per-error-type performance (check evaluation output)")
print("2. Conduct ablation studies (single-layer, unidirectional, etc.)")
print("3. Try LSTM_Attention and GRU variants")
print("4. Write your project report using the results")

print("\nDocumentation:")
print("ðŸ“„ NEW_BASELINE_PROPOSAL.md - Detailed motivation and architecture")
print("ðŸ“„ BASELINE_COMPARISON_ANALYSIS.md - Comprehensive comparison")
print("ðŸ“„ LSTM_QUICKSTART_GUIDE.md - Training and evaluation guide")
print("ðŸ“„ SUMMARY.md - High-level overview")
print("="*60)