In [2]:
import pandas as pd
import os
from datetime import datetime
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier

In [34]:
def team_name(filename):
    df = pd.read_csv(filename)
    df["team"] = filename.split("/")[1].split("-schedule")[0]

def extract_opponent(game_info):
    if '@' in game_info:
        if "" in (game_info.split("@")):
            return game_info.split("@")[1].strip()
        else:
            return game_info.split("@")[0].strip()
    else:
        return game_info

def game_site(game_info):
    if '@' in game_info:
        if "" in (game_info.split("@")):
            return "away"
        else:
            return "neutral"
    else:
        return "home"
def date_conversion(date_string):
    return datetime.strptime(date_string, "%m/%d/%Y").date()

def extract_sets(score, idx):
    return int(score.split('-')[idx])



In [48]:
df_list = []
for directory in os.listdir("all_schedules"):
    for file in os.listdir("all_schedules/"+directory):
        filename = os.path.join("all_schedules", directory, file)
        df = pd.read_csv(filename)
        df["team"] = file.split("-schedule")[0]
        df["opponent"] = df["opponent/venue"].apply(extract_opponent)
        df["game_site"] = df["opponent/venue"].apply(game_site)
        df["date"] = df["date"].apply(date_conversion)
        df['sets_won'] = df.apply(lambda row: extract_sets(row['result'], 1) if row['game_site'] == 'away' else extract_sets(row['result'], 0), axis=1)
        df['sets_lost'] = df.apply(lambda row: extract_sets(row['result'], 0) if row['game_site'] == 'away' else extract_sets(row['result'], 1), axis=1)
        df["win"] = (df["sets_won"] == 3).astype(int)
        df_list.append(df)

result_df = pd.concat(df_list, ignore_index=True)

In [29]:
len(result_df)

137882

In [49]:
#TODO: drop mp, attend, retatt, ms, tb
result_df.head()

Unnamed: 0,date,opponent/venue,result,sets,mp,kills,errors,total_attacks,hit_pct,assists,...,team,opponent,game_site,sets_won,sets_lost,win,retatt,attend,ms,tb
0,2016-08-26,"Florida @ Eugene, Ore.",3-1,4.0,0.0,67.0,12.0,143.0,0.385,64.0,...,Arizona,Florida,neutral,3,1,1,,,,
1,2016-08-27,"Texas @ Eugene, Ore.",3-0,3.0,0.0,48.0,14.0,112.0,0.304,42.0,...,Arizona,Texas,neutral,3,0,1,,,,
2,2016-09-02,Iowa St.,3-0,3.0,0.0,43.0,8.0,90.0,0.389,36.0,...,Arizona,Iowa St.,home,3,0,1,,,,
3,2016-09-03,Oregon St.,3-0,3.0,0.0,41.0,10.0,82.0,0.378,35.0,...,Arizona,Oregon St.,home,3,0,1,,,,
4,2016-09-09,"Arkansas @ Albuquerque, N.M.",3-0,3.0,0.0,40.0,15.0,98.0,0.255,37.0,...,Arizona,Arkansas,neutral,3,0,1,,,,


In [69]:
input_cols = ['sets', 'kills', 'errors',
       'total_attacks', 'hit_pct', 'assists', 'aces', 'serr', 'digs', 'rerr',
       'b_solo', 'b_assist', 'b_error', 'pts', 'bhe']#, 'team', 'opponent',
       #'game_site', 'sets_won', 'sets_lost']

train_x = (result_df[([x.year != 2023 for x in result_df["date"]])])[input_cols]
train_y = result_df[([x.year != 2023 for x in result_df["date"]])][['win']]
test_x = (result_df[([x.year == 2023 for x in result_df["date"]])])[input_cols]
test_y = result_df[([x.year == 2023 for x in result_df["date"]])]['win']

In [70]:
from sklearn.tree import DecisionTreeClassifier, export_text

clf = DecisionTreeClassifier(random_state=42)

# Train the classifier on the training data
clf.fit(train_x, train_y)

# Make predictions on the test data
predictions = clf.predict(test_x)

# Print the accuracy of the model on the test set
accuracy = sum(predictions == test_y) / len(test_y)
print(f"Accuracy: {accuracy:.2%}")

# Display the decision tree rules
tree_rules = export_text(clf, feature_names=input_cols)
print("Decision Tree Rules:")
print(tree_rules)

Accuracy: 53.47%
Decision Tree Rules:
|--- kills <= 53.50
|   |--- rerr <= 4.50
|   |   |--- b_solo <= 0.50
|   |   |   |--- errors <= 12.50
|   |   |   |   |--- hit_pct <= 0.26
|   |   |   |   |   |--- class: 0
|   |   |   |   |--- hit_pct >  0.26
|   |   |   |   |   |--- rerr <= 3.50
|   |   |   |   |   |   |--- digs <= 32.50
|   |   |   |   |   |   |   |--- class: 1
|   |   |   |   |   |   |--- digs >  32.50
|   |   |   |   |   |   |   |--- b_assist <= 3.00
|   |   |   |   |   |   |   |   |--- class: 1
|   |   |   |   |   |   |   |--- b_assist >  3.00
|   |   |   |   |   |   |   |   |--- b_assist <= 10.50
|   |   |   |   |   |   |   |   |   |--- errors <= 7.50
|   |   |   |   |   |   |   |   |   |   |--- class: 0
|   |   |   |   |   |   |   |   |   |--- errors >  7.50
|   |   |   |   |   |   |   |   |   |   |--- hit_pct <= 0.26
|   |   |   |   |   |   |   |   |   |   |   |--- class: 0
|   |   |   |   |   |   |   |   |   |   |--- hit_pct >  0.26
|   |   |   |   |   |   |   |   |   | 