In [1]:
import os
import re
import time
import requests
import pandas as pd
from tqdm import tqdm
from bs4 import BeautifulSoup
from sklearn.model_selection import train_test_split

pd.set_option("display.max_rows", None, "display.max_columns", None)

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
path = os.path.abspath('../../data/teams.csv')
df = pd.read_csv(path)
df.head()

Unnamed: 0,School,"City, State",SR key,NCAA key,NCAA School,NCAA Name,background-color
0,Abilene Christian,"Abilene, Texas",abilene-christian,abilene-christian,Abilene Christian,Abilene Christian University,#582C83
1,Air Force,"USAF Academy, Colorado",air-force,air-force,Air Force,Air Force Academy,#0032A0
2,Akron,"Akron, Ohio",akron,akron,Akron,University of Akron,#0F192B
3,Alabama,"Tuscaloosa, Alabama",alabama,alabama,Alabama,University of Alabama,#9D2235
4,Alabama A&M,"Normal, Alabama",alabama-am,alabama-am,Alabama A&M,Alabama A&M University,#862633


In [6]:
SR_SCHOOL_KEYS = [row.to_dict().get('SR key') for index, row in df.iterrows()]

def get_gamelog_basic_url(school_key, season, isWomens = False):
    type = 'women' if isWomens else 'men'
    return f'https://www.sports-reference.com/cbb/schools/{school_key}/{type}/{season}-gamelogs.html'

def get_gamelog_advanced_url(school_key, season, isWomens = False):
    type = 'women' if isWomens else 'men'
    return f'https://www.sports-reference.com/cbb/schools/{school_key}/{type}/{season}-gamelogs-advanced.html'

def get_team_season_file_path(school_key, season, filename):
    file_path = os.path.abspath(f'../../data/seasons/{season}/{school_key}/{filename}')

    dir = os.path.dirname(file_path)
    if not os.path.exists(dir):
        os.makedirs(dir)
    
    return file_path

## Download gamelogs HTML
Download the basic and advanced gamelog html for each team

In [7]:
def download_gamelog(school_key, season):
    basic_url, advanced_url = get_gamelog_basic_url(school), get_gamelog_advanced_url(school)

    time.sleep(3) # Delay for 3 seconds
    basic_html = requests.get(basic_url).content
    time.sleep(3) # Delay for 3 seconds
    advanced_html = requests.get(advanced_url).content

    basic_file_path = get_team_season_file_path(school_key, season, f'{school_key}_basic.html')
    advanced_file_path = get_team_season_file_path(school_key, season, f'{school_key}_advanced.html')

    with open(basic_file_path, 'w') as file:
        file.write(basic_html.decode('utf-8'))
    with open(advanced_file_path, 'w') as file:
        file.write(advanced_html.decode('utf-8'))

def download_gamelogs_for_single_season(season):
    for school_key in tqdm(SR_SCHOOL_KEYS, unit=f'school ({season})'):
        download_gamelog(school_key, season)

def download_gamelogs(seasons):
    for season in seasons:
        download_gamelogs_for_single_season(season)

## Change Opp name column to Opp key
The Opp name does not always match the school name we have saved in the team data, thus we will reparse the tables to instead use the SR keys

In [21]:
def get_opposing_school_keys(school_key, season):
    opp_school_keys = []
    html_file_path = get_team_season_file_path(school_key, season, f'{school_key}_basic.html')

    with open(html_file_path, 'r') as file:
        soup = BeautifulSoup(file, 'html.parser')
        table = soup.find("table")
        rows = table.find_all('tr')
    
        for row in rows[2:]:
            try:
                link = row.find_all('td')[2].find('a')['href']
                key = re.search(r'/schools/([^/]+)/', link).group(1)
                opp_school_keys.append(key)
            except IndexError:
                # repeating header row
                continue
    
    return opp_school_keys

## Basic gamelog CSV
Extract the basic gamelog to csv

In [22]:
def create_basic_gamelog(school_key, season):
    file_path = get_team_season_file_path(school_key, season, f'{school_key}_basic.html')

    team_df = pd.read_html(file_path)[0]

    # drop columns from 'Defensive Four Factors'
    opponent_columns = [column for column in team_df.columns if 'Opponent' in column[0]]
    team_df = team_df.drop(opponent_columns, axis=1)

    # Use second level column names
    team_df.columns = [column[1] for column in team_df.columns]

    # rename to location column
    team_df = team_df.rename(columns={'Unnamed: 2_level_1': 'Location'})

    # remove unneeded columns
    unneeded_columns = [column for column in team_df.columns if 'Unnamed' in column] + ['G']
    team_df = team_df.drop(unneeded_columns, axis=1)

    # Drop repeating header rows
    team_df = team_df[team_df.Tm != 'Tm']
    team_df = team_df[team_df.FG != 'School']

    # rename repeating 'Opp' column
    index = team_df.columns.to_list().index('Opp')
    team_df.columns.values[index] = 'Opp name'

    # Opp names to Opp keys
    opp_school_keys = get_opposing_school_keys(school_key, season)
    # Shape for both must match same rows
    assert team_df.shape[0] == len(opp_school_keys)
    team_df['Opp name'] = opp_school_keys
    team_df = team_df.rename(columns={'Opp name': 'Opp key'})

    # save file
    csv_file_path = get_team_season_file_path(school_key, season, f'{school_key}_basic.csv')
    team_df.to_csv(csv_file_path, index=False)

def create_basic_gamelogs_for_single_season(season):
    for school_key in tqdm(SR_SCHOOL_KEYS, unit=f'school ({season} basic csv)'):
        try:
            create_basic_gamelog(school_key, season)
        except ValueError:
            continue

def create_basic_gamelogs(seasons):
    for season in seasons:
        create_basic_gamelogs_for_single_season(season)

In [23]:
create_basic_gamelogs(['2020', '2021', '2022', '2023'])

100%|██████████| 491/491 [00:40<00:00, 12.24school (2020 basic csv)/s]
100%|██████████| 491/491 [00:34<00:00, 14.24school (2021 basic csv)/s]
100%|██████████| 491/491 [00:40<00:00, 12.05school (2022 basic csv)/s]
100%|██████████| 491/491 [00:41<00:00, 11.78school (2023 basic csv)/s]


Quick visual verification:

In [24]:
school_key, season = 'connecticut', 2020
file_path = get_team_season_file_path(school_key, season, f'{school_key}_basic.csv')
team_df = pd.read_csv(file_path)
team_df.head()

Unnamed: 0,Date,Location,Opp key,W/L,Tm,Opp,FG,FGA,FG%,3P,3PA,3P%,FT,FTA,FT%,ORB,TRB,AST,STL,BLK,TOV,PF
0,2019-11-08,,sacred-heart,W,89,67,34,75,0.453,5,19,0.263,16,22,0.727,14,38,16,10,9,11,24
1,2019-11-13,,saint-josephs,L,87,96,22,63,0.349,14,29,0.483,29,38,0.763,10,43,7,5,7,13,23
2,2019-11-17,,florida,W,62,59,21,59,0.356,4,22,0.182,16,22,0.727,11,36,11,7,4,9,14
3,2019-11-21,N,buffalo,W,79,68,30,62,0.484,7,18,0.389,12,21,0.571,11,45,14,5,7,17,12
4,2019-11-22,N,xavier,L (2 OT),74,75,21,70,0.3,7,23,0.304,25,27,0.926,13,39,10,8,6,15,24


## Advanced gamelog CSV
Extract the advanced gamelog to csv

In [25]:
def create_advanced_gamelog(school_key, season):
    file_path = get_team_season_file_path(school_key, season, f'{school_key}_advanced.html')

    team_df = pd.read_html(file_path)[0]

    # drop columns from 'Defensive Four Factors'
    defensive_columns = [column for column in team_df.columns if 'Defensive' in column[0]]
    team_df = team_df.drop(defensive_columns, axis=1)

    # Use second level column names
    team_df.columns = [column[1] for column in team_df.columns]

    # rename to location column
    team_df = team_df.rename(columns={'Unnamed: 2_level_1': 'Location'})

    # remove unneeded columns
    unneeded_columns = [column for column in team_df.columns if 'Unnamed' in column] + ['G']
    team_df = team_df.drop(unneeded_columns, axis=1)

    # Drop repeating header rows
    team_df = team_df[team_df.Tm != 'Tm']
    team_df = team_df[team_df['eFG%'] != 'Offensive Four Factors']

    # rename repeating 'Opp' column
    index = team_df.columns.to_list().index('Opp')
    team_df.columns.values[index] = 'Opp name'

    # Opp names to Opp keys
    opp_school_keys = get_opposing_school_keys(school_key, season)
    # Shape for both must match same rows
    assert team_df.shape[0] == len(opp_school_keys)
    team_df['Opp name'] = opp_school_keys
    team_df = team_df.rename(columns={'Opp name': 'Opp key'})

    # save file
    csv_file_path = get_team_season_file_path(school_key, season, f'{school_key}_advanced.csv')
    team_df.to_csv(csv_file_path, index=False)

def create_advanced_gamelogs_for_single_season(season):
    for school_key in tqdm(SR_SCHOOL_KEYS, unit=f'school ({season} advanced csv)'):
        try:
            create_advanced_gamelog(school_key, season)
        except ValueError:
            continue

def create_advanced_gamelogs(seasons):
    for season in seasons:
        create_advanced_gamelogs_for_single_season(season)

In [26]:
create_advanced_gamelogs(['2020', '2021', '2022', '2023'])

100%|██████████| 491/491 [00:39<00:00, 12.57school (2020 advanced csv)/s]
100%|██████████| 491/491 [00:33<00:00, 14.62school (2021 advanced csv)/s]
100%|██████████| 491/491 [00:39<00:00, 12.45school (2022 advanced csv)/s]
100%|██████████| 491/491 [00:40<00:00, 12.15school (2023 advanced csv)/s]


Quick visual verification:

In [28]:
school_key, season = 'connecticut', 2020
file_path = get_team_season_file_path(school_key, season, f'{school_key}_advanced.csv')
team_df = pd.read_csv(file_path)
team_df.head()

Unnamed: 0,Date,Location,Opp key,W/L,Tm,Opp,ORtg,DRtg,Pace,FTr,3PAr,TS%,TRB%,AST%,STL%,BLK%,eFG%,TOV%,ORB%,FT/FGA
0,2019-11-08,,sacred-heart,W,89,67,107.2,80.7,82.9,0.293,0.253,0.521,48.1,47.1,12.0,22.5,0.487,11.4,35.9,0.213
1,2019-11-13,,saint-josephs,L,87,96,102.4,112.9,84.9,0.603,0.46,0.537,52.4,31.8,5.9,18.9,0.46,13.8,25.0,0.46
2,2019-11-17,,florida,W,62,59,93.9,89.4,66.3,0.373,0.373,0.446,51.4,52.4,10.6,11.4,0.39,11.5,28.2,0.271
3,2019-11-21,N,buffalo,W,79,68,102.6,88.3,76.6,0.339,0.29,0.549,53.6,46.7,6.5,14.6,0.54,19.1,28.9,0.194
4,2019-11-22,N,xavier,L (2 OT),74,75,86.0,87.2,68.5,0.386,0.329,0.447,47.0,47.6,9.3,14.3,0.35,15.3,29.5,0.357


## Combine basic and advanced gamelog CSVs

In [29]:
def combine_basic_advanced_gamelog(school_key, season):
    basic_file_path = get_team_season_file_path(school_key, season, f'{school_key}_basic.csv')
    advanced_file_path = get_team_season_file_path(school_key, season, f'{school_key}_advanced.csv')

    basic_team_df, advanced_team_df = pd.read_csv(basic_file_path), pd.read_csv(advanced_file_path)

    merged_team_df = pd.merge(basic_team_df, advanced_team_df, on=['Date', 'Location', 'Opp key', 'W/L', 'Tm', 'Opp'])

    # fill NaN location values to 'H' to represent Home
    merged_team_df['Location'] = merged_team_df['Location'].fillna('H')
    
    csv_file_path = get_team_season_file_path(school_key, season, f'{school_key}_merged.csv')
    merged_team_df.to_csv(csv_file_path, index=False)

def combine_basic_advanced_gamelogs_for_single_season(season):
    for school_key in tqdm(SR_SCHOOL_KEYS, unit=f'school ({season} merged csv)'):
        try:
            combine_basic_advanced_gamelog(school_key, season)
        except FileNotFoundError:
            continue

def combine_basic_advanced_gamelogs(seasons):
    for season in seasons:
        combine_basic_advanced_gamelogs_for_single_season(season)

In [30]:
combine_basic_advanced_gamelogs(['2020', '2021', '2022', '2023'])

100%|██████████| 491/491 [00:05<00:00, 91.63school (2020 advanced csv)/s] 
100%|██████████| 491/491 [00:05<00:00, 95.75school (2021 advanced csv)/s] 
100%|██████████| 491/491 [00:05<00:00, 92.60school (2022 advanced csv)/s] 
100%|██████████| 491/491 [00:05<00:00, 89.85school (2023 advanced csv)/s] 


Quick visual verification:

In [31]:
school_key, season = 'connecticut', 2020
file_path = get_team_season_file_path(school_key, season, f'{school_key}_merged.csv')
team_df = pd.read_csv(file_path)
team_df.head()

Unnamed: 0,Date,Location,Opp key,W/L,Tm,Opp,FG,FGA,FG%,3P,3PA,3P%,FT,FTA,FT%,ORB,TRB,AST,STL,BLK,TOV,PF,ORtg,DRtg,Pace,FTr,3PAr,TS%,TRB%,AST%,STL%,BLK%,eFG%,TOV%,ORB%,FT/FGA
0,2019-11-08,H,sacred-heart,W,89,67,34,75,0.453,5,19,0.263,16,22,0.727,14,38,16,10,9,11,24,107.2,80.7,82.9,0.293,0.253,0.521,48.1,47.1,12.0,22.5,0.487,11.4,35.9,0.213
1,2019-11-13,H,saint-josephs,L,87,96,22,63,0.349,14,29,0.483,29,38,0.763,10,43,7,5,7,13,23,102.4,112.9,84.9,0.603,0.46,0.537,52.4,31.8,5.9,18.9,0.46,13.8,25.0,0.46
2,2019-11-17,H,florida,W,62,59,21,59,0.356,4,22,0.182,16,22,0.727,11,36,11,7,4,9,14,93.9,89.4,66.3,0.373,0.373,0.446,51.4,52.4,10.6,11.4,0.39,11.5,28.2,0.271
3,2019-11-21,N,buffalo,W,79,68,30,62,0.484,7,18,0.389,12,21,0.571,11,45,14,5,7,17,12,102.6,88.3,76.6,0.339,0.29,0.549,53.6,46.7,6.5,14.6,0.54,19.1,28.9,0.194
4,2019-11-22,N,xavier,L (2 OT),74,75,21,70,0.3,7,23,0.304,25,27,0.926,13,39,10,8,6,15,24,86.0,87.2,68.5,0.386,0.329,0.447,47.0,47.6,9.3,14.3,0.35,15.3,29.5,0.357


## Generating Moving Averages
Next we will generate a csv adding the moving averages for each statistic

In [70]:
SPAN = 5
LATEST = 'LATEST'

META_LABELS = ['Date', 'Location', 'Opp key', 'W/L', 'Tm', 'Opp']
STAT_LABELS = ['FG', 'FGA', 'FG%','3P', '3PA', '3P%', 'FT', 'FTA', 'FT%', 'ORB', 'TRB', 'AST', 'STL', 'BLK', 'TOV', 'PF', 'ORtg', 'DRtg', 'Pace', 'FTr', '3PAr', 'TS%', 'TRB%', 'AST%', 'STL%', 'BLK%', 'eFG%', 'TOV%', 'ORB%', 'FT/FGA']

In [76]:
def generate_moving_averages_for_school(school_key, season, keep_latest = False):
    file_path = get_team_season_file_path(school_key, season, f'{school_key}_merged.csv')
    team_df = pd.read_csv(file_path)

    # Drop any rows with NULL values
    team_df.dropna(inplace=True)

    if keep_latest:
        # This logic makes it so we keep the latest statistics for an upcoming game
        copyLast = pd.DataFrame(team_df.tail(1).values, columns=team_df.columns)
        team_df = pd.concat([team_df, copyLast], ignore_index=True)
        team_df.loc[team_df.index[-1], 'Date'] = LATEST
        team_df.loc[team_df.index[-1], 'Opp key'] = LATEST


    for column in team_df.columns:
        if column in META_LABELS:
            continue

        # Simple moving averages
        team_df[f"{column}_SMA"] = team_df.loc[:, column].rolling(window=SPAN).mean()
        team_df[f"{column}_SMA"] = team_df[f"{column}_SMA"].shift(1)

        # Cumulative moving average
        team_df[f"{column}_CMA"] = team_df.loc[:, column].expanding(min_periods=SPAN).mean()
        team_df[f"{column}_CMA"] = team_df[f"{column}_CMA"].shift(1)

        # Exponential moving average
        team_df[f"{column}_EMA"] = team_df.loc[:, column].ewm(span=SPAN, adjust=False).mean()
        team_df[f"{column}_EMA"] = team_df[f"{column}_EMA"].shift(1)

    # Drop any rows with NULL values (rows with no MA)
    team_df.dropna(inplace=True)
    
    ma_file_path = get_team_season_file_path(school_key, season, f'{school_key}_ma.csv')
    team_df.to_csv(ma_file_path, index=False)

def generate_moving_averages_for_single_season(season, keep_latest = False):
    for school_key in tqdm(SR_SCHOOL_KEYS, unit=f'school ({season} ma csv)'):
        try:
            generate_moving_averages_for_school(school_key, season, keep_latest)
        except FileNotFoundError:
            continue

def generate_moving_averages(seasons, keep_latest = False):
    for season in seasons:
        generate_moving_averages_for_single_season(season, keep_latest)

In [77]:
generate_moving_averages(['2020', '2021', '2022', '2023'])

100%|██████████| 491/491 [00:16<00:00, 29.99school (2020 ma csv)/s]
100%|██████████| 491/491 [00:15<00:00, 31.59school (2021 ma csv)/s]
100%|██████████| 491/491 [00:16<00:00, 30.01school (2022 ma csv)/s]
100%|██████████| 491/491 [00:16<00:00, 29.81school (2023 ma csv)/s]


Quick visual verification:

In [79]:
school_key, season = 'connecticut', 2023
file_path = get_team_season_file_path(school_key, season, f'{school_key}_ma.csv')
team_df = pd.read_csv(file_path)
team_df.head()

Unnamed: 0,Date,Location,Opp key,W/L,Tm,Opp,FG,FGA,FG%,3P,3PA,3P%,FT,FTA,FT%,ORB,TRB,AST,STL,BLK,TOV,PF,ORtg,DRtg,Pace,FTr,3PAr,TS%,TRB%,AST%,STL%,BLK%,eFG%,TOV%,ORB%,FT/FGA,FG_SMA,FG_CMA,FG_EMA,FGA_SMA,FGA_CMA,FGA_EMA,FG%_SMA,FG%_CMA,FG%_EMA,3P_SMA,3P_CMA,3P_EMA,3PA_SMA,3PA_CMA,3PA_EMA,3P%_SMA,3P%_CMA,3P%_EMA,FT_SMA,FT_CMA,FT_EMA,FTA_SMA,FTA_CMA,FTA_EMA,FT%_SMA,FT%_CMA,FT%_EMA,ORB_SMA,ORB_CMA,ORB_EMA,TRB_SMA,TRB_CMA,TRB_EMA,AST_SMA,AST_CMA,AST_EMA,STL_SMA,STL_CMA,STL_EMA,BLK_SMA,BLK_CMA,BLK_EMA,TOV_SMA,TOV_CMA,TOV_EMA,PF_SMA,PF_CMA,PF_EMA,ORtg_SMA,ORtg_CMA,ORtg_EMA,DRtg_SMA,DRtg_CMA,DRtg_EMA,Pace_SMA,Pace_CMA,Pace_EMA,FTr_SMA,FTr_CMA,FTr_EMA,3PAr_SMA,3PAr_CMA,3PAr_EMA,TS%_SMA,TS%_CMA,TS%_EMA,TRB%_SMA,TRB%_CMA,TRB%_EMA,AST%_SMA,AST%_CMA,AST%_EMA,STL%_SMA,STL%_CMA,STL%_EMA,BLK%_SMA,BLK%_CMA,BLK%_EMA,eFG%_SMA,eFG%_CMA,eFG%_EMA,TOV%_SMA,TOV%_CMA,TOV%_EMA,ORB%_SMA,ORB%_CMA,ORB%_EMA,FT/FGA_SMA,FT/FGA_CMA,FT/FGA_EMA
0,2022-11-24,N,oregon,W,83,59,30,63,0.476,17,37,0.459,6,11,0.545,11,35,22,9,7,11,25,118.6,84.3,70.1,0.175,0.587,0.608,57.4,73.3,12.9,21.9,0.611,13.9,32.4,0.095,30.2,30.2,30.82716,58.8,58.8,59.098765,0.516,0.516,0.523235,8.6,8.6,9.08642,23.6,23.6,24.444444,0.3658,0.3658,0.370877,18.2,18.2,17.765432,25.4,25.4,24.283951,0.7176,0.7176,0.73116,10.0,10.0,9.271605,36.8,36.8,36.0,19.0,19.0,20.135802,7.6,7.6,7.864198,5.6,5.6,5.876543,13.2,13.2,12.790123,18.0,18.0,18.54321,118.88,118.88,120.17037,77.4,77.4,76.961728,73.52,73.52,73.7,0.4324,0.4324,0.411568,0.402,0.402,0.414062,0.6192,0.6192,0.630321,59.36,59.36,59.101235,62.7,62.7,64.975309,10.08,10.08,10.419753,15.4,15.4,15.977778,0.5908,0.5908,0.601457,15.62,15.62,15.239506,33.76,33.76,32.44321,0.3102,0.3102,0.302383
1,2022-11-25,N,alabama,W,82,67,26,60,0.433,9,24,0.375,21,24,0.875,8,30,18,8,7,11,23,110.8,90.5,74.1,0.4,0.4,0.574,48.4,69.2,10.8,20.6,0.508,13.3,24.2,0.35,30.4,30.166667,30.55144,58.8,59.5,60.399177,0.5192,0.509333,0.50749,11.0,10.0,11.72428,26.2,25.833333,28.62963,0.416,0.381333,0.400251,15.0,16.166667,13.843621,21.4,23.0,19.855967,0.6846,0.688833,0.669107,9.2,10.166667,9.847737,35.8,36.5,35.666667,20.2,19.5,20.757202,7.2,7.833333,8.242798,5.4,5.833333,6.251029,12.4,12.833333,12.193416,19.0,19.166667,20.695473,120.24,118.833333,119.646914,80.04,78.55,79.407819,72.32,72.95,72.5,0.369,0.3895,0.332712,0.4432,0.432833,0.471708,0.6314,0.617333,0.622881,57.5,59.033333,58.534156,66.32,64.466667,67.750206,9.76,10.55,11.246502,14.06,16.483333,17.951852,0.613,0.594167,0.604638,15.16,15.333333,14.793004,30.56,33.533333,32.428807,0.2594,0.274333,0.233255
2,2022-11-27,N,iowa-state,W,71,53,22,53,0.415,7,26,0.269,20,25,0.8,18,45,16,5,2,17,15,112.7,84.1,62.8,0.472,0.491,0.547,70.3,72.7,7.9,4.9,0.481,20.8,60.0,0.377,29.4,29.571429,29.034294,59.8,59.571429,60.266118,0.493,0.498429,0.48266,10.8,9.857143,10.816187,26.8,25.571429,27.08642,0.3958,0.380429,0.391834,16.4,16.857143,16.229081,21.0,23.142857,21.237311,0.752,0.715429,0.737738,8.6,9.857143,9.231824,33.8,35.571429,33.777778,20.4,19.285714,19.838134,8.0,7.857143,8.161866,6.2,6.0,6.500686,11.8,12.571429,11.79561,20.0,19.714286,21.463649,117.82,117.685714,116.697942,81.86,80.257143,83.105213,73.06,73.114286,73.033333,0.3544,0.391,0.355141,0.4468,0.428143,0.447805,0.6186,0.611143,0.606587,55.42,57.514286,55.156104,69.2,65.142857,68.233471,10.78,10.585714,11.097668,16.72,17.071429,18.834568,0.5836,0.581857,0.572425,14.38,15.042857,14.295336,28.52,32.2,29.685871,0.2784,0.285143,0.27217
3,2022-12-01,H,oklahoma-state,W,74,64,21,56,0.375,8,23,0.348,24,33,0.727,10,28,11,8,3,6,16,107.2,92.8,68.9,0.589,0.411,0.516,45.2,52.4,11.6,6.7,0.446,7.7,30.3,0.429,28.2,28.625,26.689529,57.8,58.75,57.844079,0.4872,0.488,0.460107,11.0,9.5,9.544124,27.4,25.625,26.72428,0.3974,0.3665,0.350889,16.0,17.25,17.486054,20.4,23.375,22.491541,0.7548,0.726,0.758492,9.8,10.875,12.15455,35.4,36.75,37.518519,20.2,18.875,18.558756,6.8,7.5,7.10791,5.4,5.5,5.000457,12.2,13.125,13.530407,20.4,19.125,19.309099,119.1,117.0625,115.365295,82.48,80.7375,83.436808,69.88,71.825,69.622222,0.36,0.401125,0.394094,0.472,0.436,0.462203,0.618,0.603125,0.586725,57.74,59.1125,60.20407,71.6,66.0875,69.722314,9.58,10.25,10.031779,14.78,15.55,14.189712,0.5814,0.56925,0.54195,15.26,15.7625,16.463557,33.46,35.675,39.790581,0.284,0.296625,0.307113
4,2022-12-07,@,florida,W,75,54,30,58,0.517,7,19,0.368,8,9,0.889,8,39,16,9,7,15,19,105.6,76.1,70.8,0.155,0.328,0.602,60.9,53.3,12.7,18.4,0.578,19.4,29.6,0.138,26.8,27.777778,24.793019,58.6,58.444444,57.229386,0.4546,0.475444,0.431738,10.6,9.333333,9.029416,27.6,25.333333,25.482853,0.376,0.364444,0.349926,16.8,18.0,19.657369,22.4,24.444444,25.994361,0.7262,0.726111,0.747995,10.8,10.777778,11.436366,34.4,35.777778,34.345679,18.4,18.0,16.039171,7.8,7.555556,7.405274,5.0,5.222222,4.333638,11.4,12.333333,11.020271,19.8,18.777778,18.206066,114.86,115.966667,112.64353,86.12,82.077778,86.557872,70.32,71.5,69.381481,0.3894,0.422,0.459063,0.4696,0.433222,0.445136,0.5846,0.593444,0.56315,55.78,57.566667,55.202713,67.8,64.566667,63.948209,11.0,10.4,10.554519,13.62,14.566667,11.693141,0.5436,0.555556,0.509967,14.06,14.866667,13.542372,34.76,35.077778,36.627054,0.2928,0.311333,0.347742


## Merge opponent data

In [87]:
all_stat_cols = [item for col in STAT_LABELS for item in [col, f'{col}_SMA', f'{col}_CMA', f'{col}_EMA']]
opposing_stat_cols = [item for col in STAT_LABELS for item in [f'opp_{col}', f'opp_{col}_SMA', f'opp_{col}_CMA', f'opp_{col}_EMA']]
rename_opposing_cols = {item: f'opp_{item}' for stat in STAT_LABELS for item in [f'{stat}', f'{stat}_SMA', f'{stat}_CMA', f'{stat}_EMA']}

def merge_opponent_data_for_school(school_key, season):
    file_path = get_team_season_file_path(school_key, season, f'{school_key}_ma.csv')
    team_df = pd.read_csv(file_path)

    if team_df.shape[0] < 1:
        return

    home_df, away_df = pd.DataFrame(), pd.DataFrame()
    for index, row in team_df.iterrows():
        try:
            game_obj = row.to_dict()
            opponent_key = game_obj.get('Opp key')

            opponent_file_path = get_team_season_file_path(opponent_key, season, f'{opponent_key}_ma.csv')
            opponent_df = pd.read_csv(opponent_file_path)

            opponent_df = opponent_df.loc[(opponent_df['Opp key'] == school_key) & (opponent_df['Date'] == game_obj.get('Date'))]
            current_df = team_df[(team_df['Opp key'] == game_obj.get('Opp key')) & (team_df['Date'] == game_obj.get('Date'))]

            if game_obj.get('Location') == '@':
                home_df, away_df = pd.concat([home_df, opponent_df]), pd.concat([away_df, current_df])
            else:
                home_df, away_df = pd.concat([home_df, current_df]), pd.concat([away_df, opponent_df])
        except FileNotFoundError:
            continue

    # flip score column names for away dataframe to match home dataframe
    away_df.rename(columns={'Tm': 'Opp', 'Opp': 'Tm'}, inplace=True)

    away_df = away_df.drop(['Location', 'Opp key', 'W/L'], axis=1)
    away_df.rename(columns=rename_opposing_cols, inplace=True)

    merged_df = pd.merge(home_df, away_df, on=["Date", "Tm", "Opp"])
    merged_df = merged_df.sort_values(by='Date')

    merged_file_path = get_team_season_file_path(school_key, season, f'{school_key}_full.csv')
    merged_df.to_csv(merged_file_path, index=False)

def merge_opponent_data_for_single_season(season):
    for school_key in tqdm(SR_SCHOOL_KEYS, unit=f'school ({season} full csv)'):
        try:
            merge_opponent_data_for_school(school_key, season)
        except FileNotFoundError:
            continue

def merge_opponent_data(seasons):
    for season in seasons:
        merge_opponent_data_for_single_season(season)

In [85]:
merge_opponent_data(['2020', '2021', '2022', '2023'])

100%|██████████| 491/491 [01:29<00:00,  5.47school (2020 full csv)/s]
100%|██████████| 491/491 [01:08<00:00,  7.12school (2021 full csv)/s]
100%|██████████| 491/491 [02:01<00:00,  4.04school (2022 full csv)/s]
100%|██████████| 491/491 [02:06<00:00,  3.87school (2023 full csv)/s]


Quick visual verification:

In [86]:
school_key, season = 'connecticut', 2023
file_path = get_team_season_file_path(school_key, season, f'{school_key}_full.csv')
team_df = pd.read_csv(file_path)
team_df.head()

Unnamed: 0,Date,Location,Opp key,W/L,Tm,Opp,FG,FGA,FG%,3P,3PA,3P%,FT,FTA,FT%,ORB,TRB,AST,STL,BLK,TOV,PF,ORtg,DRtg,Pace,FTr,3PAr,TS%,TRB%,AST%,STL%,BLK%,eFG%,TOV%,ORB%,FT/FGA,FG_SMA,FG_CMA,FG_EMA,FGA_SMA,FGA_CMA,FGA_EMA,FG%_SMA,FG%_CMA,FG%_EMA,3P_SMA,3P_CMA,3P_EMA,3PA_SMA,3PA_CMA,3PA_EMA,3P%_SMA,3P%_CMA,3P%_EMA,FT_SMA,FT_CMA,FT_EMA,FTA_SMA,FTA_CMA,FTA_EMA,FT%_SMA,FT%_CMA,FT%_EMA,ORB_SMA,ORB_CMA,ORB_EMA,TRB_SMA,TRB_CMA,TRB_EMA,AST_SMA,AST_CMA,AST_EMA,STL_SMA,STL_CMA,STL_EMA,BLK_SMA,BLK_CMA,BLK_EMA,TOV_SMA,TOV_CMA,TOV_EMA,PF_SMA,PF_CMA,PF_EMA,ORtg_SMA,ORtg_CMA,ORtg_EMA,DRtg_SMA,DRtg_CMA,DRtg_EMA,Pace_SMA,Pace_CMA,Pace_EMA,FTr_SMA,FTr_CMA,FTr_EMA,3PAr_SMA,3PAr_CMA,3PAr_EMA,TS%_SMA,TS%_CMA,TS%_EMA,TRB%_SMA,TRB%_CMA,TRB%_EMA,AST%_SMA,AST%_CMA,AST%_EMA,STL%_SMA,STL%_CMA,STL%_EMA,BLK%_SMA,BLK%_CMA,BLK%_EMA,eFG%_SMA,eFG%_CMA,eFG%_EMA,TOV%_SMA,TOV%_CMA,TOV%_EMA,ORB%_SMA,ORB%_CMA,ORB%_EMA,FT/FGA_SMA,FT/FGA_CMA,FT/FGA_EMA,opp_FG,opp_FGA,opp_FG%,opp_3P,opp_3PA,opp_3P%,opp_FT,opp_FTA,opp_FT%,opp_ORB,opp_TRB,opp_AST,opp_STL,opp_BLK,opp_TOV,opp_PF,opp_ORtg,opp_DRtg,opp_Pace,opp_FTr,opp_3PAr,opp_TS%,opp_TRB%,opp_AST%,opp_STL%,opp_BLK%,opp_eFG%,opp_TOV%,opp_ORB%,opp_FT/FGA,opp_FG_SMA,opp_FG_CMA,opp_FG_EMA,opp_FGA_SMA,opp_FGA_CMA,opp_FGA_EMA,opp_FG%_SMA,opp_FG%_CMA,opp_FG%_EMA,opp_3P_SMA,opp_3P_CMA,opp_3P_EMA,opp_3PA_SMA,opp_3PA_CMA,opp_3PA_EMA,opp_3P%_SMA,opp_3P%_CMA,opp_3P%_EMA,opp_FT_SMA,opp_FT_CMA,opp_FT_EMA,opp_FTA_SMA,opp_FTA_CMA,opp_FTA_EMA,opp_FT%_SMA,opp_FT%_CMA,opp_FT%_EMA,opp_ORB_SMA,opp_ORB_CMA,opp_ORB_EMA,opp_TRB_SMA,opp_TRB_CMA,opp_TRB_EMA,opp_AST_SMA,opp_AST_CMA,opp_AST_EMA,opp_STL_SMA,opp_STL_CMA,opp_STL_EMA,opp_BLK_SMA,opp_BLK_CMA,opp_BLK_EMA,opp_TOV_SMA,opp_TOV_CMA,opp_TOV_EMA,opp_PF_SMA,opp_PF_CMA,opp_PF_EMA,opp_ORtg_SMA,opp_ORtg_CMA,opp_ORtg_EMA,opp_DRtg_SMA,opp_DRtg_CMA,opp_DRtg_EMA,opp_Pace_SMA,opp_Pace_CMA,opp_Pace_EMA,opp_FTr_SMA,opp_FTr_CMA,opp_FTr_EMA,opp_3PAr_SMA,opp_3PAr_CMA,opp_3PAr_EMA,opp_TS%_SMA,opp_TS%_CMA,opp_TS%_EMA,opp_TRB%_SMA,opp_TRB%_CMA,opp_TRB%_EMA,opp_AST%_SMA,opp_AST%_CMA,opp_AST%_EMA,opp_STL%_SMA,opp_STL%_CMA,opp_STL%_EMA,opp_BLK%_SMA,opp_BLK%_CMA,opp_BLK%_EMA,opp_eFG%_SMA,opp_eFG%_CMA,opp_eFG%_EMA,opp_TOV%_SMA,opp_TOV%_CMA,opp_TOV%_EMA,opp_ORB%_SMA,opp_ORB%_CMA,opp_ORB%_EMA,opp_FT/FGA_SMA,opp_FT/FGA_CMA,opp_FT/FGA_EMA
0,2022-11-25,N,alabama,W,82,67,26,60,0.433,9,24,0.375,21,24,0.875,8,30,18,8,7,11,23,110.8,90.5,74.1,0.4,0.4,0.574,48.4,69.2,10.8,20.6,0.508,13.3,24.2,0.35,30.4,30.166667,30.55144,58.8,59.5,60.399177,0.5192,0.509333,0.50749,11.0,10.0,11.72428,26.2,25.833333,28.62963,0.416,0.381333,0.400251,15.0,16.166667,13.843621,21.4,23.0,19.855967,0.6846,0.688833,0.669107,9.2,10.166667,9.847737,35.8,36.5,35.666667,20.2,19.5,20.757202,7.2,7.833333,8.242798,5.4,5.833333,6.251029,12.4,12.833333,12.193416,19.0,19.166667,20.695473,120.24,118.833333,119.646914,80.04,78.55,79.407819,72.32,72.95,72.5,0.369,0.3895,0.332712,0.4432,0.432833,0.471708,0.6314,0.617333,0.622881,57.5,59.033333,58.534156,66.32,64.466667,67.750206,9.76,10.55,11.246502,14.06,16.483333,17.951852,0.613,0.594167,0.604638,15.16,15.333333,14.793004,30.56,33.533333,32.428807,0.2594,0.274333,0.233255,21.0,50.0,0.42,6.0,16.0,0.375,19.0,25.0,0.76,7.0,32.0,10.0,4.0,6.0,19.0,23.0,90.5,110.8,74.1,0.5,0.32,0.541,51.6,47.6,5.4,16.7,0.48,23.5,24.1,0.38,27.6,27.6,28.148148,63.6,63.6,64.506173,0.4386,0.4386,0.438938,10.4,10.4,10.395062,30.8,30.8,30.777778,0.3306,0.3306,0.326,18.4,18.4,17.246914,25.8,25.8,24.518519,0.696,0.696,0.693432,16.4,16.4,15.839506,49.4,49.4,48.320988,15.4,15.4,16.432099,6.0,6.0,5.901235,6.8,6.8,6.481481,15.4,15.4,14.419753,19.8,19.8,19.604938,113.2,113.2,113.32716,80.66,80.66,83.175309,74.62,74.62,74.497531,0.4202,0.4202,0.387889,0.4812,0.4812,0.474827,0.5572,0.5572,0.553432,63.32,63.32,61.261728,54.88,54.88,57.497531,8.06,8.06,7.941975,16.9,16.9,15.306173,0.5214,0.5214,0.520247,16.68,16.68,15.68642,44.62,44.62,42.544444,0.3032,0.3032,0.274864
1,2022-11-27,N,iowa-state,W,71,53,22,53,0.415,7,26,0.269,20,25,0.8,18,45,16,5,2,17,15,112.7,84.1,62.8,0.472,0.491,0.547,70.3,72.7,7.9,4.9,0.481,20.8,60.0,0.377,29.4,29.571429,29.034294,59.8,59.571429,60.266118,0.493,0.498429,0.48266,10.8,9.857143,10.816187,26.8,25.571429,27.08642,0.3958,0.380429,0.391834,16.4,16.857143,16.229081,21.0,23.142857,21.237311,0.752,0.715429,0.737738,8.6,9.857143,9.231824,33.8,35.571429,33.777778,20.4,19.285714,19.838134,8.0,7.857143,8.161866,6.2,6.0,6.500686,11.8,12.571429,11.79561,20.0,19.714286,21.463649,117.82,117.685714,116.697942,81.86,80.257143,83.105213,73.06,73.114286,73.033333,0.3544,0.391,0.355141,0.4468,0.428143,0.447805,0.6186,0.611143,0.606587,55.42,57.514286,55.156104,69.2,65.142857,68.233471,10.78,10.585714,11.097668,16.72,17.071429,18.834568,0.5836,0.581857,0.572425,14.38,15.042857,14.295336,28.52,32.2,29.685871,0.2784,0.285143,0.27217,22.0,54.0,0.407,3.0,13.0,0.231,6.0,10.0,0.6,7.0,19.0,9.0,13.0,0.0,10.0,21.0,84.1,112.7,62.8,0.185,0.241,0.451,29.7,40.9,20.6,0.0,0.435,14.5,20.6,0.111,29.0,29.0,28.333333,64.0,64.0,62.555556,0.4538,0.4538,0.45337,7.0,7.0,7.444444,23.0,23.0,22.765432,0.3132,0.3132,0.337309,12.4,12.4,12.580247,16.2,16.2,16.382716,0.7638,0.7638,0.766432,12.6,12.6,11.432099,34.8,34.8,32.654321,17.2,17.2,17.765432,11.0,11.0,10.814815,3.2,3.2,3.111111,12.6,12.6,12.518519,19.8,19.8,20.62963,108.86,108.86,107.932099,78.14,78.14,82.996296,69.4,69.4,69.081481,0.2552,0.2552,0.263444,0.3586,0.3586,0.36263,0.5402,0.5402,0.545457,55.44,55.44,53.046914,59.38,59.38,62.777778,15.34,15.34,15.080247,13.16,13.16,12.393827,0.5092,0.5092,0.513728,14.88,14.88,15.01358,37.82,37.82,35.283951,0.1952,0.1952,0.201988
2,2022-12-01,H,oklahoma-state,W,74,64,21,56,0.375,8,23,0.348,24,33,0.727,10,28,11,8,3,6,16,107.2,92.8,68.9,0.589,0.411,0.516,45.2,52.4,11.6,6.7,0.446,7.7,30.3,0.429,28.2,28.625,26.689529,57.8,58.75,57.844079,0.4872,0.488,0.460107,11.0,9.5,9.544124,27.4,25.625,26.72428,0.3974,0.3665,0.350889,16.0,17.25,17.486054,20.4,23.375,22.491541,0.7548,0.726,0.758492,9.8,10.875,12.15455,35.4,36.75,37.518519,20.2,18.875,18.558756,6.8,7.5,7.10791,5.4,5.5,5.000457,12.2,13.125,13.530407,20.4,19.125,19.309099,119.1,117.0625,115.365295,82.48,80.7375,83.436808,69.88,71.825,69.622222,0.36,0.401125,0.394094,0.472,0.436,0.462203,0.618,0.603125,0.586725,57.74,59.1125,60.20407,71.6,66.0875,69.722314,9.58,10.25,10.031779,14.78,15.55,14.189712,0.5814,0.56925,0.54195,15.26,15.7625,16.463557,33.46,35.675,39.790581,0.284,0.296625,0.307113,22.0,56.0,0.393,4.0,11.0,0.364,16.0,19.0,0.842,11.0,34.0,5.0,4.0,3.0,16.0,25.0,92.8,107.2,68.9,0.339,0.196,0.492,54.8,22.7,5.8,9.1,0.429,19.7,37.9,0.286,28.6,27.571429,28.727023,60.0,59.571429,58.971193,0.4762,0.462143,0.488634,6.2,5.857143,5.801097,19.8,20.0,19.460905,0.3004,0.284571,0.292995,14.4,14.142857,14.03155,19.8,20.285714,21.027435,0.7382,0.699286,0.667617,11.0,10.571429,11.322359,37.8,38.0,37.742112,15.8,14.571429,14.259259,7.0,7.142857,7.585734,5.8,6.0,6.33059,13.8,14.285714,13.617284,15.4,15.857143,13.622771,106.86,102.5,106.883676,84.52,84.742857,82.777366,71.5,72.342857,71.85583,0.3382,0.346,0.363668,0.3264,0.332857,0.325915,0.561,0.542571,0.561052,58.74,58.157143,59.866255,56.64,53.557143,50.473937,9.6,9.728571,10.477503,16.78,17.328571,18.435254,0.5264,0.510143,0.536724,16.46,17.0,16.381756,37.3,34.742857,38.7893,0.245,0.240286,0.241519
3,2022-12-07,H,connecticut,L,54,75,16,53,0.302,4,15,0.267,18,28,0.643,6,25,6,8,6,12,14,76.1,105.6,70.8,0.528,0.283,0.407,39.1,37.5,11.3,15.4,0.34,15.3,16.2,0.34,30.8,29.222222,30.206066,61.8,61.444444,59.438957,0.5006,0.478333,0.508084,8.6,7.555556,9.068587,21.4,19.666667,20.828837,0.3926,0.374889,0.423205,11.8,15.0,14.844688,16.4,20.111111,19.58299,0.688,0.724333,0.731288,7.6,8.555556,8.090992,30.6,34.444444,32.690901,14.6,13.666667,14.050602,7.0,6.111111,6.685109,5.6,6.111111,5.417924,10.6,10.888889,10.725194,18.4,17.0,17.447798,113.98,110.655556,118.554885,96.84,93.022222,88.874638,72.78,73.722222,71.792471,0.2842,0.342556,0.344753,0.3482,0.320444,0.353255,0.5948,0.574889,0.61665,50.04,53.522222,53.599985,46.64,46.388889,46.16592,9.54,8.3,9.278326,13.56,14.7,13.252751,0.5712,0.540111,0.585274,13.22,13.277778,13.488599,25.66,27.511111,29.760997,0.2082,0.257,0.263779,30.0,58.0,0.517,7.0,19.0,0.368,8.0,9.0,0.889,8.0,39.0,16.0,9.0,7.0,15.0,19.0,105.6,76.1,70.8,0.155,0.328,0.602,60.9,53.3,12.7,18.4,0.578,19.4,29.6,0.138,26.8,27.777778,24.793019,58.6,58.444444,57.229386,0.4546,0.475444,0.431738,10.6,9.333333,9.029416,27.6,25.333333,25.482853,0.376,0.364444,0.349926,16.8,18.0,19.657369,22.4,24.444444,25.994361,0.7262,0.726111,0.747995,10.8,10.777778,11.436366,34.4,35.777778,34.345679,18.4,18.0,16.039171,7.8,7.555556,7.405274,5.0,5.222222,4.333638,11.4,12.333333,11.020271,19.8,18.777778,18.206066,114.86,115.966667,112.64353,86.12,82.077778,86.557872,70.32,71.5,69.381481,0.3894,0.422,0.459063,0.4696,0.433222,0.445136,0.5846,0.593444,0.56315,55.78,57.566667,55.202713,67.8,64.566667,63.948209,11.0,10.4,10.554519,13.62,14.566667,11.693141,0.5436,0.555556,0.509967,14.06,14.866667,13.542372,34.76,35.077778,36.627054,0.2928,0.311333,0.347742
4,2022-12-10,H,long-island-university,W,114,61,45,72,0.625,14,32,0.438,10,13,0.769,9,38,29,10,6,11,17,142.5,76.3,80.3,0.181,0.444,0.729,67.9,64.4,12.5,14.0,0.722,12.3,39.1,0.139,25.8,28.0,26.52868,58.0,58.4,57.486257,0.4432,0.4796,0.460158,9.6,9.1,8.352944,25.8,24.7,23.321902,0.3638,0.3648,0.355951,15.8,17.0,15.77158,20.4,22.9,20.329574,0.7672,0.7424,0.794996,11.0,10.5,10.290911,35.4,36.1,35.897119,16.6,17.8,16.026114,7.8,7.7,7.936849,5.2,5.4,5.222425,12.0,12.6,12.346848,19.6,18.8,18.470711,110.98,114.93,110.295687,85.56,81.48,83.071915,69.34,71.43,69.854321,0.3582,0.3953,0.357709,0.4434,0.4227,0.40609,0.5694,0.5943,0.5761,56.44,57.9,57.101809,64.18,63.44,60.398806,11.18,10.63,11.269679,14.5,14.95,13.928761,0.5248,0.5578,0.532645,15.02,15.32,15.494914,35.3,34.53,34.284703,0.2778,0.294,0.277828,22.0,52.0,0.423,2.0,9.0,0.222,15.0,24.0,0.625,4.0,18.0,5.0,4.0,2.0,21.0,12.0,76.3,142.5,80.3,0.462,0.173,0.481,32.1,22.7,5.0,5.0,0.442,24.9,12.1,0.288,22.4,24.0,23.169182,55.4,58.625,56.155921,0.4034,0.40525,0.411432,4.8,6.375,4.836305,16.0,18.375,15.237311,0.2772,0.308625,0.273423,13.4,12.125,12.902149,19.8,18.75,18.773663,0.678,0.647,0.69162,5.6,7.5,5.752172,25.2,30.0,26.508002,10.6,13.5,12.973022,4.8,5.5,4.371742,3.4,3.625,2.852309,14.2,15.25,14.177412,16.4,16.25,15.965706,86.56,88.1875,88.095885,111.34,106.525,112.7631,73.08,75.3125,73.10727,0.3646,0.3265,0.338619,0.286,0.30725,0.267833,0.4872,0.4885,0.492355,40.04,43.95,42.230818,46.28,53.375,54.948925,6.62,7.25,6.02995,8.26,8.9875,6.963283,0.4464,0.457375,0.453623,17.86,18.3375,17.742753,17.1,22.1625,17.974348,0.2482,0.21325,0.234017


## Generate dataset

In [115]:
FINAL_FEATURES = [
    'FG_SMA', 'FG_CMA', 'FG_EMA', 'FGA_SMA', 'FGA_CMA', 'FGA_EMA', 'FG%_SMA', 'FG%_CMA', 'FG%_EMA', '3P_SMA', '3P_CMA', '3P_EMA', '3PA_SMA', '3PA_CMA', '3PA_EMA', '3P%_SMA', '3P%_CMA', '3P%_EMA', 'FT_SMA', 'FT_CMA', 'FT_EMA', 
    'FTA_SMA', 'FTA_CMA', 'FTA_EMA', 'FT%_SMA', 'FT%_CMA', 'FT%_EMA', 'ORB_SMA', 'ORB_CMA', 'ORB_EMA', 'TRB_SMA', 'TRB_CMA', 'TRB_EMA', 'AST_SMA', 'AST_CMA', 'AST_EMA', 'STL_SMA', 'STL_CMA', 'STL_EMA', 'BLK_SMA', 'BLK_CMA', 
    'BLK_EMA', 'TOV_SMA', 'TOV_CMA', 'TOV_EMA', 'PF_SMA', 'PF_CMA', 'PF_EMA', 'ORtg_SMA', 'ORtg_CMA', 'ORtg_EMA', 'DRtg_SMA', 'DRtg_CMA', 'DRtg_EMA', 'Pace_SMA', 'Pace_CMA', 'Pace_EMA', 'FTr_SMA', 'FTr_CMA', 'FTr_EMA', 
    '3PAr_SMA', '3PAr_CMA', '3PAr_EMA', 'TS%_SMA', 'TS%_CMA', 'TS%_EMA', 'TRB%_SMA', 'TRB%_CMA', 'TRB%_EMA', 'AST%_SMA', 'AST%_CMA', 'AST%_EMA', 'STL%_SMA', 'STL%_CMA', 'STL%_EMA', 'BLK%_SMA', 'BLK%_CMA', 'BLK%_EMA', 'eFG%_SMA', 
    'eFG%_CMA', 'eFG%_EMA', 'TOV%_SMA', 'TOV%_CMA', 'TOV%_EMA', 'ORB%_SMA', 'ORB%_CMA', 'ORB%_EMA', 'FT/FGA_SMA', 'FT/FGA_CMA', 'FT/FGA_EMA', 'opp_FG_SMA', 'opp_FG_CMA', 'opp_FG_EMA', 'opp_FGA_SMA', 'opp_FGA_CMA', 'opp_FGA_EMA', 
    'opp_FG%_SMA', 'opp_FG%_CMA', 'opp_FG%_EMA', 'opp_3P_SMA', 'opp_3P_CMA', 'opp_3P_EMA', 'opp_3PA_SMA', 'opp_3PA_CMA', 'opp_3PA_EMA', 'opp_3P%_SMA', 'opp_3P%_CMA', 'opp_3P%_EMA', 'opp_FT_SMA', 'opp_FT_CMA', 'opp_FT_EMA', 
    'opp_FTA_SMA', 'opp_FTA_CMA', 'opp_FTA_EMA', 'opp_FT%_SMA', 'opp_FT%_CMA', 'opp_FT%_EMA', 'opp_ORB_SMA', 'opp_ORB_CMA', 'opp_ORB_EMA', 'opp_TRB_SMA', 'opp_TRB_CMA', 'opp_TRB_EMA', 'opp_AST_SMA', 'opp_AST_CMA', 'opp_AST_EMA', 
    'opp_STL_SMA', 'opp_STL_CMA', 'opp_STL_EMA', 'opp_BLK_SMA', 'opp_BLK_CMA', 'opp_BLK_EMA', 'opp_TOV_SMA', 'opp_TOV_CMA', 'opp_TOV_EMA', 'opp_PF_SMA', 'opp_PF_CMA', 'opp_PF_EMA', 'opp_ORtg_SMA', 'opp_ORtg_CMA', 'opp_ORtg_EMA', 
    'opp_DRtg_SMA', 'opp_DRtg_CMA', 'opp_DRtg_EMA', 'opp_Pace_SMA', 'opp_Pace_CMA', 'opp_Pace_EMA', 'opp_FTr_SMA', 'opp_FTr_CMA', 'opp_FTr_EMA', 'opp_3PAr_SMA', 'opp_3PAr_CMA', 'opp_3PAr_EMA', 'opp_TS%_SMA', 'opp_TS%_CMA', 
    'opp_TS%_EMA', 'opp_TRB%_SMA', 'opp_TRB%_CMA', 'opp_TRB%_EMA', 'opp_AST%_SMA', 'opp_AST%_CMA', 'opp_AST%_EMA', 'opp_STL%_SMA', 'opp_STL%_CMA', 'opp_STL%_EMA', 'opp_BLK%_SMA', 'opp_BLK%_CMA', 'opp_BLK%_EMA', 'opp_eFG%_SMA', 
    'opp_eFG%_CMA', 'opp_eFG%_EMA', 'opp_TOV%_SMA', 'opp_TOV%_CMA', 'opp_TOV%_EMA', 'opp_ORB%_SMA', 'opp_ORB%_CMA', 'opp_ORB%_EMA', 'opp_FT/FGA_SMA', 'opp_FT/FGA_CMA', 'opp_FT/FGA_EMA', 'Neutral', 'Win'
]

def generate_season_dataset(season):
    all_data_df = pd.DataFrame()

    for school_key in tqdm(SR_SCHOOL_KEYS):
        file_path = get_team_season_file_path(school_key, season, f'{school_key}_full.csv')

        if os.path.exists(file_path):
            team_df = pd.read_csv(file_path)
            all_data_df = pd.concat([all_data_df, team_df])
        else:
            continue

    # sort by date
    all_data_df = all_data_df.sort_values(by="Date")

    # Add feature for if game is neutral site
    all_data_df['Neutral'] = all_data_df['Location'].apply(lambda x: 1 if x == 'N' else 0)

    # add label for win (1 = win, 0 = loss)
    all_data_df['Win'] = (all_data_df['Tm'] > all_data_df['Opp']).astype(int)

    # remove meta columns
    all_data_df = all_data_df.drop(META_LABELS, axis=1)

    # remove non moving average columns
    all_data_df = all_data_df.drop(STAT_LABELS, axis=1)
    all_data_df = all_data_df.drop([f'opp_{col}' for col in STAT_LABELS], axis=1)

    # Drop any rows with NULL value
    all_data_df.dropna(inplace=True)

    # Drop any duplicate rows
    all_data_df.drop_duplicates(inplace=True)

    # Reorder columns
    all_data_df = all_data_df.reindex(FINAL_FEATURES, axis=1)

    assert all([all_data_df.columns.to_list()[i] == FINAL_FEATURES[i] for i in range(len(FINAL_FEATURES))])

    training_data_path = os.path.abspath(f'../../data/{season}_dataset.csv')
    all_data_df.to_csv(training_data_path, index=False)

def generate_datasets(seasons):
    for season in seasons:
        generate_season_dataset(season)

In [116]:
generate_datasets(['2020', '2021', '2022', '2023'])

100%|██████████| 491/491 [00:07<00:00, 61.97it/s]
100%|██████████| 491/491 [00:06<00:00, 80.30it/s] 
100%|██████████| 491/491 [00:08<00:00, 61.35it/s]
100%|██████████| 491/491 [00:07<00:00, 61.97it/s]


## Create full train test split
A typical train-test split for machine learning models is to use 70% of the data for training and 30% for testing.

In [121]:
def generate_test_train(seasons):
    data_dir_path = os.path.abspath(f'../../data/')
    pattern = r"\d{4}_dataset\.csv"
    filenames = [filename for filename in os.listdir(data_dir_path) if re.match(pattern, filename)]
    merged_df = pd.concat([pd.read_csv(os.path.join(data_dir_path, filename)) for filename in filenames], ignore_index=True)

    print(f'Original shape: {merged_df.shape}')
    merged_df.dropna(inplace=True)
    print(f'Shape after dropping rows with null values: {merged_df.shape}')
    merged_df.drop_duplicates(inplace=True)
    print(f'Shape after dropping duplicate rows: {merged_df.shape}')

    # Split the DataFrame into two based on 'Neutral' column
    neutral_df = merged_df[merged_df['Neutral'] == 1]
    home_away_df = merged_df[merged_df['Neutral'] == 0]

    neutral_train_df, neutral_test_df = train_test_split(neutral_df, test_size=0.3)
    home_away_train_df, home_away_test_df = train_test_split(home_away_df, test_size=0.3)

    print(len(neutral_train_df), 'neutral train examples')
    print(len(neutral_test_df), 'neutral test examples')
    print(len(home_away_train_df), 'home/away train examples')
    print(len(home_away_test_df), 'home/away test examples')

    neutral_train_df.to_csv(os.path.join(data_dir_path, 'neutral_training_set.csv'), index=False)
    neutral_test_df.to_csv(os.path.join(data_dir_path, 'neutral_testing_set.csv'), index=False)
    home_away_train_df.to_csv(os.path.join(data_dir_path, 'home_away_training_set.csv'), index=False)
    home_away_test_df.to_csv(os.path.join(data_dir_path, 'home_away_testing_set.csv'), index=False)

In [122]:
generate_test_train(['2020', '2021', '2022', '2023'])

Original shape: (18705, 182)
Shape after dropping rows with null values: (18705, 182)
Shape after dropping duplicate rows: (18705, 182)
2429 neutral train examples
1041 neutral test examples
10664 home/away train examples
4571 home/away test examples
