In [2]:
import ipywidgets as widgets
from ipywidgets import interact, Layout
import pandas as pd
import numpy as np
import json
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import seaborn as sns

# Set plot style
sns.set_style("whitegrid")
plt.rc('font', size=12)

In [3]:
# --- Load Model and Normalization Stats ---
# This cell loads the necessary files. Run it once.

MODEL_PATH = '../nba_prediction_model.keras'
STATS_PATH = '../normalization_stats.json'
CONFIG_FEATURES = [
    'PTS_avg_home', 'FG_PCT_avg_home', 'FT_PCT_avg_home', 'FG3_PCT_avg_home', 'AST_avg_home', 'REB_avg_home',
    'PTS_avg_away', 'FG_PCT_avg_away', 'FT_PCT_avg_away', 'FG3_PCT_avg_away', 'AST_avg_away', 'REB_avg_away'
]

try:
    model = load_model(MODEL_PATH)
    with open(STATS_PATH, 'r') as f:
        norm_stats = json.load(f)

    mean_stats = pd.Series(norm_stats['mean'])
    std_stats = pd.Series(norm_stats['std'])

    print("✅ Model and normalization stats loaded successfully!")
except Exception as e:
    print(f"❌ Error loading files: {e}")
    print("Please make sure you have trained the model by running 'python -m src.train' in the root directory.")

✅ Model and normalization stats loaded successfully!


In [4]:
# --- Prediction Function ---
# This function takes the stats, normalizes them, and returns the prediction

def predict_winner(input_data):
    """Predicts winner based on input data dictionary."""
    input_df = pd.DataFrame([input_data])

    # Ensure columns are in the correct order
    input_df = input_df[CONFIG_FEATURES]

    # Normalize the input data using saved stats from training
    normalized_df = (input_df - mean_stats) / std_stats

    # Make prediction
    prediction_prob = model.predict(normalized_df, verbose=0)[0][0]
    return prediction_prob

In [5]:
# --- Create Interactive Sliders ---

style = {'description_width': 'initial'}
layout = Layout(width='80%')

# Sliders for Home Team
pts_home = widgets.FloatSlider(value=115, min=80, max=140, step=0.1, description='Home Avg Points:', style=style, layout=layout)
fg_pct_home = widgets.FloatSlider(value=0.48, min=0.35, max=0.60, step=0.001, description='Home Avg FG%:', format='.3f', style=style, layout=layout)
ast_home = widgets.FloatSlider(value=25, min=15, max=40, step=0.1, description='Home Avg Assists:', style=style, layout=layout)
reb_home = widgets.FloatSlider(value=44, min=30, max=60, step=0.1, description='Home Avg Rebounds:', style=style, layout=layout)

# Sliders for Away Team
pts_away = widgets.FloatSlider(value=110, min=80, max=140, step=0.1, description='Away Avg Points:', style=style, layout=layout)
fg_pct_away = widgets.FloatSlider(value=0.46, min=0.35, max=0.60, step=0.001, description='Away Avg FG%:', format='.3f', style=style, layout=layout)
ast_away = widgets.FloatSlider(value=23, min=15, max=40, step=0.1, description='Away Avg Assists:', style=style, layout=layout)
reb_away = widgets.FloatSlider(value=42, min=30, max=60, step=0.1, description='Away Avg Rebounds:', style=style, layout=layout)

# Output widget to display the results
output_widget = widgets.Output()

def update_prediction(**kwargs):
    # This function will be called whenever a slider value changes
    with output_widget:
        output_widget.clear_output(wait=True)

        # Create the data dictionary from current slider values
        # Note: FT% and FG3% are hardcoded here for simplicity, but could also be sliders
        game_data = {
            'PTS_avg_home': kwargs['pts_home'], 'FG_PCT_avg_home': kwargs['fg_pct_home'],
            'AST_avg_home': kwargs['ast_home'], 'REB_avg_home': kwargs['reb_home'],
            'FT_PCT_avg_home': 0.78, 'FG3_PCT_avg_home': 0.36,
            'PTS_avg_away': kwargs['pts_away'], 'FG_PCT_avg_away': kwargs['fg_pct_away'],
            'AST_avg_away': kwargs['ast_away'], 'REB_avg_away': kwargs['reb_away'],
            'FT_PCT_avg_away': 0.78, 'FG3_PCT_avg_away': 0.35
        }

        # Get prediction
        home_win_prob = predict_winner(game_data)
        away_win_prob = 1 - home_win_prob

        # --- Display Results ---
        fig, ax = plt.subplots(figsize=(8, 4))

        teams = ['Away Team', 'Home Team']
        probs = [away_win_prob * 100, home_win_prob * 100]
        colors = ['#E74C3C', '#3498DB']

        bars = ax.barh(teams, probs, color=colors)
        ax.bar_label(bars, fmt='%.2f%%', padding=5, fontsize=12, weight='bold')

        ax.set_title('Predicted Win Probability', fontsize=16, weight='bold')
        ax.set_xlabel('Probability (%)')
        ax.set_xlim(0, 100)
        ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) # Hide x-axis labels

        # Remove spines
        for spine in ['top', 'right', 'bottom', 'left']:
            ax.spines[spine].set_visible(False)

        plt.show()

# Link sliders to the update function
interactive_ui = widgets.interactive(
    update_prediction,
    pts_home=pts_home, fg_pct_home=fg_pct_home, ast_home=ast_home, reb_home=reb_home,
    pts_away=pts_away, fg_pct_away=fg_pct_away, ast_away=ast_away, reb_away=reb_away
)

# Display the UI
home_box = widgets.VBox([widgets.HTML("<h3><b>Home Team</b></h3>"), pts_home, fg_pct_home, ast_home, reb_home])
away_box = widgets.VBox([widgets.HTML("<h3><b>Away Team</b></h3>"), pts_away, fg_pct_away, ast_away, reb_away])
ui_container = widgets.HBox([home_box, away_box])

display(ui_container, output_widget)

HBox(children=(VBox(children=(HTML(value='<h3><b>Home Team</b></h3>'), FloatSlider(value=115.0, description='H…

Output()