In [1]:
import numpy as np
import pandas as pd

filename = 'nba_2014_games.csv'
results = pd.read_csv(filename)
results.loc[:5]

Unnamed: 0,Date,Unnamed: 1,Visitor/Neutral,PTS,Home/Neutral,PTS.1,Unnamed: 6,Notes
0,Tue Oct 29 2013,Box Score,Orlando Magic,87,Indiana Pacers,97,,
1,Tue Oct 29 2013,Box Score,Los Angeles Clippers,103,Los Angeles Lakers,116,,
2,Tue Oct 29 2013,Box Score,Chicago Bulls,95,Miami Heat,107,,
3,Wed Oct 30 2013,Box Score,Brooklyn Nets,94,Cleveland Cavaliers,98,,
4,Wed Oct 30 2013,Box Score,Atlanta Hawks,109,Dallas Mavericks,118,,
5,Wed Oct 30 2013,Box Score,Washington Wizards,102,Detroit Pistons,113,,


In [2]:
# 数据集清洗
results = pd.read_csv(filename, parse_dates=['Date'])
results.columns = ['Date', 'Score Type', 'Visitor Team', 'VisitorPts',
                  'Home Team', 'HomePts', 'OT?', 'Notes']
results.loc[:5]

Unnamed: 0,Date,Score Type,Visitor Team,VisitorPts,Home Team,HomePts,OT?,Notes
0,2013-10-29,Box Score,Orlando Magic,87,Indiana Pacers,97,,
1,2013-10-29,Box Score,Los Angeles Clippers,103,Los Angeles Lakers,116,,
2,2013-10-29,Box Score,Chicago Bulls,95,Miami Heat,107,,
3,2013-10-30,Box Score,Brooklyn Nets,94,Cleveland Cavaliers,98,,
4,2013-10-30,Box Score,Atlanta Hawks,109,Dallas Mavericks,118,,
5,2013-10-30,Box Score,Washington Wizards,102,Detroit Pistons,113,,


In [3]:
# 找出主场获胜的球队
results['HomeWin'] = results['VisitorPts'] < results['HomePts']
y_true = results['HomeWin'].values
results.loc[:5]
#print(y_true[-5:])

Unnamed: 0,Date,Score Type,Visitor Team,VisitorPts,Home Team,HomePts,OT?,Notes,HomeWin
0,2013-10-29,Box Score,Orlando Magic,87,Indiana Pacers,97,,,True
1,2013-10-29,Box Score,Los Angeles Clippers,103,Los Angeles Lakers,116,,,True
2,2013-10-29,Box Score,Chicago Bulls,95,Miami Heat,107,,,True
3,2013-10-30,Box Score,Brooklyn Nets,94,Cleveland Cavaliers,98,,,True
4,2013-10-30,Box Score,Atlanta Hawks,109,Dallas Mavericks,118,,,True
5,2013-10-30,Box Score,Washington Wizards,102,Detroit Pistons,113,,,True


In [4]:
print('Home Win percentage: %.1f' % (100 * results['HomeWin'].sum() / results['HomeWin'].count()) + '%')

Home Win percentage: 57.9%


In [5]:
results['HomeLastWin'] = False
results['VisitorLastWin'] = False
print(results.loc[:5])

        Date Score Type          Visitor Team  VisitorPts  \
0 2013-10-29  Box Score         Orlando Magic          87   
1 2013-10-29  Box Score  Los Angeles Clippers         103   
2 2013-10-29  Box Score         Chicago Bulls          95   
3 2013-10-30  Box Score         Brooklyn Nets          94   
4 2013-10-30  Box Score         Atlanta Hawks         109   
5 2013-10-30  Box Score    Washington Wizards         102   

             Home Team  HomePts  OT? Notes  HomeWin  HomeLastWin  \
0       Indiana Pacers       97  NaN   NaN     True        False   
1   Los Angeles Lakers      116  NaN   NaN     True        False   
2           Miami Heat      107  NaN   NaN     True        False   
3  Cleveland Cavaliers       98  NaN   NaN     True        False   
4     Dallas Mavericks      118  NaN   NaN     True        False   
5      Detroit Pistons      113  NaN   NaN     True        False   

   VisitorLastWin  
0           False  
1           False  
2           False  
3           Fal

In [6]:
# 创建字典，存储球队上次比赛的结果
from collections import defaultdict
won_last = defaultdict(int)

for index, row in results.iterrows():
    home_team = row['Home Team']
    visitor_team = row['Visitor Team']
    row['HomeLastWin'] = won_last[home_team]
    row['VisitorLastWin']  = won_last[visitor_team]
    results.loc[index] = row
    
    won_last[home_team] = row['HomeWin']
    won_last[visitor_team] = not row['HomeWin']
    
print(results.loc[20:25])

         Date Score Type            Visitor Team  VisitorPts  \
20 2013-11-01  Box Score         Milwaukee Bucks         105   
21 2013-11-01  Box Score              Miami Heat         100   
22 2013-11-01  Box Score     Cleveland Cavaliers          84   
23 2013-11-01  Box Score  Portland Trail Blazers         113   
24 2013-11-01  Box Score        Dallas Mavericks         105   
25 2013-11-01  Box Score       San Antonio Spurs          91   

             Home Team  HomePts  OT? Notes  HomeWin HomeLastWin VisitorLastWin  
20      Boston Celtics       98  NaN   NaN    False       False          False  
21       Brooklyn Nets      101  NaN   NaN     True       False          False  
22   Charlotte Bobcats       90  NaN   NaN     True       False           True  
23      Denver Nuggets       98  NaN   NaN    False       False          False  
24     Houston Rockets      113  NaN   NaN     True        True           True  
25  Los Angeles Lakers       85  NaN   NaN    False       False  

In [7]:
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=14)

In [8]:
from sklearn.model_selection import cross_val_score
x_previouswins = results[['HomeLastWin', 'VisitorLastWin']].values

scores = cross_val_score(clf, x_previouswins, y_true, scoring='accuracy')
print('Accuracy: %.1f' % (np.mean(scores)*100) + '%')

Accuracy: 57.4%


In [9]:
s_filename = 'nba_2013_standings.csv'
standings = pd.read_csv(s_filename, skiprows=[0,])
standings

Unnamed: 0,Rk,Team,Overall,Home,Road,E,W,A,C,SE,...,Post,≤3,≥10,Oct,Nov,Dec,Jan,Feb,Mar,Apr
0,1,Miami Heat,66-16,37-4,29-12,41-11,25-5,14-4,12-6,15-1,...,30-2,9-3,39-8,1-0,10-3,10-5,8-5,12-1,17-1,8-1
1,2,Oklahoma City Thunder,60-22,34-7,26-15,21-9,39-13,7-3,8-2,6-4,...,21-8,3-6,44-6,,13-4,11-2,11-5,7-4,12-5,6-2
2,3,San Antonio Spurs,58-24,35-6,23-18,25-5,33-19,8-2,9-1,8-2,...,16-12,9-5,31-10,1-0,12-4,12-4,12-3,8-3,10-4,3-6
3,4,Denver Nuggets,57-25,38-3,19-22,19-11,38-14,5-5,10-0,4-6,...,24-4,11-7,28-8,0-1,8-8,9-6,12-3,8-4,13-2,7-1
4,5,Los Angeles Clippers,56-26,32-9,24-17,21-9,35-17,7-3,8-2,6-4,...,17-9,3-5,38-12,1-0,8-6,16-0,9-7,8-5,7-7,7-1
5,6,Memphis Grizzlies,56-26,32-9,24-17,22-8,34-18,8-2,8-2,6-4,...,23-8,6-4,28-9,0-1,12-1,7-7,10-7,9-2,11-6,7-2
6,7,New York Knicks,54-28,31-10,23-18,37-15,17-13,10-6,12-6,15-3,...,22-10,7-5,31-12,,11-4,10-5,7-6,6-5,12-6,8-2
7,8,Brooklyn Nets,49-33,26-15,23-18,36-16,13-17,11-5,13-5,12-6,...,18-11,9-4,23-17,,11-4,5-11,11-4,7-5,8-7,7-2
8,9,Indiana Pacers,49-32,30-11,19-21,31-20,18-12,6-11,13-3,12-6,...,17-11,4-9,27-14,1-0,7-8,10-5,9-6,9-3,11-5,2-5
9,10,Golden State Warriors,47-35,28-13,19-22,19-11,28-24,7-3,5-5,7-3,...,17-13,5-3,20-18,1-0,8-6,12-4,8-7,4-8,9-7,5-3


In [13]:
results['HomeTeamRanksHigher'] = 0
for index, row in results.iterrows():
    home_team = results['Home Team']
    visitor_team = results['Visitor Team']
    if home_team == 'New Orleans Pelicans':
        home_team = 'New Orleans Hornets'
    elif visitor_team == 'New Orleans Pelicans':
        visitor_team = 'New Orleans Hornets'
#     home_team = ['New Orleans Hornets' if ele == 'New Orleans Pelicans' else ele for ele in home_team]
#     visitor_team = ['New Orleans Hornets' if ele == 'New Orleans Pelicans' else ele for ele in visitor_team]
    
    home_rank = standings[standings['Team']==home_team]['Rk'].values[0]
    visitor_rank = standings[standings['Team']==visitor_team]['Rk'].values[0]
    row['HomeTeamRanksHigher'] = int(home_rank > visitor_rank)
    results.loc[index] = row
results[:5]

ValueError: Arrays were different lengths: 30 vs 1319

In [11]:
x_homehigher = results[['HomeLastWin', 'VisitorLastWin', 'HomeTeamRanksHigher']].values
clf = DecisionTreeClassifier()
scores = cross_val_score(clf, x_homehigher, y_true, scoring='accuracy')
print('Using whether the home team is ranked higher')
print('Accuracy: {0:.1f}%'.format(np.mean(scores) * 100))

KeyError: "['HomeTeamRanksHigher'] not in index"

In [12]:
last_match_winner = defaultdict(int)
results['HomeTeamWonLast'] = 0

for index, row in results.iterrows():
    home_team = results['Home Team']
    visitor_team = results['Visitor Team']
    teams = tuple(sorted([home_team, visitor_team]))
    row['HomeTeamWonLast'] = 1 if last_match_winner[teams] == row['Home Team'] else 0
    results.loc[index] = row
    
    winner = row['Home Team'] if row['HomeWin'] else row['Visitor Team']
    last_match_winner[teams] = winner
results.loc[:5]

ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().