In [1]:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score

In [2]:
matches = pd.read_csv('nba_games_normalized.csv', index_col=0)
matches

Unnamed: 0_level_0,Season,Date,StartET,Team,Opponent,Venue,Result,TeamPoints,OpponentPoints,Attendance,FGA,FGM,3PA,3PM,3P%,FTA,FTM,FT%
Index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,2023-24,2023-10-25,7:00p,ATL,CHA,away,L,110,116,16129,93.0,39.0,29.0,5.0,0.172,33.0,27.0,0.818
1,2023-24,2023-10-27,7:30p,ATL,NYK,home,L,120,126,17692,87.0,42.0,32.0,12.0,0.375,30.0,24.0,0.800
2,2023-24,2023-10-29,7:00p,ATL,MIL,away,W,127,110,17341,93.0,47.0,37.0,15.0,0.405,22.0,18.0,0.818
3,2023-24,2023-10-30,7:30p,ATL,MIN,home,W,127,113,15504,86.0,48.0,30.0,14.0,0.467,18.0,17.0,0.944
4,2023-24,2023-11-01,7:30p,ATL,WAS,home,W,130,121,15925,92.0,46.0,32.0,9.0,0.281,32.0,29.0,0.906
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4939,2024-25,2025-04-06,6:00p,WAS,BOS,away,L,90,124,19156,86.0,35.0,45.0,12.0,0.267,10.0,8.0,0.800
4940,2024-25,2025-04-08,7:00p,WAS,IND,away,L,98,104,16144,91.0,38.0,39.0,11.0,0.282,15.0,11.0,0.733
4941,2024-25,2025-04-09,7:00p,WAS,PHI,home,L,103,122,17222,87.0,34.0,33.0,10.0,0.303,30.0,25.0,0.833
4942,2024-25,2025-04-11,8:00p,WAS,CHI,away,L,89,119,21400,91.0,35.0,41.0,10.0,0.244,12.0,9.0,0.750


In [3]:
# cleaning up the data for use in the model
matches["date"] = pd.to_datetime(matches["Date"])
matches["venue_code"] = matches["Venue"].astype("category").cat.codes
matches["opp_code"] = matches["Opponent"].astype("category").cat.codes
matches["hour"] = matches["StartET"].str.replace(":.+", "", regex=True).astype("int")
matches["day_code"] = matches["date"].dt.dayofweek
matches["target"] = (matches["Result"] == "W").astype("int")

In [4]:
rf = RandomForestClassifier(n_estimators=50, min_samples_split=10, random_state=1)

In [5]:
training_set = matches[matches["date"] < '2024-09-30']
test_set = matches[matches["date"] >= '2024-09-30']

In [6]:
predictors = ["venue_code", "opp_code", "hour", "day_code"]

In [7]:
# Predict && print metrics of initial model
rf.fit(training_set[predictors], training_set["target"])
predictions = rf.predict(test_set[predictors])
print("Accuracy:", accuracy_score(test_set["target"], predictions)) # whether it was right or wrong
print("Precision:", precision_score(test_set["target"], predictions)) # when it predicted win, how often was it correct

actual_and_pred = pd.DataFrame(dict(actual=test_set["target"], prediction=predictions))
pd.crosstab(index=actual_and_pred["actual"], columns=actual_and_pred["prediction"])


Accuracy: 0.5651294498381877
Precision: 0.5681625740897545


prediction,0,1
actual,Unnamed: 1_level_1,Unnamed: 2_level_1
0,726,510
1,565,671


In [8]:
# showing how to group by team (prints all LAL games)
grouped_matches = matches.groupby("Team")
group = grouped_matches.get_group('LAL')
group

Unnamed: 0_level_0,Season,Date,StartET,Team,Opponent,Venue,Result,TeamPoints,OpponentPoints,Attendance,...,3P%,FTA,FTM,FT%,date,venue_code,opp_code,hour,day_code,target
Index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1070,2023-24,2023-10-24,7:30p,LAL,DEN,away,L,107,119,19842,...,0.345,20.0,15.0,0.750,2023-10-24,0,7,7,1,0
1071,2023-24,2023-10-26,10:00p,LAL,PHX,home,W,100,95,18997,...,0.172,29.0,23.0,0.793,2023-10-26,1,23,10,3,1
1072,2023-24,2023-10-29,9:00p,LAL,SAC,away,L,127,132,18198,...,0.333,34.0,26.0,0.765,2023-10-29,0,25,9,6,0
1073,2023-24,2023-10-30,10:30p,LAL,ORL,home,W,106,103,18997,...,0.296,18.0,14.0,0.778,2023-10-30,1,21,10,0,1
1074,2023-24,2023-11-01,10:00p,LAL,LAC,home,W,130,125,18997,...,0.344,36.0,27.0,0.750,2023-11-01,1,12,10,2,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3621,2024-25,2025-04-06,3:30p,LAL,OKC,away,W,126,99,18203,...,0.550,17.0,12.0,0.706,2025-04-06,0,20,3,6,1
3622,2024-25,2025-04-08,8:00p,LAL,OKC,away,L,120,136,18203,...,0.450,37.0,28.0,0.757,2025-04-08,0,20,8,1,0
3623,2024-25,2025-04-09,7:30p,LAL,DAL,away,W,112,97,20841,...,0.333,21.0,16.0,0.762,2025-04-09,0,6,7,2,1
3624,2024-25,2025-04-11,10:30p,LAL,HOU,home,W,140,109,18997,...,0.514,25.0,15.0,0.600,2025-04-11,1,10,10,4,1


In [9]:
def rolling_averages(group, cols, new_cols):
  group = group.sort_values("date")
  rolling_stats = group[cols].rolling(3, closed='left').mean() # rolling average of last 3 games, excluding current game
  group[new_cols] = rolling_stats
  group = group.dropna(subset=new_cols) # drop rows where rolling averages are NaN (first 2 games of each team)
  return group

In [10]:
# create rolling average columns
cols = ["FGA", "FGM", "3PA", "3PM", "3P%", "FTA", "FTM", "FT%"]
new_cols = [f"rolling_{c}" for c in cols]

In [11]:
# test rolling_averages by showing rolling averages for LAL
rolling_averages(group, cols, new_cols)

Unnamed: 0_level_0,Season,Date,StartET,Team,Opponent,Venue,Result,TeamPoints,OpponentPoints,Attendance,...,day_code,target,rolling_FGA,rolling_FGM,rolling_3PA,rolling_3PM,rolling_3P%,rolling_FTA,rolling_FTM,rolling_FT%
Index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1073,2023-24,2023-10-30,10:30p,LAL,ORL,home,W,106,103,18997,...,0,1,90.666667,40.000000,34.333333,10.000000,0.283333,27.666667,21.333333,0.769333
1074,2023-24,2023-11-01,10:00p,LAL,LAC,home,W,130,125,18997,...,2,1,87.000000,40.333333,33.666667,9.333333,0.267000,27.000000,21.000000,0.778667
1075,2023-24,2023-11-04,7:00p,LAL,ORL,away,L,101,120,18846,...,5,0,89.000000,43.666667,34.666667,11.333333,0.324333,29.333333,22.333333,0.764333
1076,2023-24,2023-11-06,7:30p,LAL,MIA,away,L,107,108,19725,...,0,0,84.333333,41.666667,29.666667,9.000000,0.302333,25.000000,20.000000,0.811000
1077,2023-24,2023-11-08,8:00p,LAL,HOU,away,L,94,128,18055,...,2,0,85.000000,42.333333,29.333333,9.000000,0.306333,23.666667,19.000000,0.813667
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3621,2024-25,2025-04-06,3:30p,LAL,OKC,away,W,126,99,18203,...,6,1,84.666667,38.333333,44.666667,16.666667,0.373667,23.666667,21.333333,0.899667
3622,2024-25,2025-04-08,8:00p,LAL,OKC,away,L,120,136,18203,...,1,0,83.000000,42.000000,42.666667,18.666667,0.441000,22.666667,19.333333,0.835000
3623,2024-25,2025-04-09,7:30p,LAL,DAL,away,W,112,97,20841,...,2,1,81.000000,41.333333,40.333333,18.666667,0.463333,27.333333,22.000000,0.797333
3624,2024-25,2025-04-11,10:30p,LAL,HOU,home,W,140,109,18997,...,4,1,84.666667,41.666667,38.666667,17.333333,0.444333,25.000000,18.666667,0.741667


In [12]:
parts = []

for team, group in matches.groupby("Team", sort=False):
    # workaround for groupby().apply() FutureWarning
    part = rolling_averages(group.copy(), cols, new_cols)
    parts.append(part)

matches_rolling = pd.concat(parts, ignore_index=True)
matches_rolling

Unnamed: 0,Season,Date,StartET,Team,Opponent,Venue,Result,TeamPoints,OpponentPoints,Attendance,...,day_code,target,rolling_FGA,rolling_FGM,rolling_3PA,rolling_3PM,rolling_3P%,rolling_FTA,rolling_FTM,rolling_FT%
0,2023-24,2023-10-30,7:30p,ATL,MIN,home,W,127,113,15504,...,0,1,91.000000,42.666667,32.666667,10.666667,0.317333,28.333333,23.000000,0.812000
1,2023-24,2023-11-01,7:30p,ATL,WAS,home,W,130,121,15925,...,2,1,88.666667,45.666667,33.000000,13.666667,0.415667,23.333333,19.666667,0.854000
2,2023-24,2023-11-04,7:00p,ATL,NOP,away,W,123,105,17237,...,5,1,90.333333,47.000000,33.000000,12.666667,0.384333,24.000000,21.333333,0.889333
3,2023-24,2023-11-06,8:00p,ATL,OKC,away,L,117,126,16486,...,0,0,90.333333,46.333333,34.333333,12.333333,0.363000,24.666667,21.666667,0.880667
4,2023-24,2023-11-09,9:30p,ATL,ORL,away,W,120,119,19986,...,3,1,95.666667,43.000000,38.333333,12.333333,0.318333,29.333333,25.000000,0.847333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4643,2024-25,2025-04-06,6:00p,WAS,BOS,away,L,90,124,19156,...,6,0,85.000000,38.333333,38.666667,14.000000,0.365667,15.666667,11.666667,0.754000
4644,2024-25,2025-04-08,7:00p,WAS,IND,away,L,98,104,16144,...,1,0,85.000000,38.666667,39.333333,13.666667,0.354000,13.000000,10.000000,0.780000
4645,2024-25,2025-04-09,7:00p,WAS,PHI,home,L,103,122,17222,...,2,0,87.333333,37.000000,38.000000,11.666667,0.316333,12.000000,9.333333,0.783667
4646,2024-25,2025-04-11,8:00p,WAS,CHI,away,L,89,119,21400,...,4,0,88.000000,35.666667,39.000000,11.000000,0.284000,18.333333,14.666667,0.788667


In [13]:
def make_predictions(data, predictors):
  training_set = data[data["date"] < '2024-09-30']
  test_set = data[data["date"] >= '2024-09-30']
  rf.fit(training_set[predictors], training_set["target"])
  predictions = rf.predict(test_set[predictors])
  combined = pd.DataFrame(dict(actual=test_set["target"], prediction=predictions), index=test_set.index)
  precision = precision_score(test_set["target"], predictions)
  return combined, precision

In [14]:
combined, precision = make_predictions(matches_rolling, predictors + new_cols)

In [15]:
precision

0.5669679539852095

In [16]:
combined = combined.merge(matches_rolling[["Date", "Team", "Opponent", "Result"]], left_index=True, right_index=True)
combined

Unnamed: 0,actual,prediction,Date,Team,Opponent,Result
80,1,1,2024-10-23,ATL,BKN,W
81,1,1,2024-10-25,ATL,CHA,W
82,0,0,2024-10-27,ATL,OKC,L
83,0,1,2024-10-28,ATL,WAS,L
84,0,1,2024-10-30,ATL,WAS,L
...,...,...,...,...,...,...
4643,0,0,2025-04-06,WAS,BOS,L
4644,0,1,2025-04-08,WAS,IND,L
4645,0,1,2025-04-09,WAS,PHI,L
4646,0,0,2025-04-11,WAS,CHI,L
