# Make your own predictions! ⛹️

This notebook will allow you to load and use the predictive models available here to generate your 
own predictions for future regular-season NBA games.

The steps involved are:
- Define your data file
- Load the data
- Select and load the model
- Make the prediction

Details on the definition of the data file:

- A csv file of previous game data is needed before predictions can begin.
- To predict the result of the upcoming game between HomeTeam and AwayTeam, the CSV needs to 
contain the results and statistics of each teams previous games.
- The first record in the csv should contain the oldest game data, and the last record should contain game data closest to the game to predict.
- The csv needs the columns listed below, where the `home` prefix refers to information about previous games of HomeTeam, and similarly `away` prefix for AwayTeam.
- By default, the code expects the csv to have data from 8 previous games.
- Example csv files are available in `data/predict_csv/`.

In [15]:
# 1. define your data file

header_cols = [
    "home_win",
    "home_home",
    "home_close_game",
    "home_ot_count",
    "home_q1_points",
    "home_q2_points",
    "home_q3_points",
    "home_q4_points",
    "home_ot_points",
    "home_final_points",
    "home_field_made",
    "home_field_percent",
    "home_three_made",
    "home_three_percent",
    "home_free_made",
    "home_free_percent",
    "home_offensive_rebounds",
    "home_defensive_rebounds",
    "home_total_rebounds",
    "home_assists",
    "home_steals",
    "home_blocks",
    "home_turnovers",
    "home_fouls",
    "home_plus_minus",
    "home_pts_paint",
    "home_pts_2nd_chance",
    "home_pts_fast_break",
    "home_largest_lead",
    "home_games_played",
    "home_win_percent",
    "home_opp_q1_points",
    "home_opp_q2_points",
    "home_opp_q3_points",
    "home_opp_q4_points",
    "home_opp_ot_points",
    "home_opp_final_points",
    "home_opp_field_made",
    "home_opp_field_percent",
    "home_opp_three_made",
    "home_opp_three_percent",
    "home_opp_free_made",
    "home_opp_free_percent",
    "home_opp_offensive_rebounds",
    "home_opp_defensive_rebounds",
    "home_opp_total_rebounds",
    "home_opp_assists",
    "home_opp_steals",
    "home_opp_blocks",
    "home_opp_turnovers",
    "home_opp_fouls",
    "home_opp_plus_minus",
    "home_opp_pts_paint",
    "home_opp_pts_2nd_chance",
    "home_opp_pts_fast_break",
    "home_opp_largest_lead",
    "home_opp_games_played",
    "home_opp_win_percent",
    "away_win",
    "away_home",
    "away_close_game",
    "away_ot_count",
    "away_q1_points",
    "away_q2_points",
    "away_q3_points",
    "away_q4_points",
    "away_ot_points",
    "away_final_points",
    "away_field_made",
    "away_field_percent",
    "away_three_made",
    "away_three_percent",
    "away_free_made",
    "away_free_percent",
    "away_offensive_rebounds",
    "away_defensive_rebounds",
    "away_total_rebounds",
    "away_assists",
    "away_steals",
    "away_blocks",
    "away_turnovers",
    "away_fouls",
    "away_plus_minus",
    "away_pts_paint",
    "away_pts_2nd_chance",
    "away_pts_fast_break",
    "away_largest_lead",
    "away_games_played",
    "away_win_percent",
    "away_opp_q1_points",
    "away_opp_q2_points",
    "away_opp_q3_points",
    "away_opp_q4_points",
    "away_opp_ot_points",
    "away_opp_final_points",
    "away_opp_field_made",
    "away_opp_field_percent",
    "away_opp_three_made",
    "away_opp_three_percent",
    "away_opp_free_made",
    "away_opp_free_percent",
    "away_opp_offensive_rebounds",
    "away_opp_defensive_rebounds",
    "away_opp_total_rebounds",
    "away_opp_assists",
    "away_opp_steals",
    "away_opp_blocks",
    "away_opp_turnovers",
    "away_opp_fouls",
    "away_opp_plus_minus",
    "away_opp_pts_paint",
    "away_opp_pts_2nd_chance",
    "away_opp_pts_fast_break",
    "away_opp_largest_lead",
    "away_opp_games_played",
    "away_opp_win_percent",
]

In [1]:
# 2. load the data from .csv file

from src.predict import load_record_from_csv

csv_path = "data/predict_csv/lac_home_win_vs_lal_2023_04_05.csv"
data = load_record_from_csv(file_path=csv_path)

print(f"Loaded data of size: {data.shape}")

Loaded data of size: torch.Size([1, 8, 116])


In [13]:
# 3. load the model from .pth file containing the state data

from src.persist import load_model

model_type = "TCN"  # choose between RNN, LSTM, GRU, TCN, TE
model_path = f"output/pretrained/{model_type}_state_data.pth"
model = load_model(file_path=model_path)

print(f"Loaded model: {model}")

Loaded model: TCN


In [14]:
# 4. run the prediction code on the model

from src.predict import make_prediction

pred = make_prediction(model=model, data=data)

print(f"Home win prediction: {pred}")

The HOME team will win.
Home win prediction: 0.9627264142036438
