In [4]:
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Lasso, Ridge
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR

In [14]:
# inputs
#########
NAME = 'lebron_james'
model = RandomForestRegressor(n_estimators=500, max_depth=5)
#########

train = pd.read_csv(f"./player_game_logs/{NAME}/{NAME}_TRAIN.csv")
test = pd.read_csv(f"./player_game_logs/{NAME}/{NAME}_TEST.csv")

X_train = train.drop(columns=['PTS'])
X_test = test.drop(columns=['PTS'])
y_train = train['PTS']
y_test = test['PTS']
n_train = y_train.count()
n_test = y_test.count()
test_mean = np.mean(y_test)

# scale features
scaler = MinMaxScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

model.fit(X_train_scaled, y_train)

y_train_pred = model.predict(X_train_scaled)
mse_train = mean_squared_error(y_train, y_train_pred)
rmse_train = np.sqrt(mse_train)
r2_train = r2_score(y_train, y_train_pred)

y_pred_test = model.predict(X_test_scaled)
mse_test = mean_squared_error(y_test, y_pred_test)
rmse_test = np.sqrt(mse_test)
r2_test = r2_score(y_test, y_pred_test)

rounded_preds = []
for p in y_pred_test:
    rounded_preds.append(round(p, 4))

y = list(zip(y_test, rounded_preds))

print(f"player: {NAME}\n")
print(f"(actual_pts, predicted_pts) pairs: \n{str(y)}\n")
#print({f"{col}_weight": coef for col, coef in zip(X_train.columns, model.coef_)}
#print(f"bias: {model.intercept_}")
print(f"y_test_mean: {test_mean}")
print(f"r2_train: {r2_train}")
print(f"r2_test: {r2_test}")

player: lebron_james

(actual_pts, predicted_pts) pairs: 
[(27, 26.6795), (18, 26.4421), (28, 27.4694), (36, 28.5641), (27, 26.4849), (43, 27.8252), (37, 29.2979), (22, 26.1714), (30, 26.6056), (28, 26.7438), (32, 26.889), (19, 26.8761), (29, 25.5062), (38, 26.8078), (25, 26.0207), (37, 25.99), (36, 28.8972), (27, 26.673), (31, 26.671), (16, 25.2265), (24, 26.694), (32, 30.0118), (20, 26.8711), (20, 26.4468), (46, 29.3307), (15, 28.1189), (32, 28.9885), (32, 29.2483), (23, 28.0141), (10, 26.7757), (32, 28.7737), (23, 25.7454), (36, 26.4753), (10, 25.2243), (23, 26.3848), (19, 27.0414), (35, 29.7304), (42, 30.293), (32, 26.666), (22, 28.1895), (32, 21.715), (29, 30.0177), (26, 26.6312), (20, 27.3522), (21, 26.7696), (27, 27.8551), (31, 26.7062), (9, 24.9951), (39, 26.4932), (26, 26.5427), (32, 26.4964), (27, 27.9475), (23, 27.4206), (21, 26.3904), (31, 27.3338), (29, 29.2703), (22, 26.4212), (24, 26.2864), (33, 26.6071), (22, 29.7152), (18, 24.878), (19, 26.6415), (15, 25.3349), (17, 25