In [10]:
import numpy as np
import pandas as pd
import os
import time
import json
from glob import glob

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score


In [6]:

# Check what files we have
pbp_files = glob('./nba_data/pbp_*.json')
summary_files = glob('./nba_data/summary_*.json')

print(f"PBP files: {len(pbp_files)}")
print(f"Summary files: {len(summary_files)}")

# Load all games and extract features
data = []
for pbp_file in pbp_files:
    game_id = os.path.basename(pbp_file).replace('pbp_', '').replace('.json', '')
    summary_file = f'./nba_data/summary_{game_id}.json'
    
    # Skip if summary doesn't exist
    if not os.path.exists(summary_file):
        print(f"Missing summary for {game_id}, skipping...")
        continue
    
    with open(pbp_file) as f:
        pbp = json.load(f)
    with open(summary_file) as f:
        summary = json.load(f)
    
    home_won = summary['home']['points'] > summary['away']['points']
    
    # Extract snapshots on possession changes
    for period in pbp.get('periods', []):
        home_score = away_score = 0
        last_possession_team = None
        
        for i, event in enumerate(period.get('events', [])):
            # Update scores
            if event.get('event_type') == 'fieldgoalmade':
                points = event.get('points', 0)
                if event.get('attribution', {}).get('team', {}).get('id') == pbp['home']['id']:
                    home_score += points
                else:
                    away_score += points
            
            # Get current possession team
            current_team = event.get('attribution', {}).get('team', {}).get('id')
            
            # Sample on possession change OR every 10 events
            if current_team != last_possession_team or i % 10 == 0:
                data.append({
                    'quarter': period.get('number', 1),
                    'score_diff': home_score - away_score,
                    'home_won': int(home_won)
                })
                last_possession_team = current_team

PBP files: 100
Summary files: 99
Missing summary for c7968eb5-62d0-4595-b64c-aea8b6fee489, skipping...


In [7]:
# data preprocessing 
# extract ALL available features from play-by-play
data = []
for pbp_file in pbp_files:
    game_id = os.path.basename(pbp_file).replace('pbp_', '').replace('.json', '')
    summary_file = f'./nba_data/summary_{game_id}.json'
    
    if not os.path.exists(summary_file):
        continue
    
    try:
        with open(pbp_file) as f:
            pbp = json.load(f)
        with open(summary_file) as f:
            summary = json.load(f)
        
        # Check if required keys exist
        if 'home' not in pbp or 'away' not in pbp or 'periods' not in pbp:
            print(f"Skipping {game_id}: missing required keys")
            continue
        
        home_won = summary['home']['points'] > summary['away']['points']
        home_id = pbp['home']['id']
        
    except Exception as e:
        print(f"Error loading {game_id}: {e}")
        continue
    
    for period in pbp.get('periods', []):
        # Track everything
        home_score = away_score = 0
        home_timeouts = away_timeouts = 7
        home_fouls = away_fouls = 0
        home_fgm = home_fga = away_fgm = away_fga = 0
        home_3pm = home_3pa = away_3pm = away_3pa = 0
        home_ftm = home_fta = away_ftm = away_fta = 0
        home_turnovers = away_turnovers = 0
        home_rebounds = away_rebounds = 0
        home_assists = away_assists = 0
        home_steals = away_steals = 0
        home_blocks = away_blocks = 0
        home_recent_points = away_recent_points = []
        
        last_possession_team = None
        
        for i, event in enumerate(period.get('events', [])):
            event_type = event.get('event_type', '')
            team_id = event.get('attribution', {}).get('team', {}).get('id')
            is_home = team_id == home_id
            
            # Scores
            if event_type == 'fieldgoalmade':
                points = event.get('points', 0)
                if is_home:
                    home_score += points
                    home_fgm += 1
                    home_recent_points.append(points)
                    if points == 3:
                        home_3pm += 1
                else:
                    away_score += points
                    away_fgm += 1
                    away_recent_points.append(points)
                    if points == 3:
                        away_3pm += 1
            
            elif event_type == 'fieldgoalmissed':
                if is_home:
                    home_fga += 1
                    if event.get('three_point_shot'):
                        home_3pa += 1
                else:
                    away_fga += 1
                    if event.get('three_point_shot'):
                        away_3pa += 1
            
            elif event_type == 'freethrow':
                if event.get('made'):
                    if is_home:
                        home_score += 1
                        home_ftm += 1
                        home_recent_points.append(1)
                    else:
                        away_score += 1
                        away_ftm += 1
                        away_recent_points.append(1)
                if is_home:
                    home_fta += 1
                else:
                    away_fta += 1
            
            # Other events
            elif event_type == 'timeout':
                if is_home:
                    home_timeouts = max(0, home_timeouts - 1)
                else:
                    away_timeouts = max(0, away_timeouts - 1)
            
            elif event_type == 'foul':
                if is_home:
                    home_fouls += 1
                else:
                    away_fouls += 1
            
            elif event_type == 'turnover':
                if is_home:
                    home_turnovers += 1
                else:
                    away_turnovers += 1
            
            elif event_type == 'rebound':
                if is_home:
                    home_rebounds += 1
                else:
                    away_rebounds += 1
            
            elif event_type == 'assist':
                if is_home:
                    home_assists += 1
                else:
                    away_assists += 1
            
            elif event_type == 'steal':
                if is_home:
                    home_steals += 1
                else:
                    away_steals += 1
            
            elif event_type == 'block':
                if is_home:
                    home_blocks += 1
                else:
                    away_blocks += 1
            
            home_recent_points = home_recent_points[-10:]
            away_recent_points = away_recent_points[-10:]
            
            # Sample on possession change OR every 10 events
            if team_id != last_possession_team or i % 10 == 0:
                clock = event.get('clock', '12:00')
                try:
                    mins, secs = clock.split(':')
                    time_in_period = int(mins) * 60 + int(secs)
                except:
                    time_in_period = 720
                
                # Calculate percentages safely
                home_fg_pct = home_fgm / home_fga if home_fga > 0 else 0
                away_fg_pct = away_fgm / away_fga if away_fga > 0 else 0
                home_3p_pct = home_3pm / home_3pa if home_3pa > 0 else 0
                away_3p_pct = away_3pm / away_3pa if away_3pa > 0 else 0
                home_ft_pct = home_ftm / home_fta if home_fta > 0 else 0
                away_ft_pct = away_ftm / away_fta if away_fta > 0 else 0
                
                data.append({
                    'quarter': period.get('number', 1),
                    'time_in_period': time_in_period,
                    'home_score': home_score,
                    'away_score': away_score,
                    'score_diff': home_score - away_score,
                    'home_fgm': home_fgm,
                    'home_fga': home_fga,
                    'home_fg_pct': home_fg_pct,
                    'away_fgm': away_fgm,
                    'away_fga': away_fga,
                    'away_fg_pct': away_fg_pct,
                    'home_3pm': home_3pm,
                    'home_3pa': home_3pa,
                    'home_3p_pct': home_3p_pct,
                    'away_3pm': away_3pm,
                    'away_3pa': away_3pa,
                    'away_3p_pct': away_3p_pct,
                    'home_ftm': home_ftm,
                    'home_fta': home_fta,
                    'home_ft_pct': home_ft_pct,
                    'away_ftm': away_ftm,
                    'away_fta': away_fta,
                    'away_ft_pct': away_ft_pct,
                    'home_timeouts': home_timeouts,
                    'away_timeouts': away_timeouts,
                    'home_fouls': home_fouls,
                    'away_fouls': away_fouls,
                    'home_turnovers': home_turnovers,
                    'away_turnovers': away_turnovers,
                    'home_rebounds': home_rebounds,
                    'away_rebounds': away_rebounds,
                    'home_assists': home_assists,
                    'away_assists': away_assists,
                    'home_steals': home_steals,
                    'away_steals': away_steals,
                    'home_blocks': home_blocks,
                    'away_blocks': away_blocks,
                    'home_recent_scoring': sum(home_recent_points),
                    'away_recent_scoring': sum(away_recent_points),
                    'home_won': int(home_won)
                })
                last_possession_team = team_id

# Train with ALL features
df = pd.DataFrame(data)
print(f"\nTotal snapshots: {len(df)}")
print(f"Total features: {len(df.columns) - 1}")

features = [col for col in df.columns if col != 'home_won']
X = df[features]
y = df['home_won']

Skipping a06b10fa-fd80-4058-9e0a-d3d1d69cb6f1: missing required keys
Skipping 28dfca06-a3e0-43f5-a841-53fb7d75cc00: missing required keys

Total snapshots: 4898
Total features: 39


In [8]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [9]:
# Define models
models = {
    'Decision Tree': DecisionTreeClassifier(max_depth=10, random_state=42),
    'Random Forest': RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42),
    'Gradient Boosting': GradientBoostingClassifier(n_estimators=100, max_depth=5, random_state=42),
    'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42),
    'Neural Network': MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
}

# Train and evaluate all models
results = []

for name, model in models.items():
    print(f"Training {name}...")
    
    start_time = time.time()
    model.fit(X_train, y_train)
    train_time = time.time() - start_time
    
    # Predictions
    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)[:, 1]
    
    # Metrics
    accuracy = accuracy_score(y_test, y_pred)
    logloss = log_loss(y_test, y_prob)
    auc = roc_auc_score(y_test, y_prob)
    
    results.append({
        'Model': name,
        'Accuracy': accuracy,
        'Log Loss': logloss,
        'AUC-ROC': auc,
        'Train Time (s)': train_time
    })

# Create comparison table
results_df = pd.DataFrame(results).sort_values('Log Loss')

print("\n" + "="*80)
print("MODEL COMPARISON")
print("="*80)
print(results_df.to_string(index=False))
print("\nBest model by Log Loss (calibration): " + results_df.iloc[0]['Model'])
print("Best model by Accuracy: " + results_df.sort_values('Accuracy', ascending=False).iloc[0]['Model'])
print("Best model by AUC-ROC: " + results_df.sort_values('AUC-ROC', ascending=False).iloc[0]['Model'])

Training Decision Tree...
Training Random Forest...
Training Gradient Boosting...
Training Logistic Regression...
Training Neural Network...

MODEL COMPARISON
              Model  Accuracy  Log Loss  AUC-ROC  Train Time (s)
Logistic Regression  0.518367  0.689183 0.542194        0.015362
      Random Forest  0.533673  0.691745 0.534368        0.160475
     Neural Network  0.494898  0.698625 0.491178        0.164747
  Gradient Boosting  0.510204  0.712241 0.511281        0.311555
      Decision Tree  0.505102  2.711164 0.494269        0.006461

Best model by Log Loss (calibration): Logistic Regression
Best model by Accuracy: Random Forest
Best model by AUC-ROC: Logistic Regression


In [13]:
# Better parameter grid for win probability
param_grid = {
    'n_estimators': [100, 200, 300],
    'max_depth': [10, 15, 20, 25],  # Deeper trees
    'min_samples_split': [2, 5],
    'min_samples_leaf': [1, 2],
    'max_features': ['sqrt', 0.5, 0.7]  # More features
}

# Use GridSearchCV on smaller grid
from sklearn.model_selection import GridSearchCV

rf = RandomForestClassifier(random_state=42)

grid_search = GridSearchCV(
    rf,
    param_grid,
    cv=3,  # Reduce folds
    scoring='neg_log_loss',
    n_jobs=-1,
    verbose=0
)

grid_search.fit(X_train, y_train)

best_model = grid_search.best_estimator_

print("\n" + "="*80)
print("IMPROVED HYPERPARAMETERS")
print("="*80)
for param, value in grid_search.best_params_.items():
    print(f"{param}: {value}")

y_pred = best_model.predict(X_test)
y_prob = best_model.predict_proba(X_test)[:, 1]

print("\n" + "="*80)
print("IMPROVED MODEL PERFORMANCE")
print("="*80)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
print(f"Log Loss: {log_loss(y_test, y_prob):.4f}")
print(f"AUC-ROC: {roc_auc_score(y_test, y_prob):.4f}")


IMPROVED HYPERPARAMETERS
max_depth: 10
max_features: sqrt
min_samples_leaf: 2
min_samples_split: 5
n_estimators: 300

IMPROVED MODEL PERFORMANCE
Accuracy: 0.5469
Log Loss: 0.6888
AUC-ROC: 0.5457


In [14]:
# Example observation - Q4, Home team up by 5
example = pd.DataFrame([{
    'quarter': 4,
    'time_in_period': 180,  # 3 minutes left
    'home_score': 105,
    'away_score': 100,
    'score_diff': 5,
    'home_fgm': 38, 'home_fga': 82, 'home_fg_pct': 0.463,
    'away_fgm': 35, 'away_fga': 80, 'away_fg_pct': 0.438,
    'home_3pm': 10, 'home_3pa': 28, 'home_3p_pct': 0.357,
    'away_3pm': 12, 'away_3pa': 30, 'away_3p_pct': 0.400,
    'home_ftm': 19, 'home_fta': 22, 'home_ft_pct': 0.864,
    'away_ftm': 18, 'away_fta': 24, 'away_ft_pct': 0.750,
    'home_timeouts': 2, 'away_timeouts': 1,
    'home_fouls': 18, 'away_fouls': 20,
    'home_turnovers': 12, 'away_turnovers': 14,
    'home_rebounds': 42, 'away_rebounds': 38,
    'home_assists': 22, 'away_assists': 19,
    'home_steals': 7, 'away_steals': 5,
    'home_blocks': 4, 'away_blocks': 3,
    'home_recent_scoring': 14, 'away_recent_scoring': 8
}])

# Get prediction
win_prob = best_model.predict_proba(example)[0, 1]

print(f"Home team win probability: {win_prob:.1%}")
print(f"Away team win probability: {(1-win_prob):.1%}")

Home team win probability: 65.0%
Away team win probability: 35.0%


In [17]:
# Test on multiple real game snapshots from test set
n_examples = 10

print("Probability vs Actual (Test Set Examples)")
print("="*60)

for i in range(n_examples):
    # Get actual test observation
    example = X_test.iloc[i:i+1]
    actual = y_test.iloc[i]
    
    # Predict
    prob = best_model.predict_proba(example)[0, 1]
    
    # Extract key features for display
    quarter = example['quarter'].values[0]
    score_diff = example['score_diff'].values[0]
    time = example.get('time_in_period', [0]).values[0] if 'time_in_period' in example.columns else 0
    
    print(f"\nExample {i+1}:")
    print(f"  Quarter: {quarter}, Score diff: {score_diff:+.0f}, Time: {time:.0f}s")
    print(f"  Predicted: {prob:.1%} home win")
    print(f"  Actual: {'HOME WON' if actual == 1 else 'AWAY WON'}")
    print(f"  Correct: {'✓' if (prob >= 0.5) == actual else '✗'}")

Probability vs Actual (Test Set Examples)

Example 1:
  Quarter: 1, Score diff: +0, Time: 661s
  Predicted: 52.1% home win
  Actual: AWAY WON
  Correct: ✗

Example 2:
  Quarter: 4, Score diff: +0, Time: 337s
  Predicted: 51.9% home win
  Actual: HOME WON
  Correct: ✓

Example 3:
  Quarter: 3, Score diff: +0, Time: 54s
  Predicted: 51.2% home win
  Actual: AWAY WON
  Correct: ✗

Example 4:
  Quarter: 1, Score diff: +0, Time: 720s
  Predicted: 49.6% home win
  Actual: HOME WON
  Correct: ✗

Example 5:
  Quarter: 2, Score diff: +0, Time: 95s
  Predicted: 50.2% home win
  Actual: HOME WON
  Correct: ✓

Example 6:
  Quarter: 3, Score diff: +0, Time: 456s
  Predicted: 43.5% home win
  Actual: HOME WON
  Correct: ✗

Example 7:
  Quarter: 1, Score diff: +0, Time: 47s
  Predicted: 31.1% home win
  Actual: AWAY WON
  Correct: ✓

Example 8:
  Quarter: 1, Score diff: +0, Time: 570s
  Predicted: 54.9% home win
  Actual: AWAY WON
  Correct: ✗

Example 9:
  Quarter: 4, Score diff: +0, Time: 517s
  Pr