In [1]:
import os

import numpy as np
import pandas as pd
import tensorflow as tf


In [2]:
_data_dir = os.path.join('.', 'data')
_train_filename = os.path.join(_data_dir, 'play_prediction_train.csv')
_eval_filename = os.path.join(_data_dir, 'play_prediction_eval.csv')
_predict_filename = os.path.join(_data_dir, 'play_prediction_holdout.csv')

In [3]:
def load_prediction_data(filename):
    column_dtypes = {
        'gsis_id': np.int32,
        'drive_id': np.int32,
        'play_id': np.int32,
        'home_score': np.int32,
        'away_score': np.int32,
        'yardline': np.int32,
        'quarter': np.int32,
        'clock': np.int32,
        'down': np.int32,
        'yards_to_go': np.int32
    }
    df = pd.read_csv(filename)
    
    # Change identity categoricals to 0-based
    df.quarter = df.quarter - 1
    df.down = df.down - 1
    df = df.loc[df.pos_team == 'NE']
    df = df.reset_index()
    
    return df

In [4]:
train_data = load_prediction_data(_train_filename)
train_labels = train_data.pop('play_type')

eval_data = load_prediction_data(_eval_filename)
eval_labels = eval_data.pop('play_type')

train_data.head()

Unnamed: 0,index,gsis_id,drive_id,play_id,home_team,away_team,home_score,away_score,quarter,clock,down,yards_to_go,yardline,pos_team
0,61,2010102410,16,2671,SD,NE,3,19,2,512,2,1,49,NE
1,75,2015111509,9,1665,NYG,NE,10,10,1,780,0,10,-10,NE
2,148,2013121504,11,2051,MIA,NE,7,10,2,130,1,5,-11,NE
3,197,2013092912,11,2371,ATL,NE,10,10,2,463,0,10,32,NE
4,257,2016010301,4,546,MIA,NE,3,0,0,413,0,10,-30,NE


In [5]:
train_data.size

67116

In [6]:
unique_labels = ['rush', 'pass', 'punt', 'field goal']

In [7]:
train_input_fn = tf.estimator.inputs.pandas_input_fn(x=train_data, y=train_labels, shuffle=True, batch_size=32, num_epochs=None)
eval_input_fn = tf.estimator.inputs.pandas_input_fn(x=eval_data, y=eval_labels, shuffle=False, num_epochs=1)

In [8]:
numeric_features = [
    tf.feature_column.numeric_column('home_score'),
    tf.feature_column.numeric_column('away_score'),
    tf.feature_column.numeric_column('clock'),
    tf.feature_column.numeric_column('yards_to_go'),
    tf.feature_column.numeric_column('yardline')
]

categorical_features = [
    tf.feature_column.categorical_column_with_identity('quarter', num_buckets=4),
    tf.feature_column.categorical_column_with_identity('down', num_buckets=4)
]

features = numeric_features + [tf.feature_column.indicator_column(x) for x in categorical_features]

In [9]:
estimator = tf.estimator.DNNClassifier(feature_columns=features,
                                       hidden_units=[10, 10],
                                       n_classes=4,
                                       label_vocabulary=unique_labels,
                                       model_dir='./model')

estimator.train(train_input_fn, steps=50000)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000002AE6BAE2D30>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./model\model.ckpt-50264
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 50265 into ./model\model.ckpt

<tensorflow.python.estimator.canned.dnn.DNNClassifier at 0x2ae6bae28d0>

In [10]:
estimator.evaluate(eval_input_fn)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-05-02-18:06:07
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./model\model.ckpt-100264
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-05-02-18:06:08
INFO:tensorflow:Saving dict for global step 100264: accuracy = 0.64136326, average_loss = 0.6582385, global_step = 100264, loss = 77.253265


{'accuracy': 0.64136326,
 'average_loss': 0.6582385,
 'loss': 77.253265,
 'global_step': 100264}

In [11]:
predict_data = load_prediction_data(_predict_filename)
predict_labels = predict_data.pop('play_type')

predict_input_fn = tf.estimator.inputs.pandas_input_fn(predict_data, predict_labels, shuffle=False, num_epochs=1)

predictions = [row['classes'] for row in list(estimator.predict(predict_input_fn))]
predictions = pd.DataFrame(predictions, columns=['prediction'])

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./model\model.ckpt-100264
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


In [13]:
joined = predictions.join(predict_labels).join(predict_data)
joined.head()

Unnamed: 0,prediction,play_type,index,gsis_id,drive_id,play_id,home_team,away_team,home_score,away_score,quarter,clock,down,yards_to_go,yardline,pos_team
0,b'rush',pass,19,2012123001,18,3497,NE,MIA,27,0,3,335,0,10,27,NE
1,b'pass',rush,33,2009121308,24,4221,NE,CAR,20,10,3,866,1,10,40,NE
2,b'pass',pass,90,2010102410,16,2275,SD,NE,3,13,2,5,0,10,-29,NE
3,b'pass',rush,98,2011101610,8,1264,NE,DAL,3,3,1,92,0,8,42,NE
4,b'pass',pass,104,2011121811,12,1992,DEN,NE,16,24,1,876,0,10,-1,NE


In [14]:
joined.loc[joined.play_type == 'field goal']

Unnamed: 0,prediction,play_type,index,gsis_id,drive_id,play_id,home_team,away_team,home_score,away_score,quarter,clock,down,yards_to_go,yardline,pos_team
41,b'field goal',field goal,1074,2012120207,10,1794,MIA,NE,3,17,1,482,3,10,25,NE
51,b'field goal',field goal,1275,2011110609,19,2818,NE,NYG,3,10,2,567,3,3,36,NE
61,b'pass',field goal,1701,2011091200,10,2192,MIA,NE,7,14,1,893,2,21,20,NE
76,b'field goal',field goal,2073,2012101408,17,3476,SEA,NE,10,23,3,335,3,2,33,NE
108,b'field goal',field goal,2792,2013120802,21,3920,NE,CLE,14,19,3,552,3,11,18,NE
214,b'field goal',field goal,6015,2012082951,9,1227,NYG,NE,0,3,1,393,3,2,48,NE
223,b'rush',field goal,6197,2015102503,13,2505,NE,NYJ,16,10,2,226,3,6,44,NE
232,b'field goal',field goal,6419,2011101610,8,1364,NE,DAL,6,3,1,170,3,7,43,NE
260,b'field goal',field goal,7067,2009121308,22,3864,NE,CAR,20,10,3,659,3,9,21,NE
264,b'field goal',field goal,7166,2013102703,8,1659,NE,MIA,3,14,1,673,3,3,34,NE
