In [1]:
import pandas as pd
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error

In [2]:
teams = pd.read_csv("teams.csv")

In [3]:
print(teams.head())

  team      country  year  events  athletes   age  height  weight  medals  \
0  AFG  Afghanistan  1964       8         8  22.0   161.0    64.2       0   
1  AFG  Afghanistan  1968       5         5  23.2   170.2    70.0       0   
2  AFG  Afghanistan  1972       8         8  29.0   168.3    63.8       0   
3  AFG  Afghanistan  1980      11        11  23.6   168.4    63.2       0   
4  AFG  Afghanistan  2004       5         5  18.6   170.8    64.8       0   

   prev_medals  prev_3_medals  
0          0.0            0.0  
1          0.0            0.0  
2          0.0            0.0  
3          0.0            0.0  
4          0.0            0.0  


In [4]:
teams = teams[['team', 'country', 'year', 'athletes', 'age', 'prev_medals', 'medals']]

In [5]:
print(teams.head())

  team      country  year  athletes   age  prev_medals  medals
0  AFG  Afghanistan  1964         8  22.0          0.0       0
1  AFG  Afghanistan  1968         5  23.2          0.0       0
2  AFG  Afghanistan  1972         8  29.0          0.0       0
3  AFG  Afghanistan  1980        11  23.6          0.0       0
4  AFG  Afghanistan  2004         5  18.6          0.0       0


In [6]:
teams.corr()["medals"]

year          -0.021603
athletes       0.840817
age            0.025096
prev_medals    0.920048
medals         1.000000
Name: medals, dtype: float64

In [7]:
print(teams[teams.isnull().any(axis=1)])

     team                           country  year  athletes   age  \
19    ALB                           Albania  1992         9  25.3   
26    ALG                           Algeria  1964         7  26.0   
39    AND                           Andorra  1976         3  28.3   
50    ANG                            Angola  1980        17  17.4   
59    ANT               Antigua and Barbuda  1976        17  23.2   
...   ...                               ...   ...       ...   ...   
2092  VIN  Saint Vincent and the Grenadines  1988         6  20.5   
2103  YAR                       North Yemen  1984         3  27.7   
2105  YEM                             Yemen  1992         8  19.6   
2112  YMD                       South Yemen  1988         5  23.6   
2120  ZAM                            Zambia  1964        15  21.7   

      prev_medals  medals  
19            NaN       0  
26            NaN       0  
39            NaN       0  
50            NaN       0  
59            NaN       0  
...

In [8]:
teams = teams.dropna()

In [9]:
print(teams.shape)

(2014, 7)


In [10]:
train = teams[teams["year"] < 2012].copy()
test = teams[teams["year"] >= 2012].copy()

In [11]:
reg = LinearRegression()

In [12]:
input = ["athletes", "prev_medals"]
output = ["medals"]

In [13]:
reg.fit(train[input], train[output])

LinearRegression()

In [14]:
predictions = reg.predict(test[input])
print(predictions)

[[-9.61221245e-01]
 [-1.17633261e+00]
 [-1.42503158e+00]
 [-1.71184673e+00]
 [ 2.15562926e+00]
 [ 3.91463636e+00]
 [-1.71184673e+00]
 [-1.85525431e+00]
 [ 3.67563128e-01]
 [-2.77770967e-01]
 [-1.85525431e+00]
 [-1.49673537e+00]
 [ 4.67519911e+01]
 [ 2.87550937e+01]
 [ 4.58450091e+00]
 [ 2.54773581e+00]
 [-1.85525431e+00]
 [-1.64014295e+00]
 [-1.85525431e+00]
 [-1.85525431e+00]
 [ 1.46556876e+02]
 [ 1.20571799e+02]
 [ 6.56314795e+00]
 [ 3.95275254e+00]
 [ 7.34283247e+00]
 [ 1.03117468e+01]
 [ 5.19171882e+00]
 [ 3.58517645e+00]
 [-1.64014295e+00]
 [-1.64014295e+00]
 [-1.56843916e+00]
 [-1.20992022e+00]
 [-1.71184673e+00]
 [-1.42503158e+00]
 [ 1.17929959e+01]
 [ 1.00049298e+01]
 [-1.78355052e+00]
 [-1.71184673e+00]
 [-1.56843916e+00]
 [-1.56843916e+00]
 [-1.99866189e+00]
 [-1.99866189e+00]
 [-1.56843916e+00]
 [-1.35332779e+00]
 [-1.92695810e+00]
 [-1.92695810e+00]
 [ 3.28912706e+01]
 [ 2.53042547e+01]
 [-1.78355052e+00]
 [-1.28162400e+00]
 [-1.85525431e+00]
 [-3.87590939e-01]
 [ 7.8348077

In [15]:
test["predictions"] = predictions

In [16]:
test.loc[test["predictions"] < 0, "predictions"] = 0

In [17]:
test["predictions"] = test["predictions"].round()

In [18]:
print(test["predictions"])

6       0.0
7       0.0
24      0.0
25      0.0
37      2.0
       ... 
2111    0.0
2131    0.0
2132    0.0
2142    2.0
2143    0.0
Name: predictions, Length: 405, dtype: float64


In [19]:
error = mean_absolute_error(test["medals"], test["predictions"])

In [20]:
print(error)

3.2987654320987656
