# Configs

In [1]:
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import warnings

from Learning.MLPClassifier import MLPClassifier
from Learning.dataset_helper_functions import *
import Learning.MLPClassifier
from sklearn.metrics import classification_report

warnings.filterwarnings(
    "ignore",
    message="DataFrameGroupBy.apply operated on the grouping columns.*",
    category=DeprecationWarning,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.mps.is_available():
    device = torch.device("mps")
print('Using device:', device)


Using device: mps


# Classification Settings

In [2]:

HIGHLIGHTS_MEAN_1Q = 3
HIGHLIGHTS_MEAN_2Q = 4
HIGHLIGHTS_MEAN_3Q = 4
HIGHLIGHTS_MEAN_4Q = 4

#KEEP THESE TRUE
INC_FTS = True
INC_TIES = True

CONTEXT_WINDOW = 4

highlights_per_qtr = {'1st': HIGHLIGHTS_MEAN_1Q, '2nd': HIGHLIGHTS_MEAN_2Q, '3rd':HIGHLIGHTS_MEAN_3Q, '4th':HIGHLIGHTS_MEAN_4Q}

seed = 42
TRAINING_END_IDX = 284986
data_path = "../../full season data/plays_with_onehot_v2_withoutOT.csv"
model_path = '/Users/galishai/PycharmProjects/AI_PROJECT_SPORTS_HIGHLIGHTS/Learning/saved_model/mlp_final_checkpoint_withoutOT_esf1.pth'
interval_weights_path = '/Users/galishai/PycharmProjects/AI_PROJECT_SPORTS_HIGHLIGHTS/Learning/nba_stats/unique_per_interval_minute.csv'

trained_model_params={
    'hidden_dim' : 32,
    'dropout' : 0.5,
}

# Data Prep

In [3]:
dataset = pd.read_csv(data_path)
freeze_seeds(seed)
df = get_dataset(ds=dataset, verbose=False, rm_ft_ds=False, add_game_idx=True,
                 play_context_window=2,
                 team_context_window=2,
                 compact_players=1,
                 compact_oncourt=1,
                 compact_current_team=1,
                 drop_home_away_teams=1,
                 group_all_plays=1,
                 enum_quarters=1)
#dataset['game_id'] = df['game_id']
#len(start_positions)
games_idx = df['game_id'].unique()
#len(start_positions)
#games_idx
rng = np.random.default_rng(seed=seed)  # For reproducibility
shuffled_game_ids = rng.permutation(games_idx)

split_idx = int(0.6 * len(shuffled_game_ids))
train_game_ids = shuffled_game_ids[:split_idx]
test_game_ids = shuffled_game_ids[split_idx:]
split_test_ids = int(0.5 * len(test_game_ids))
val_game_ids = test_game_ids[:split_test_ids]
test_game_ids = test_game_ids[split_test_ids:]
print(len(train_game_ids))
print(len(val_game_ids))

X = df.drop(columns=['is_highlight', 'game_id']).to_numpy(dtype=np.float32)#.values.astype(np.float32)
y = df['is_highlight'].to_numpy(dtype=np.int32)#.values.astype(int)
#print(f"start1 loading time: {time.time() - start_time_1:.2f}s")
#start_time_2 = time.time()

altered_df_train = df[df['game_id'].isin(train_game_ids)]
altered_df_test = df[df['game_id'].isin(test_game_ids)]
unaltered_df_test = dataset[df['game_id'].isin(test_game_ids)]

altered_df_test = altered_df_test.reset_index(drop=True)
unaltered_df_test = unaltered_df_test.reset_index(drop=True)


X_train = altered_df_train.drop(columns=['is_highlight', 'game_id']).to_numpy(dtype=np.float32)
y_train = altered_df_train['is_highlight'].to_numpy(dtype=np.int32)
X_val = altered_df_test.drop(columns=['is_highlight', 'game_id']).to_numpy(dtype=np.float32)
y_val = altered_df_test['is_highlight'].to_numpy(dtype=np.int32)

scaler = StandardScaler().fit(X_train)

X_train=scaler.transform(X_train)
X_val=scaler.transform(X_val)

train_ds = TensorDataset(torch.from_numpy(X_train).float(),
                         torch.from_numpy(y_train).long())
val_ds = TensorDataset(torch.from_numpy(X_val).float(),
                         torch.from_numpy(y_val).long())


573
191


In [None]:
print(unaltered_df_test.head())
#print(X_test.shape)
#print(y_test.shape)

In [None]:
#altered_df_test.reset_index(drop=True)
#unaltered_df_test.reset_index(drop=True)

'''test_ds = TensorDataset(torch.from_numpy(X_test).float(),
                         torch.from_numpy(y_test).long())
test_loader = DataLoader(test_ds, batch_size=2048, shuffle=False)'''

# Load Model

In [4]:
checkpoint = torch.load(model_path, map_location=device)
model = MLPClassifier(input_dim=X_train.shape[1], hidden_dim=64, dropout=0).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
model.to(device).eval()

MLPClassifier(
  (net): Sequential(
    (0): Linear(in_features=113, out_features=64, bias=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0, inplace=False)
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0, inplace=False)
    (8): Linear(in_features=32, out_features=32, bias=True)
    (9): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0, inplace=False)
    (12): Linear(in_features=32, out_features=1, bias=True)
  )
)

# Predict Play Highlight Probabilities

In [5]:
logits = model(torch.tensor(X_val).to(device))
probs = torch.sigmoid(logits).detach().cpu().numpy()
unaltered_df_test['probs'] = probs
unaltered_df_test.head()

Unnamed: 0,time_left_qtr,play,distance,quarter,home_team,away_team,current_team,name,assister,win_difference,...,Oncourt_Player_27,Oncourt_Player_28,Oncourt_Player_29,Oncourt_Player_30,Oncourt_Player_31,Oncourt_Player_32,Oncourt_Player_33,date,is_highlight,probs
0,11:50,3,2,1st,WARRIORS,SUNS,WARRIORS,Chris Paul,Blank,1,...,0,0,0,0,0,0,0,"October 25, 2023",0,0.007733
1,11:48,7,0,1st,WARRIORS,SUNS,SUNS,Jusuf Nurkic,Blank,1,...,0,0,0,0,0,0,0,"October 25, 2023",0,0.009655
2,11:28,25,1,1st,WARRIORS,SUNS,SUNS,Jusuf Nurkic,Kevin Durant,1,...,0,0,0,0,0,0,0,"October 25, 2023",0,0.047389
3,11:28,9,0,1st,WARRIORS,SUNS,WARRIORS,Andrew Wiggins,Blank,1,...,0,0,0,0,0,0,0,"October 25, 2023",0,0.49477
4,11:28,10,0,1st,WARRIORS,SUNS,SUNS,Jusuf Nurkic,Blank,1,...,0,0,0,0,0,0,0,"October 25, 2023",0,0.15699


In [None]:
'''unaltered_df_test.loc[:,'predicted'] = 0 #init predictions as 0
unaltered_df_test.head()'''


In [None]:
'''
unaltered_df_test.head()

logits = model(X_test).detach().cpu().numpy().squeeze()
unaltered_df_test.loc[:, 'logits'] = logits
unaltered_df_test'''

#unaltered_df_test.loc[:,'y_truth'] = y_test.detach().cpu().numpy()
#print(unaltered_df_test.head())

#unaltered_df_test.head()
#unaltered_df_test['y_truth'] = y_test.detach().cpu().numpy()

In [None]:
unaltered_df_test.head()

# Predict Highlights

In [6]:
pred_accuracies = []
all_preds = []
all_labels = []
unaltered_df_test['time_left_qtr_sec'] = altered_df_test['time_left_qtr']
unaltered_df_test['game_id'] = altered_df_test['game_id']

df_interval_weights = pd.read_csv(interval_weights_path)
df_interval_weights.columns = df_interval_weights.columns.str.strip().str.replace('"', '')

df_interval_weights['start'] = df_interval_weights['time_left_qtr'].str.extract(r'\((\d+),')[0].astype(int)
df_interval_weights['end'] = df_interval_weights['time_left_qtr'].str.extract(r', (\d+)\]')[0].astype(int)
df_interval_weights['num_highlights_per_game'] = df_interval_weights['num_highlights_per_game'].astype(float)

unaltered_df_test.head()


Unnamed: 0,time_left_qtr,play,distance,quarter,home_team,away_team,current_team,name,assister,win_difference,...,Oncourt_Player_29,Oncourt_Player_30,Oncourt_Player_31,Oncourt_Player_32,Oncourt_Player_33,date,is_highlight,probs,time_left_qtr_sec,game_id
0,11:50,3,2,1st,WARRIORS,SUNS,WARRIORS,Chris Paul,Blank,1,...,0,0,0,0,0,"October 25, 2023",0,0.007733,710,2
1,11:48,7,0,1st,WARRIORS,SUNS,SUNS,Jusuf Nurkic,Blank,1,...,0,0,0,0,0,"October 25, 2023",0,0.009655,708,2
2,11:28,25,1,1st,WARRIORS,SUNS,SUNS,Jusuf Nurkic,Kevin Durant,1,...,0,0,0,0,0,"October 25, 2023",0,0.047389,688,2
3,11:28,9,0,1st,WARRIORS,SUNS,WARRIORS,Andrew Wiggins,Blank,1,...,0,0,0,0,0,"October 25, 2023",0,0.49477,688,2
4,11:28,10,0,1st,WARRIORS,SUNS,SUNS,Jusuf Nurkic,Blank,1,...,0,0,0,0,0,"October 25, 2023",0,0.15699,688,2


In [7]:
def get_weight(row):
    matches = df_interval_weights[
        (df_interval_weights['quarter'] == row['quarter']) &
        (row['time_left_qtr_sec'] >= df_interval_weights['start']) &
        (row['time_left_qtr_sec'] <= df_interval_weights['end'])
    ]
    if not matches.empty:
        return matches['num_highlights_per_game'].values[0]
    return

In [8]:
def get_game_intervals(df):
    df = df.reset_index(drop=True)
    start_positions = list((
        df.groupby("game_id")
        .apply(lambda g: g.index[0])
        .values
    ))
    end_positions = start_positions[1:] + [len(df)]

    return start_positions, end_positions

In [9]:
test_start_position, test_end_positions = get_game_intervals(altered_df_test)

for i,j in zip(test_start_position, test_end_positions):
    print(f'[{i},{j}]')

#unaltered_df_test.to_csv("/Users/galishai/PycharmProjects/AI_PROJECT_SPORTS_HIGHLIGHTS/Learning/full game highlight classification/check_test_set.csv")


[0,426]
[426,800]
[800,1217]
[1217,1601]
[1601,2011]
[2011,2406]
[2406,2787]
[2787,3174]
[3174,3547]
[3547,3933]
[3933,4309]
[4309,4726]
[4726,5118]
[5118,5512]
[5512,5888]
[5888,6282]
[6282,6656]
[6656,7034]
[7034,7426]
[7426,7809]
[7809,8195]
[8195,8612]
[8612,8976]
[8976,9359]
[9359,9734]
[9734,10132]
[10132,10543]
[10543,10905]
[10905,11287]
[11287,11664]
[11664,12011]
[12011,12410]
[12410,12768]
[12768,13178]
[13178,13563]
[13563,13974]
[13974,14365]
[14365,14748]
[14748,15144]
[15144,15483]
[15483,15897]
[15897,16270]
[16270,16646]
[16646,17019]
[17019,17401]
[17401,17771]
[17771,18197]
[18197,18618]
[18618,18996]
[18996,19412]
[19412,19782]
[19782,20131]
[20131,20522]
[20522,20896]
[20896,21263]
[21263,21640]
[21640,22037]
[22037,22410]
[22410,22774]
[22774,23099]
[23099,23491]
[23491,23853]
[23853,24245]
[24245,24633]
[24633,25030]
[25030,25378]
[25378,25746]
[25746,26115]
[26115,26484]
[26484,26859]
[26859,27254]
[27254,27644]
[27644,28012]
[28012,28426]
[28426,28799]
[28799,2

In [None]:
#WEIGHT_SCALE = 0
#unaltered_df_test.loc[:,'weighted_probs'] = unaltered_df_test['probs'] * (1+unaltered_df_test['weight']*WEIGHT_SCALE)

In [10]:
def select_unique_highlights_qtr(unaltered_quarter_df, num_to_select):


    temp_df = unaltered_quarter_df.groupby(['time_left_qtr_sec']).agg(max_prob = ('probs','max'))
    #print(temp_df.head())

    top_n_plays =  temp_df.nlargest(num_to_select, columns=['max_prob'])
    top_n_plays = top_n_plays.reset_index()
    #print(top_n_plays['time_left_qtr'].values)
    return top_n_plays['time_left_qtr_sec'].values

In [None]:
#unaltered_df_test.head()
#y_test

In [11]:
LEFT_CONTEXT_WINDOW = 0
RIGHT_CONTEXT_WINDOW = 4

all_preds=[]
all_labels=[]
pred_accuracies = []
test_start_position, test_end_positions = get_game_intervals(altered_df_test)
unaltered_df_test['predicted']=0
#print(test_game_ids)
#print(df.iloc[test_start_position[test_game_ids[1]-1]]['game_id'])
for i, (start, end) in enumerate(zip(test_start_position, test_end_positions)):
    quarters = ['1st','2nd','3rd','4th']
    for quarter in quarters:
        mask_q = (unaltered_df_test['quarter'] == quarter) & (unaltered_df_test.index >= start) & (unaltered_df_test.index < end)
        num_highlights_to_choose = highlights_per_qtr.get(quarter, 0)
        curr_qtr_df = unaltered_df_test[mask_q]
        if INC_FTS and INC_TIES:
            selected_play_times = select_unique_highlights_qtr(curr_qtr_df, highlights_per_qtr.get(quarter, 0))
            expanded_times = set()
            for t in selected_play_times:
                expanded_times.add(t)
                expanded_times.update(range(t-LEFT_CONTEXT_WINDOW, t + RIGHT_CONTEXT_WINDOW))

        else:
            raise NotImplementedError

        mask = unaltered_df_test['time_left_qtr_sec'].isin(expanded_times) & mask_q
        unaltered_df_test.loc[mask, 'predicted'] = 1
        #altered_df_game.loc[selected_plays.index, 'predicted'] = 1

    y_pred = unaltered_df_test.loc[start:end-1, 'predicted']
    y_truth = unaltered_df_test.loc[start:end-1, 'is_highlight']
    accuracy = (y_truth == y_pred).mean()
    pred_accuracies.append(accuracy)
    print(f"Game prediction accuracy: {accuracy:.3%}")
    all_preds.extend(y_pred)
    all_labels.extend(y_truth)



Game prediction accuracy: 89.671%
Game prediction accuracy: 92.513%
Game prediction accuracy: 91.367%
Game prediction accuracy: 92.708%
Game prediction accuracy: 94.146%
Game prediction accuracy: 91.392%
Game prediction accuracy: 94.488%
Game prediction accuracy: 95.349%
Game prediction accuracy: 92.493%
Game prediction accuracy: 94.041%
Game prediction accuracy: 96.011%
Game prediction accuracy: 91.847%
Game prediction accuracy: 93.112%
Game prediction accuracy: 94.924%
Game prediction accuracy: 94.415%
Game prediction accuracy: 91.624%
Game prediction accuracy: 90.107%
Game prediction accuracy: 91.270%
Game prediction accuracy: 94.133%
Game prediction accuracy: 92.167%
Game prediction accuracy: 91.969%
Game prediction accuracy: 94.484%
Game prediction accuracy: 92.857%
Game prediction accuracy: 94.517%
Game prediction accuracy: 94.400%
Game prediction accuracy: 94.724%
Game prediction accuracy: 91.971%
Game prediction accuracy: 95.580%
Game prediction accuracy: 91.099%
Game predictio

In [12]:


print(f"Game prediction Final Mean Accuracy: {(sum(pred_accuracies)/len(pred_accuracies)):.3%}")
print(classification_report(all_labels, all_preds, digits=3))
#unaltered_df_test['predicted'] = all_preds

Game prediction Final Mean Accuracy: 92.805%
              precision    recall  f1-score   support

           0      0.964     0.959     0.962     67111
           1      0.359     0.389     0.373      3908

    accuracy                          0.928     71019
   macro avg      0.661     0.674     0.668     71019
weighted avg      0.931     0.928     0.929     71019



In [None]:
#unaltered_df_game

In [None]:
unaltered_df_test = unaltered_df_test.drop(columns=['time_left_qtr_sec'])

unaltered_df_test.to_csv("/Users/galishai/PycharmProjects/AI_PROJECT_SPORTS_HIGHLIGHTS/Learning/full game highlight classification/predicted_output_test.csv", index=False)

# Data Exploration

In [None]:
altered_df_games = altered_df.iloc[test_start_position[0]:test_end_positions[3]].copy()
unaltered_df_games = unaltered_df.iloc[test_start_position[0]:test_end_positions[3]].copy()
altered_df_games

In [None]:
unaltered_df_games['time_left_qtr'] = altered_df_games['time_left_qtr']
unaltered_df_games['game_id'] = altered_df_games['game_id']
unaltered_df_games['next_play_type'] = unaltered_df_games.groupby(['game_id','quarter'])['play'].shift(-1)
unaltered_df_games.head()


In [None]:
game_highlights = unaltered_df_game.drop(unaltered_df_game[unaltered_df_game.is_highlight == 0].index)

num_unique_highlights = game_highlights.groupby(['quarter'])['time_left_qtr'].nunique()
num_unique_highlights

In [None]:
altered_df_all = altered_df.copy()
unaltered_df_all = unaltered_df.copy()

unaltered_df_all['time_left_qtr'] = altered_df_all['time_left_qtr']
unaltered_df_all['game_id'] = altered_df_all['game_id']

df_all_game_highlights = unaltered_df_all.drop(unaltered_df_all[unaltered_df_all.is_highlight == 0].index)

num_unique_highlights_per_game = df_all_game_highlights.groupby(['game_id', 'quarter'])['time_left_qtr'].nunique()

In [None]:
num_unique_highlights_per_game

In [None]:
all_games_highlights = unaltered_df_all.drop(unaltered_df_all[(unaltered_df_all.is_highlight == 0)].index)


In [None]:
unique_highlights_per_game_qtr = all_games_highlights.groupby(['game_id','quarter']).agg(avg_num_highlights = ('time_left_qtr','nunique'))
unique_highlights_per_game_qtr

## Unique highlights per quarter

In [None]:
unique_highlights_each_quarter = unique_highlights_per_game_qtr.groupby(['quarter']).agg(avg_highlights_qtr = ('avg_num_highlights', 'mean'))
unique_highlights_each_quarter

## Unique highlights per interval

In [None]:
interval_groups = np.arange(0,721,720)
interval_groups


In [None]:
per_interval = all_games_highlights.groupby(['game_id','quarter', pd.cut(all_games_highlights.time_left_qtr, interval_groups)]).agg(num_highlights = ('time_left_qtr','nunique'))

unique_players_per_interval = all_games_highlights.groupby(['game_id','quarter', pd.cut(all_games_highlights.time_left_qtr, interval_groups)]).agg(num_players = ('name','nunique'))



In [None]:
player_counts = unaltered_df['name'].value_counts()
rare_players = player_counts.sort_values(ascending=True)
common = player_counts[player_counts >=400].index
player_counts = player_counts.reset_index()
player_counts['name'] = player_counts['name'].where(player_counts['name'].isin(common), 'Other')
player_counts['name'].value_counts()

In [None]:
per_game = per_interval.groupby(['quarter', 'time_left_qtr']).agg(num_highlights=('num_highlights','mean'), unique_players=('unique_players','mean'))
per_game

In [None]:
context_window=2
unaltered_df_copy = unaltered_df.copy()
unaltered_df_copy['game_id'] = altered_df['game_id']
for i in range(context_window):
    unaltered_df_copy[f'play_{str(i + 1)}_after'] = unaltered_df_copy.groupby(['game_id', 'quarter'])['play'].shift(-(i+1))
    #dataset[f'play_{str(i + 1)}_before'] = dataset.groupby(['game_id', 'quarter'])['play'].shift((i+1))

    #begin team context

    shifted = (
    unaltered_df_copy
      .groupby(['game_id','quarter'])['current_team']
      .shift(-(i+1))
    )
    original = unaltered_df_copy['current_team']    # ← plain Series
    unaltered_df_copy[f'team_{str(i + 1)}_after'] = original == shifted

    #end team context


    unaltered_df_copy[f'play_{str(i + 1)}_after'] = unaltered_df_copy[f'play_{str(i + 1)}_after'].dropna().astype(int)
    unaltered_df_copy[f'team_{str(i + 1)}_after'] = unaltered_df_copy[f'team_{str(i + 1)}_after'].dropna().astype(int)

In [None]:
test2 = unaltered_df_copy.groupby(['game_id', 'quarter'])['current_team']
test2.head()

In [None]:
df_h = unaltered_df_copy[unaltered_df_copy['is_highlight'] == 1]

all_players = pd.concat([
    df_h['name'],
    df_h['assister'],
    df_h['stolen_by']
])

all_players = all_players[all_players != 'Blank']

star_power = all_players.value_counts()
star_power.index.name = 'name'
star_power = star_power.reset_index()
star_power


In [None]:
total_highlights = (unaltered_df_copy['is_highlight'] == 1).sum()
star_power['count'] = star_power['count']/total_highlights
star_power

In [None]:


mean, std = star_power['count'].mean(), star_power['count'].std()
star_power['count'] = (star_power['count'] - mean) / std

star_power

In [None]:
mapping = star_power.set_index('name')['count']

unaltered_df['name_star_power'] = unaltered_df['name'].map(mapping).fillna(0).astype(float)

unaltered_df['assister_star_power'] = unaltered_df['assister'].map(mapping).fillna(0).astype(float)

unaltered_df['stolen_by_star_power'] = unaltered_df['stolen_by'].map(mapping).fillna(0).astype(float)

unaltered_df.head(50)


In [None]:
df_test = get_dataset(data_path, verbose=False, rm_ft_ds=False, add_game_idx=True, compact_mode=True, play_context_window=2, team_context_window=2)
unalt_df1 = pd.read_csv(data_path)
df_test.head()

In [None]:
import json

with open('/Users/galishai/PycharmProjects/AI_PROJECT_SPORTS_HIGHLIGHTS/full season data/temp_rosters.json') as json_data:
    team_rosters_full = json.load(json_data)
    team_rosters_full

In [None]:
print((team_rosters_full))

In [None]:
unalt_df1['current_team']
oncourt_cols = [c for c in unalt_df1 if c.startswith('Oncourt_Player')]
team_star_power = {}
mapping_dict = mapping.to_dict()
team_star_power = {team : [0]*33 for team in team_rosters_full.keys()}
for team in team_rosters_full.keys():
    print(team)
    for i, player in enumerate(team_rosters_full[team]):
        if player in mapping_dict.keys():
            team_star_power[team][i] = mapping_dict[player]
            print(f'{player}, power: {mapping_dict[player]}')

def calculate_oncourt_star(row, team_rosters):
    team = row['current_team']
    star_power = 0
    for i, col in enumerate(oncourt_cols):
        if row[col] == 1:  # Player is on court
            if team_rosters[team][i] in mapping_dict.keys():
                star_power += mapping_dict[team_rosters[team][i]]  #
    return star_power

unalt_df1['oncourt_star_power'] = unalt_df1.apply(calculate_oncourt_star, args=(team_rosters_full,), axis=1)


#for col in oncourt_cols:
#    player_index = int(col.split('_')[-1])


In [None]:
unalt_df1.head(50)

# Test

In [None]:


dataset_test = pd.read_csv(data_path)
res = get_dataset(dataset_test,
                  verbose=False,
                  rm_ft_ds=False,
                  add_game_idx=True,
                  play_context_window=1,
                  team_context_window=1,
                  compact_players=1,
                  compact_oncourt=1,
                  compact_current_team=1,
                  drop_home_away_teams=1,
                  group_all_plays=1,
                  enum_quarters=1)


In [None]:
res.head()

In [None]:
dataset_test.head()

In [None]:
dataset_test = pd.read_csv(data_path)

In [None]:
res = get_dataset(dataset_test,
                  verbose=False,
                  rm_ft_ds=False,
                  add_game_idx=True,
                  play_context_window=0,
                  team_context_window=0,
                  compact_players=False,
                  compact_oncourt=False,
                  compact_current_team=False,
                  drop_home_away_teams=False,
                  group_all_plays=False,
                  enum_quarters=False)

In [None]:
res.head()