# 🎯 USPSA Shot Detection Model Training

This notebook trains a machine learning model to detect gunshots in USPSA stage videos.

## 📋 Prerequisites:
1. Upload your `training_data` folder to Google Drive
2. Folder should contain:
   - 20 video files (.mp4, .mov, etc.)
   - 20 JSON label files

## ⚡ How to Run:
1. Click **Runtime → Run all** (or run cells one by one)
2. When prompted, authorize Google Drive access
3. Wait for training to complete (~30-60 minutes)
4. Download the trained model at the end

---

## 📦 Step 1: Install Dependencies

In [None]:
!pip install -q librosa soundfile scikit-learn

# Suppress warnings for cleaner output
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Suppress TensorFlow warnings

import json
import numpy as np
import librosa
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

print("✅ Dependencies installed")
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

## 💾 Step 2: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Set your data path - UPDATE THIS if your folder name is different
DATA_PATH = '/content/drive/MyDrive/uspsa_training_data'

# Verify the folder exists
if not os.path.exists(DATA_PATH):
    print(f"❌ ERROR: Folder not found at {DATA_PATH}")
    print("Please update DATA_PATH to match your Google Drive folder")
else:
    files = os.listdir(DATA_PATH)
    videos = [f for f in files if f.endswith(('.mp4', '.mov', '.avi', '.webm', '.mkv'))]
    jsons = [f for f in files if f.endswith('.json')]
    print(f"✅ Found {len(videos)} videos and {len(jsons)} JSON files")
    if len(videos) != len(jsons):
        print("⚠️ WARNING: Number of videos and JSON files don't match!")

## 📊 Step 3: Load and Validate Labels

In [None]:
def load_labels(data_path):
    """Load all JSON label files"""
    labels_data = []
    json_files = [f for f in os.listdir(data_path) if f.endswith('.json')]

    for json_file in json_files:
        with open(os.path.join(data_path, json_file), 'r') as f:
            data = json.load(f)
            labels_data.append(data)

    return labels_data

# Load all labels
all_labels = load_labels(DATA_PATH)

# Print summary
print(f"📋 Loaded {len(all_labels)} label files\n")

total_shots = 0
total_beeps = 0
first_person_count = 0
third_person_count = 0

for label_data in all_labels:
    shots = sum(1 for l in label_data['labels'] if l['type'] == 'shot')
    beeps = sum(1 for l in label_data['labels'] if l['type'] == 'beep')

    total_shots += shots
    total_beeps += beeps

    if label_data['recording_type'] == 'first_person':
        first_person_count += 1
    else:
        third_person_count += 1

    print(f"  {label_data['video_id']}: {shots} shots, {beeps} beeps ({label_data['recording_type']})")

print(f"\n📊 Summary:")
print(f"  Total shots: {total_shots}")
print(f"  Total beeps: {total_beeps}")
print(f"  First-person videos: {first_person_count}")
print(f"  Third-person videos: {third_person_count}")
print(f"\n  Total labeled examples: {total_shots + total_beeps}")

# Verify expected files
print(f"\n⚠️ VALIDATION:")
missing_stages = []
for i in range(1, 11):
    for suffix in ['a', 'b']:
        expected_name = f"stage{i}{suffix}"
        # Check if any label has this video_id (case-insensitive check)
        found = any(label_data['video_id'].lower().startswith(expected_name.lower())
                   for label_data in all_labels)
        if not found:
            missing_stages.append(expected_name)

if missing_stages:
    print(f"  ⚠️ Missing expected videos: {', '.join(missing_stages)}")
    print(f"  → This is OK if you have different video names")
else:
    print(f"  ✅ All expected stage videos (stage1a-stage10b) found!")

# Check for unexpected files
unexpected = []
for label_data in all_labels:
    video_id = label_data['video_id'].lower()
    if not any(f"stage{i}" in video_id for i in range(1, 99)):
        unexpected.append(label_data['video_id'])

if unexpected:
    print(f"  ⚠️ Unexpected video files (not stage format): {', '.join(unexpected)}")
    print(f"  → These will still be used for training")

## 🎵 Step 4: Extract Audio Features

In [None]:
def extract_audio_segment(video_path, timestamp, duration=0.15, sr=22050):
    """
    Extract audio segment from video at specified timestamp
    
    Args:
        video_path: Path to video file
        timestamp: Time in seconds
        duration: Duration of segment in seconds
        sr: Sample rate
    
    Returns:
        Mel spectrogram in dB scale with FIXED reference (preserves amplitude)
    """
    try:
        # Load audio from video
        y, _ = librosa.load(video_path, sr=sr, offset=timestamp, duration=duration, mono=True)
        
        # Convert to mel spectrogram (visual representation of audio)
        mel_spec = librosa.feature.melspectrogram(
            y=y,
            sr=sr,
            n_mels=64,  # 64 frequency bands
            fmax=8000,  # Focus on frequencies up to 8kHz (where gunshots are)
            hop_length=512
        )
        
        # **CRITICAL FIX: Use FIXED reference value instead of np.max**
        # This preserves the absolute amplitude differences between loud shots and quiet background
        mel_spec_db = librosa.power_to_db(mel_spec, ref=1.0)
        
        # Resize to fixed dimensions (64x64)
        if mel_spec_db.shape[1] < 64:
            # Pad if too short
            mel_spec_db = np.pad(mel_spec_db, ((0, 0), (0, 64 - mel_spec_db.shape[1])), mode='constant')
        else:
            # Truncate if too long
            mel_spec_db = mel_spec_db[:, :64]
        
        # Return without normalization - will be normalized globally later
        return mel_spec_db
    
    except Exception as e:
        print(f"Error extracting audio: {e}")
        return None

print("✅ Audio extraction function ready (with FIXED dB reference)")

## 🔧 Step 5: Prepare Training Dataset

In [None]:
def prepare_dataset(labels_data, data_path, non_shots_per_video=20):
    """
    Extract features and labels from all videos for 3-CLASS CLASSIFICATION
    
    **AUTO-GENERATES non-shot samples** by randomly sampling timestamps
    that are far from any labeled shots or beeps.

    Args:
        labels_data: List of label data from JSON files
        data_path: Path to video directory
        non_shots_per_video: Number of random non-shot samples to generate per video

    Returns:
        X: Array of spectrograms (features)
        y: Array of labels (0 = non-shot, 1 = shot, 2 = beep)
        metadata: List of dicts with info about each sample
    """
    X = []
    y = []
    metadata = []

    # Build case-insensitive file lookup
    actual_files = {}
    for filename in os.listdir(data_path):
        if filename.endswith(('.mp4', '.mov', '.avi', '.webm', '.mkv', '.MP4', '.MOV', '.AVI', '.WEBM', '.MKV')):
            actual_files[filename.lower()] = filename

    print(f"Found {len(actual_files)} video files in directory\n")

    for label_file in labels_data:
        video_id = label_file['video_id']

        # Try case-insensitive match
        video_filename = actual_files.get(video_id.lower())

        if video_filename is None:
            print(f"⚠️ Video not found: {video_id}")
            print(f"   (Checked for case-insensitive match)")
            continue

        video_path = os.path.join(data_path, video_filename)

        if not os.path.exists(video_path):
            print(f"⚠️ Video not found: {video_path}")
            continue

        print(f"Processing {video_filename}...")

        # Collect excluded times (shots and beeps)
        excluded_times = []
        
        # Process labeled shots and beeps
        for label in label_file['labels']:
            if label['type'] == 'shot':
                # Extract audio features for shot
                spec = extract_audio_segment(video_path, label['time'])
                if spec is not None:
                    X.append(spec)
                    y.append(1)  # 1 = shot
                    metadata.append({
                        'video_id': video_filename,
                        'time': label['time'],
                        'type': 'shot',
                        'recording_type': label_file['recording_type']
                    })
                    excluded_times.append(label['time'])
                    
            elif label['type'] == 'beep':
                # Extract audio features for beep
                spec = extract_audio_segment(video_path, label['time'])
                if spec is not None:
                    X.append(spec)
                    y.append(2)  # 2 = beep
                    metadata.append({
                        'video_id': video_filename,
                        'time': label['time'],
                        'type': 'beep',
                        'recording_type': label_file['recording_type']
                    })
                    excluded_times.append(label['time'])

        # Get video duration
        try:
            duration = librosa.get_duration(path=video_path)
        except:
            print(f"  ⚠️ Could not get duration for {video_filename}, skipping non-shot generation")
            continue

        # AUTO-GENERATE non-shot samples
        # Sample random times that are at least 0.5s away from any shot or beep
        non_shot_count = 0
        attempts = 0
        max_attempts = non_shots_per_video * 10  # Try up to 10x to find good samples

        while non_shot_count < non_shots_per_video and attempts < max_attempts:
            attempts += 1
            
            # Random timestamp (avoid first/last 1 second)
            random_time = np.random.uniform(1.0, max(2.0, duration - 1.0))
            
            # Check if it's far enough from any shot/beep (at least 0.5s away)
            min_distance = min([abs(random_time - t) for t in excluded_times]) if excluded_times else 999
            
            if min_distance >= 0.5:
                # Extract audio features for non-shot
                spec = extract_audio_segment(video_path, random_time)
                if spec is not None:
                    X.append(spec)
                    y.append(0)  # 0 = non-shot
                    metadata.append({
                        'video_id': video_filename,
                        'time': random_time,
                        'type': 'non_shot_auto',
                        'recording_type': label_file['recording_type']
                    })
                    non_shot_count += 1

        print(f"  Generated {non_shot_count} random non-shot samples")

    X = np.array(X)
    y = np.array(y)

    print(f"\n✅ Dataset prepared:")
    print(f"  Total samples: {len(X)}")
    print(f"  Non-shots (class 0): {np.sum(y == 0)}")
    print(f"  Shots (class 1): {np.sum(y == 1)}")
    print(f"  Beeps (class 2): {np.sum(y == 2)}")
    print(f"  Feature shape: {X.shape}")

    return X, y, metadata

# Prepare the dataset with auto-generated non-shots
X, y, metadata = prepare_dataset(all_labels, DATA_PATH, non_shots_per_video=20)

## 🔧 Step 5.5: Apply Global Normalization (CRITICAL FIX)

In [None]:
print("\n" + "="*60)
print("🔧 APPLYING GLOBAL NORMALIZATION")
print("="*60)

# Calculate global min/max across ALL training data
global_min = X.min()
global_max = X.max()

print(f"\nGlobal statistics from {len(X)} samples:")
print(f"  Min: {global_min:.4f}")
print(f"  Max: {global_max:.4f}")
print(f"  Mean: {X.mean():.4f}")
print(f"  Std: {X.std():.4f}")

# Apply global normalization to all data
X_normalized = (X - global_min) / (global_max - global_min + 1e-8)

print(f"\nAfter global normalization:")
print(f"  Min: {X_normalized.min():.4f}")
print(f"  Max: {X_normalized.max():.4f}")
print(f"  Mean: {X_normalized.mean():.4f}")
print(f"  Std: {X_normalized.std():.4f}")

# Replace X with normalized version
X = X_normalized

# **CRITICAL: Save normalization parameters for inference**
normalization_params = {
    'global_min': float(global_min),
    'global_max': float(global_max),
    'sample_rate': 22050,
    'segment_duration': 0.15,
    'n_mels': 64,
    'spec_size': 64
}

# Save to Google Drive
norm_params_path = '/content/drive/MyDrive/normalization_params.json'
with open(norm_params_path, 'w') as f:
    json.dump(normalization_params, f, indent=2)

print(f"\n✅ Normalization parameters saved to: {norm_params_path}")
print("\n⚠️  IMPORTANT: Download this file along with the model!")
print("   The server needs these values for inference!")

print("="*60)

## 📸 Step 6: Visualize Examples

In [None]:
# Show examples of non-shots, shots, and beeps
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

# Find examples of each class
non_shot_indices = np.where(y == 0)[0][:4]
shot_indices = np.where(y == 1)[0][:4]
beep_indices = np.where(y == 2)[0][:4] if np.sum(y == 2) > 0 else []

for i, idx in enumerate(non_shot_indices):
    axes[0, i].imshow(X[idx], aspect='auto', origin='lower', cmap='viridis')
    axes[0, i].set_title(f"Non-shot: {metadata[idx]['video_id'][:20]}...\n{metadata[idx]['time']:.2f}s")
    axes[0, i].axis('off')

for i, idx in enumerate(shot_indices):
    axes[1, i].imshow(X[idx], aspect='auto', origin='lower', cmap='viridis')
    axes[1, i].set_title(f"Shot: {metadata[idx]['video_id'][:20]}...\n{metadata[idx]['time']:.2f}s")
    axes[1, i].axis('off')

if len(beep_indices) > 0:
    for i, idx in enumerate(beep_indices):
        axes[2, i].imshow(X[idx], aspect='auto', origin='lower', cmap='viridis')
        axes[2, i].set_title(f"Beep: {metadata[idx]['video_id'][:20]}...\n{metadata[idx]['time']:.2f}s")
        axes[2, i].axis('off')
else:
    for i in range(4):
        axes[2, i].text(0.5, 0.5, 'No beep samples', ha='center', va='center')
        axes[2, i].axis('off')

plt.tight_layout()
plt.suptitle('Spectrogram Examples: Non-Shots (top) vs Shots (middle) vs Beeps (bottom)', y=1.02, fontsize=14)
plt.show()

print("📊 Notice the visual differences:")
print("  - Non-shots: Random patterns or sustained noise")
print("  - Shots: Bright vertical bands (sudden energy burst)")
print("  - Beeps: Sustained tone with consistent frequency pattern")

## 🔀 Step 7: Split Dataset

In [None]:
# Reshape for CNN (add channel dimension)
X = X.reshape(X.shape[0], X.shape[1], X.shape[2], 1)

# Convert labels to categorical (one-hot encoding) for 3-class classification
from tensorflow.keras.utils import to_categorical
y_categorical = to_categorical(y, num_classes=3)

# Split into train (70%), validation (15%), and test (15%)
X_train, X_temp, y_train, y_temp = train_test_split(X, y_categorical, test_size=0.3, random_state=42, stratify=y)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp.argmax(axis=1))

# Get class counts for each split
y_train_classes = y_train.argmax(axis=1)
y_val_classes = y_val.argmax(axis=1)
y_test_classes = y_test.argmax(axis=1)

print(f"✅ Dataset split:")
print(f"  Training: {len(X_train)} samples ({len(X_train)/len(X)*100:.1f}%)")
print(f"  Validation: {len(X_val)} samples ({len(X_val)/len(X)*100:.1f}%)")
print(f"  Test: {len(X_test)} samples ({len(X_test)/len(X)*100:.1f}%)")
print(f"\n  Train: {np.sum(y_train_classes == 0)} non-shots | {np.sum(y_train_classes == 1)} shots | {np.sum(y_train_classes == 2)} beeps")
print(f"  Val: {np.sum(y_val_classes == 0)} non-shots | {np.sum(y_val_classes == 1)} shots | {np.sum(y_val_classes == 2)} beeps")
print(f"  Test: {np.sum(y_test_classes == 0)} non-shots | {np.sum(y_test_classes == 1)} shots | {np.sum(y_test_classes == 2)} beeps")

# DIAGNOSTIC: Check if validation data looks normal
print(f"\n🔍 VALIDATION SET DIAGNOSTICS:")
print(f"  Val data shape: {X_val.shape}")
print(f"  Val data min: {X_val.min():.4f}, max: {X_val.max():.4f}")
print(f"  Val data mean: {X_val.mean():.4f}, std: {X_val.std():.4f}")
print(f"  Any NaN values: {np.isnan(X_val).any()}")
print(f"  Any Inf values: {np.isinf(X_val).any()}")

# Check a few validation samples
print(f"\n  Sample validation features (first 5):")
for i in range(min(5, len(X_val))):
    class_id = y_val_classes[i]
    class_name = ['non-shot', 'shot', 'beep'][class_id]
    print(f"    Val[{i}]: class={class_name}, min={X_val[i].min():.3f}, max={X_val[i].max():.3f}, mean={X_val[i].mean():.3f}")

## 🏗️ Step 8: Build Model Architecture

In [None]:
def create_model(input_shape=(64, 64, 1)):
    """
    Create CNN model for 3-class shot detection
    
    Architecture:
    - 3 Convolutional blocks (Conv2D + MaxPool + BatchNorm)
    - Dense layers for classification  
    - 3-class output (non-shot, shot, beep)
    """
    model = keras.Sequential([
        # Input layer
        layers.Input(shape=input_shape),
        
        # Block 1
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        
        # Block 2
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        
        # Block 3
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        
        # Flatten and dense layers
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(64, activation='relu'),
        
        # Output layer (3-class classification: non-shot, shot, beep)
        layers.Dense(3, activation='softmax')
    ])
    
    # Compile model with categorical crossentropy for multi-class
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy', keras.metrics.Precision(), keras.metrics.Recall()]
    )
    
    return model

# Create the model
model = create_model()
model.summary()

print(f"\n✅ Model created with {model.count_params():,} parameters")
print(f"📊 3-class output: [non-shot, shot, beep]")

## 🎓 Step 9: Train the Model

In [None]:
# Calculate class weights to handle imbalance
from sklearn.utils.class_weight import compute_class_weight

# Get class labels from one-hot encoded y_train
y_train_classes = y_train.argmax(axis=1)

# Calculate automatic weights first
class_weights_array = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train_classes),
    y=y_train_classes
)

print(f"⚖️ Automatic class weights calculated:")
print(f"  Non-shot (0): {class_weights_array[0]:.2f}")
print(f"  Shot (1): {class_weights_array[1]:.2f}")
print(f"  Beep (2): {class_weights_array[2]:.2f}")

# **CRITICAL FIX: Cap beep weight to prevent overfitting**
# Extreme beep weight causes model to memorize rather than learn
class_weights = {
    0: class_weights_array[0],
    1: class_weights_array[1],
    2: min(class_weights_array[2], 5.0)  # Cap beep weight at 5.0
}

print(f"\n✅ Adjusted class weights (beep capped at 5.0 to prevent overfitting):")
print(f"  Non-shot (0): {class_weights[0]:.2f}")
print(f"  Shot (1): {class_weights[1]:.2f}")
print(f"  Beep (2): {class_weights[2]:.2f}")
print(f"  This balances learning without overfitting on rare classes\n")

# Callbacks for training
callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=15,  # Increased patience to allow more learning
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=7,  # Increased patience for learning rate reduction
        min_lr=1e-7,
        verbose=1
    )
]

# Train the model with adjusted class weights
print("🚀 Starting training...\n")

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=50,
    batch_size=32,
    class_weight=class_weights,  # ← Using adjusted weights
    callbacks=callbacks,
    verbose=1
)

print("\n✅ Training complete!")

## 🧪 Step 9.5: Verify Model Learned

In [None]:
print("\n" + "="*60)
print("🧪 VERIFYING MODEL LEARNED TO DISTINGUISH CLASSES")
print("="*60)

# Test on training samples
non_shot_indices = np.where(y_train_classes == 0)[0][:5]
shot_indices = np.where(y_train_classes == 1)[0][:5]
beep_indices = np.where(y_train_classes == 2)[0][:5] if np.sum(y_train_classes == 2) > 0 else []

class_names = ['non-shot', 'shot', 'beep']

print("\nPredictions on NON-SHOT samples (should predict class 0):")
for idx in non_shot_indices:
    pred = model.predict(X_train[idx:idx+1], verbose=0)[0]
    predicted_class = pred.argmax()
    confidence = pred[predicted_class] * 100
    print(f"  Non-shot: predicted={class_names[predicted_class]} ({confidence:.1f}%) {'✅' if predicted_class == 0 else '❌ FAILED'}")

print("\nPredictions on SHOT samples (should predict class 1):")
for idx in shot_indices:
    pred = model.predict(X_train[idx:idx+1], verbose=0)[0]
    predicted_class = pred.argmax()
    confidence = pred[predicted_class] * 100
    print(f"  Shot: predicted={class_names[predicted_class]} ({confidence:.1f}%) {'✅' if predicted_class == 1 else '❌ FAILED'}")

if len(beep_indices) > 0:
    print("\nPredictions on BEEP samples (should predict class 2):")
    for idx in beep_indices:
        pred = model.predict(X_train[idx:idx+1], verbose=0)[0]
        predicted_class = pred.argmax()
        confidence = pred[predicted_class] * 100
        print(f"  Beep: predicted={class_names[predicted_class]} ({confidence:.1f}%) {'✅' if predicted_class == 2 else '❌ FAILED'}")

# Calculate average predictions per class
print(f"\n📊 Average confidence scores:")
for class_id in range(3):
    class_samples = X_train[y_train_classes == class_id][:20]
    if len(class_samples) > 0:
        preds = model.predict(class_samples, verbose=0)
        avg_confidence = preds[:, class_id].mean() * 100
        print(f"  {class_names[class_id]}: {avg_confidence:.1f}% confident on own class")

print("="*60)

## 📈 Step 10: Visualize Training History

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot accuracy
axes[0, 0].plot(history.history['accuracy'], label='Train')
axes[0, 0].plot(history.history['val_accuracy'], label='Validation')
axes[0, 0].set_title('Model Accuracy')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Accuracy')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Plot loss
axes[0, 1].plot(history.history['loss'], label='Train')
axes[0, 1].plot(history.history['val_loss'], label='Validation')
axes[0, 1].set_title('Model Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Plot precision
axes[1, 0].plot(history.history['precision'], label='Train')
axes[1, 0].plot(history.history['val_precision'], label='Validation')
axes[1, 0].set_title('Model Precision')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Precision')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Plot recall
axes[1, 1].plot(history.history['recall'], label='Train')
axes[1, 1].plot(history.history['val_recall'], label='Validation')
axes[1, 1].set_title('Model Recall')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Recall')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Evaluate on test set
print("="*60)
print("🎯 EVALUATING ON TEST SET")
print("="*60)

test_loss, test_acc, test_precision, test_recall = model.evaluate(X_test, y_test, verbose=0)

print(f"\nTest Set Performance:")
print(f"  Accuracy:  {test_acc*100:.2f}%")
print(f"  Precision: {test_precision*100:.2f}%")
print(f"  Recall:    {test_recall*100:.2f}%")
print(f"  Loss:      {test_loss:.4f}")

# Confusion matrix for 3 classes
y_pred = model.predict(X_test, verbose=0).argmax(axis=1)
y_true = y_test.argmax(axis=1)
cm = confusion_matrix(y_true, y_pred)

print(f"\nConfusion Matrix:")
print(f"                 Predicted")
print(f"              Non-shot  Shot  Beep")
print(f"  Actual")
print(f"  Non-shot  [{cm[0][0]:5d}  {cm[0][1]:4d}  {cm[0][2]:4d}]")
print(f"  Shot      [{cm[1][0]:5d}  {cm[1][1]:4d}  {cm[1][2]:4d}]")
print(f"  Beep      [{cm[2][0]:5d}  {cm[2][1]:4d}  {cm[2][2]:4d}]")

# Per-class metrics
print(f"\nPer-Class Performance:")
for class_id, class_name in enumerate(['Non-shot', 'Shot', 'Beep']):
    class_mask = (y_true == class_id)
    class_correct = np.sum((y_pred == class_id) & class_mask)
    class_total = np.sum(class_mask)
    if class_total > 0:
        class_acc = class_correct / class_total * 100
        print(f"  {class_name}: {class_acc:.1f}% ({class_correct}/{class_total})")

print("="*60)

In [ ]:
# Save model (using modern Keras format)
print("="*60)
print("💾 SAVING MODEL")
print("="*60)

local_model_path = '/content/shot_detector_model.keras'
model.save(local_model_path)
print(f"✅ Model saved to: {local_model_path}")

# Also save to Google Drive  
drive_model_path = '/content/drive/MyDrive/shot_detector_model.keras'
model.save(drive_model_path)
print(f"✅ Model also saved to Google Drive: {drive_model_path}")

# Save model metadata (including 3-class info)
model_info = {
    'model_type': '3-class',
    'classes': ['non-shot', 'shot', 'beep'],
    'test_accuracy': float(test_acc),
    'test_precision': float(test_precision),
    'test_recall': float(test_recall),
    'total_training_samples': len(X_train),
    'total_non_shots_trained': int(np.sum(y_train_classes == 0)),
    'total_shots_trained': int(np.sum(y_train_classes == 1)),
    'total_beeps_trained': int(np.sum(y_train_classes == 2)),
    'input_shape': [64, 64, 1],
    'sample_rate': 22050,
    'segment_duration': 0.15
}

metadata_path = '/content/drive/MyDrive/shot_detector_metadata.json'
with open(metadata_path, 'w') as f:
    json.dump(model_info, f, indent=2)
print(f"✅ Metadata saved to: {metadata_path}")

# Download model
print("\n📥 Downloading model...")
from google.colab import files
files.download(local_model_path)

# Download normalization parameters (CRITICAL!)
print("\n📥 Downloading normalization parameters...")
files.download('/content/drive/MyDrive/normalization_params.json')

print("\n" + "="*50)
print("🎉 TRAINING COMPLETE!")
print("="*50)
print(f"\nFinal Model Performance:")
print(f"  Test Accuracy: {test_acc*100:.2f}%")
print(f"  Test Precision: {test_precision*100:.2f}%")
print(f"  Test Recall: {test_recall*100:.2f}%")
print("\n✅ Downloaded files:")
print("  1. shot_detector_model.keras (3-class model)")
print("  2. normalization_params.json (REQUIRED for server!)")
print("\nReplace both files in your project directory.")
print("Server code will need to be updated for 3-class predictions.")
print("="*50)