# ML Event Tagger - Training & Evaluation

This notebook demonstrates the complete training pipeline and provides detailed evaluation metrics for the multi-label event classification model.

## Overview

- **Goal:** Train a TensorFlow/Keras model to predict event tags
- **Architecture:** Embedding → Pooling → Dense layers → Sigmoid output
- **Data:** 100 labeled events, 21 tags
- **Split:** 70% train, 15% validation, 15% test


## 1. Setup & Imports


In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow import keras
from sklearn.metrics import classification_report, confusion_matrix, multilabel_confusion_matrix

# Import our modules
from ml_event_tagger.config import TAGS
from ml_event_tagger.model import create_model
from ml_event_tagger.train import (
    load_preprocessed_data,
    create_tokenizer,
    tokenize_texts
)

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

print(f"✅ Imports loaded")
print(f"📊 Number of tags: {len(TAGS)}")
print(f"🏷️  Tags: {', '.join(TAGS)}")


## 2. Load Data


In [None]:
# Load preprocessed data
train_texts, train_labels, val_texts, val_labels, test_texts, test_labels, tag_names = load_preprocessed_data()

print(f"✅ Data loaded:")
print(f"   Train: {len(train_texts)} samples")
print(f"   Val:   {len(val_texts)} samples")
print(f"   Test:  {len(test_texts)} samples")
print(f"   Tags:  {len(tag_names)} tags")
print()

# Show sample
print("Sample preprocessed text:")
print(f"  '{train_texts[0][:100]}...'")
print()
print("Sample labels:")
sample_tags = [tag for i, tag in enumerate(TAGS) if train_labels[0][i] == 1]
print(f"  Tags: {sample_tags}")


## 3. Tokenize Text


In [None]:
# Create and fit tokenizer
MAX_VOCAB_SIZE = 10000
MAX_LENGTH = 200

tokenizer = create_tokenizer(train_texts, vocab_size=MAX_VOCAB_SIZE)
vocab_size = len(tokenizer.word_index) + 1

print(f"✅ Tokenizer created")
print(f"   Vocabulary size: {vocab_size:,} words")
print(f"   Max sequence length: {MAX_LENGTH}")
print()

# Tokenize
X_train = tokenize_texts(tokenizer, train_texts, max_length=MAX_LENGTH)
X_val = tokenize_texts(tokenizer, val_texts, max_length=MAX_LENGTH)
X_test = tokenize_texts(tokenizer, test_texts, max_length=MAX_LENGTH)

print(f"✅ Texts tokenized:")
print(f"   X_train: {X_train.shape}")
print(f"   X_val:   {X_val.shape}")
print(f"   X_test:  {X_test.shape}")
print()

# Show most common words
word_freq = sorted(tokenizer.word_index.items(), key=lambda x: x[1])[:20]
print("Top 20 most common words:")
print("  ", ", ".join([w for w, _ in word_freq]))


## 4. Load Trained Model

We'll load the model that was trained using `train.py`.


In [None]:
# Load trained model
model = keras.models.load_model("models/event_tagger_model.h5")

print("✅ Model loaded successfully")
print()
model.summary()


## 5. Evaluate on Test Set


In [None]:
# Evaluate
test_results = model.evaluate(X_test, test_labels, verbose=0)

print("📊 Test Set Results:")
print("=" * 50)
print(f"  Loss:            {test_results[0]:.4f}")
print(f"  Binary Accuracy: {test_results[1]:.4f} ({test_results[1]*100:.1f}%)")
print(f"  Precision:       {test_results[2]:.4f} ({test_results[2]*100:.1f}%)")
print(f"  Recall:          {test_results[3]:.4f} ({test_results[3]*100:.1f}%)")

# Calculate F1 score
precision = test_results[2]
recall = test_results[3]
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
print(f"  F1 Score:        {f1:.4f} ({f1*100:.1f}%)")
print()

# Check success criteria
print("✅ Success Criteria:")
if precision >= 0.60:
    print(f"  ✅ Precision ≥ 60%: {precision*100:.1f}% (PASSED)")
else:
    print(f"  ❌ Precision ≥ 60%: {precision*100:.1f}% (FAILED)")


## 6. Make Predictions


In [None]:
# Get predictions
y_pred_proba = model.predict(X_test, verbose=0)

# Convert to binary predictions (threshold = 0.5)
y_pred = (y_pred_proba >= 0.5).astype(int)

print(f"✅ Predictions generated")
print(f"   Shape: {y_pred.shape}")
print(f"   Predicted tags per event: {y_pred.sum(axis=1).mean():.1f} (average)")
print(f"   Actual tags per event:    {test_labels.sum(axis=1).mean():.1f} (average)")


## 7. Training History Visualization


In [None]:
# Load training history
with open('models/training_history.json') as f:
    history = json.load(f)

epochs = range(1, len(history['loss']) + 1)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Training History', fontsize=16, fontweight='bold')

# Loss
ax = axes[0, 0]
ax.plot(epochs, history['loss'], 'b-', label='Train Loss', linewidth=2)
ax.plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Binary Crossentropy Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Accuracy
ax = axes[0, 1]
ax.plot(epochs, history['binary_accuracy'], 'b-', label='Train Accuracy', linewidth=2)
ax.plot(epochs, history['val_binary_accuracy'], 'r-', label='Val Accuracy', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_title('Binary Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Precision
ax = axes[1, 0]
ax.plot(epochs, history['precision'], 'b-', label='Train Precision', linewidth=2)
ax.plot(epochs, history['val_precision'], 'r-', label='Val Precision', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Precision')
ax.set_title('Precision')
ax.legend()
ax.grid(True, alpha=0.3)

# Recall
ax = axes[1, 1]
ax.plot(epochs, history['recall'], 'b-', label='Train Recall', linewidth=2)
ax.plot(epochs, history['val_recall'], 'r-', label='Val Recall', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Recall')
ax.set_title('Recall')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"✅ Training completed after {len(epochs)} epochs")
print(f"   Final validation loss: {history['val_loss'][-1]:.4f}")
print(f"   Final validation precision: {history['val_precision'][-1]:.4f}")


## 8. Per-Tag Performance Analysis


In [None]:
# Calculate per-tag metrics
from sklearn.metrics import precision_score, recall_score, f1_score

per_tag_metrics = []
for i, tag in enumerate(TAGS):
    y_true_tag = test_labels[:, i]
    y_pred_tag = y_pred[:, i]

    # Skip tags with no positive samples
    if y_true_tag.sum() == 0:
        per_tag_metrics.append({
            'tag': tag,
            'count': 0,
            'precision': 0,
            'recall': 0,
            'f1': 0
        })
        continue

    prec = precision_score(y_true_tag, y_pred_tag, zero_division=0)
    rec = recall_score(y_true_tag, y_pred_tag, zero_division=0)
    f1_tag = f1_score(y_true_tag, y_pred_tag, zero_division=0)

    per_tag_metrics.append({
        'tag': tag,
        'count': int(y_true_tag.sum()),
        'precision': prec,
        'recall': rec,
        'f1': f1_tag
    })

# Display table
print("Per-Tag Performance on Test Set:")
print("=" * 70)
print(f"{'Tag':<12} {'Count':<7} {'Precision':<12} {'Recall':<12} {'F1 Score':<12}")
print("-" * 70)
for m in per_tag_metrics:
    if m['count'] > 0:
        print(f"{m['tag']:<12} {m['count']:<7} {m['precision']:<12.2f} {m['recall']:<12.2f} {m['f1']:<12.2f}")
    else:
        print(f"{m['tag']:<12} {m['count']:<7} {'N/A':<12} {'N/A':<12} {'N/A':<12}")
print("=" * 70)


## 9. Per-Tag Precision/Recall Bar Chart


In [None]:
# Filter tags with samples
tags_with_samples = [m for m in per_tag_metrics if m['count'] > 0]

if len(tags_with_samples) > 0:
    tags = [m['tag'] for m in tags_with_samples]
    precisions = [m['precision'] for m in tags_with_samples]
    recalls = [m['recall'] for m in tags_with_samples]

    x = np.arange(len(tags))
    width = 0.35

    fig, ax = plt.subplots(figsize=(14, 6))
    bars1 = ax.bar(x - width/2, precisions, width, label='Precision', color='steelblue', alpha=0.8)
    bars2 = ax.bar(x + width/2, recalls, width, label='Recall', color='coral', alpha=0.8)

    ax.set_xlabel('Tag', fontweight='bold')
    ax.set_ylabel('Score', fontweight='bold')
    ax.set_title('Per-Tag Precision and Recall on Test Set', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(tags, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim(0, 1.1)

    # Add value labels on bars
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            if height > 0:
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{height:.2f}',
                       ha='center', va='bottom', fontsize=8)

    plt.tight_layout()
    plt.show()
else:
    print("⚠️  No tags with test samples to plot")


## 10. Tag Frequency Distribution


In [None]:
# Count tag frequencies in test set
tag_counts = test_labels.sum(axis=0)
tag_freq = [(TAGS[i], int(tag_counts[i])) for i in range(len(TAGS)) if tag_counts[i] > 0]
tag_freq = sorted(tag_freq, key=lambda x: x[1], reverse=True)

if len(tag_freq) > 0:
    tags, counts = zip(*tag_freq)

    fig, ax = plt.subplots(figsize=(12, 6))
    bars = ax.bar(range(len(tags)), counts, color='teal', alpha=0.7)
    ax.set_xlabel('Tag', fontweight='bold')
    ax.set_ylabel('Frequency', fontweight='bold')
    ax.set_title('Tag Frequency Distribution in Test Set', fontsize=14, fontweight='bold')
    ax.set_xticks(range(len(tags)))
    ax.set_xticklabels(tags, rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')

    # Add count labels
    for i, bar in enumerate(bars):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{int(height)}',
               ha='center', va='bottom', fontsize=10, fontweight='bold')

    plt.tight_layout()
    plt.show()

    print("\nTag frequencies in test set:")
    for tag, count in tag_freq:
        print(f"  {tag:<12} {count:>3}")
else:
    print("⚠️  No tags found in test set")


## 11. Sample Predictions


In [None]:
# Show a few sample predictions
print("Sample Predictions:")
print("=" * 80)

for i in range(min(5, len(test_texts))):
    print(f"\nEvent {i+1}:")
    print(f"  Text: '{test_texts[i][:100]}...'")

    # Actual tags
    actual_tags = [TAGS[j] for j in range(len(TAGS)) if test_labels[i][j] == 1]
    print(f"  Actual tags:    {actual_tags}")

    # Predicted tags
    predicted_tags = [TAGS[j] for j in range(len(TAGS)) if y_pred[i][j] == 1]
    print(f"  Predicted tags: {predicted_tags}")

    # Top 5 predictions with confidence
    top5_indices = np.argsort(y_pred_proba[i])[-5:][::-1]
    top5_tags = [(TAGS[j], y_pred_proba[i][j]) for j in top5_indices]
    print(f"  Top 5 (with confidence):")
    for tag, conf in top5_tags:
        print(f"    {tag:<12} {conf:.3f}")

    # Check correctness
    correct = set(actual_tags) == set(predicted_tags)
    if correct:
        print("  ✅ Perfect match!")
    else:
        missed = set(actual_tags) - set(predicted_tags)
        extra = set(predicted_tags) - set(actual_tags)
        if missed:
            print(f"  ⚠️  Missed: {missed}")
        if extra:
            print(f"  ⚠️  Extra: {extra}")
    print("-" * 80)


## 12. Summary & Conclusions


In [None]:
print("="*80)
print("TRAINING & EVALUATION SUMMARY")
print("="*80)
print()
print("📊 Dataset:")
print(f"  Total events: {len(train_texts) + len(val_texts) + len(test_texts)}")
print(f"  Train: {len(train_texts)} | Val: {len(val_texts)} | Test: {len(test_texts)}")
print(f"  Number of tags: {len(TAGS)}")
print()
print("🧠 Model:")
print(f"  Architecture: Embedding → Pooling → Dense layers → Sigmoid")
print(f"  Vocabulary size: {vocab_size:,} words")
print(f"  Max sequence length: {MAX_LENGTH}")
print(f"  Parameters: ~107K")
print()
print("📈 Performance:")
print(f"  Test Precision: {test_results[2]:.1%}")
print(f"  Test Recall:    {test_results[3]:.1%}")
print(f"  Test F1 Score:  {f1:.1%}")
print(f"  Test Accuracy:  {test_results[1]:.1%}")
print()
print("✅ Success Criteria:")
if test_results[2] >= 0.60:
    print(f"  ✅ Precision ≥ 60%: {test_results[2]:.1%} (PASSED)")
else:
    print(f"  ❌ Precision ≥ 60%: {test_results[2]:.1%} (FAILED)")
print()
print("💡 Insights:")
print(f"  • High precision ({test_results[2]:.1%}) = Model is confident when it predicts")
print(f"  • Moderate recall ({test_results[3]:.1%}) = Model is conservative")
print(f"  • Good results for only {len(train_texts)} training samples!")
print(f"  • Performance will improve with more data (see ROADMAP v0.2)")
print()
print("="*80)
