In [1]:
from jax.config import config
config.update("jax_enable_x64", True)

import jax.numpy as jnp
import numpy as onp
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from tdata.datasets.oncourt_dataset import OnCourtDataset

In [3]:
dataset = OnCourtDataset()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.obj[key] = _infer_fill_value(value)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.obj[item] = s




In [4]:
df = dataset.get_stats_df()

In [45]:
to_use = df[df['year'] >= 2000]
to_use = to_use.dropna(subset=['W1S_1'])

spw_1 = (to_use['W1S_1'] + to_use['W2S_1']) / (to_use['W1SOF_1'] + to_use['W2SOF_1'])
spw_2 = (to_use['W1S_2'] + to_use['W2S_2']) / (to_use['W1SOF_2'] + to_use['W2SOF_2'])

margins = (spw_1 - spw_2).values

In [46]:
winners = to_use['winner'].values
losers = to_use['loser'].values
surface = to_use['surface'].values

In [47]:
from sklearn.preprocessing import LabelEncoder

In [48]:
surf_encoder = LabelEncoder()

In [49]:
surface_enc = surf_encoder.fit_transform(surface)

In [50]:
n_matches = winners.shape[0]

a_winner = onp.zeros((n_matches, len(surf_encoder.classes_)))
a_winner[onp.arange(n_matches), surface_enc] = 1
a_loser = a_winner.copy()

a_full = onp.concatenate([a_winner, -a_loser], axis=1)

In [51]:
a_full.shape

(37379, 10)

In [52]:
from jax import hessian, jacobian
from functools import partial

In [53]:
from ml_tools.jax import weighted_sum, logistic_normal_integral_approx
from jax.scipy.stats import multivariate_normal, norm
from jax import jit
from jax.scipy.special import expit

In [54]:
## DEFINED BY USER

@jit
def calculate_likelihood(x, a, theta, y):
    
    margin = y[0]
    
    margin_prob = norm.logpdf(margin, theta['factor'] * (a @ x) + theta['offset'], 
                              theta['obs_sd'])
    
    win_prob = jnp.log(expit(a @ x))
    
    return margin_prob + win_prob

@jit
def calculate_predictive_lik(x, a, cov_mat, theta, y):
    
    margin = y[0]
    
    latent_mean, latent_var = weighted_sum(x, cov_mat, a)
    
    margin_prob = norm.logpdf(margin, theta['factor'] * (latent_mean) + theta['offset'], 
                              jnp.sqrt(theta['obs_sd']**2 + theta['factor']**2 * latent_var))
    
    win_prob = jnp.log(logistic_normal_integral_approx(latent_mean, latent_var))
    
    return win_prob + margin_prob

@jit
def calculate_prior(x, mu, cov_mat, theta):
    
    return multivariate_normal.logpdf(x, mu, cov_mat)

@jit
def calculate_log_posterior(x, mu, cov_mat, a, theta, y):
    
    return calculate_likelihood(x, a, theta, y) + calculate_prior(x, mu, cov_mat, theta)

In [55]:
## General code

# TODO: Could group the functions etc. together into a namedtuple
# if JAX allows this.

@partial(jit, static_argnums=4)
def calculate_update(mu, cov_mat, a, y, elo_functions, elo_params):
        
    lik = elo_functions.predictive_lik_fun(mu, a, cov_mat, elo_params.theta, y)
    
    # Evaluate Jacobian and Hessian at the current guess
    mode_jac = elo_functions.log_post_jac_x(mu, mu, cov_mat, a, elo_params.theta, y)
    mode_hess = elo_functions.log_post_hess_x(mu, mu, cov_mat, a, elo_params.theta, y)
    
    # Get the updated guess from linearising
    new_x = -jnp.linalg.solve(mode_hess, mode_jac)
    
    return new_x + mu, lik

@partial(jit, static_argnums=4)
def compute_update(mu1, mu2, a, y, elo_functions, elo_params):
    
    mu = jnp.concatenate([mu1, mu2])
    cov_full = jnp.kron(jnp.eye(2), elo_params.cov_mat)
    new_mu, lik = calculate_update(mu, cov_full, a, y, elo_functions, elo_params)
    
    new_mu1, new_mu2 = jnp.split(new_mu, 2)
    
    return new_mu1, new_mu2, lik

In [56]:
from collections import namedtuple
from jax import grad

# Give it a go
EloFunctions = namedtuple('EloConfig', 'log_post_jac_x,log_post_hess_x,predictive_lik_fun')
EloParams = namedtuple('EloParams', 'theta,cov_mat')

In [57]:
from jax.ops import index_update

def update_ratings(carry, x, elo_functions, elo_params):
    
    cur_winner, cur_loser, cur_a, cur_y = x
    
    new_winner_mean, new_loser_mean, lik = compute_update(
        carry[cur_winner], carry[cur_loser], cur_a, cur_y, elo_functions, elo_params)
    
    carry = index_update(carry, cur_winner, new_winner_mean)
    carry = index_update(carry, cur_loser, new_loser_mean)
    
    return carry, lik

In [58]:
import numpy as onp

encoder = LabelEncoder()
encoder.fit(onp.concatenate([winners, losers]))

winners_array, losers_array = map(encoder.transform, [winners, losers])

In [59]:
from functools import partial, wraps
from jax.lax import scan

init = jnp.zeros((len(encoder.classes_), len(surf_encoder.classes_)))

@partial(jit, static_argnums=4)
def calculate_ratings_scan(winners_array, losers_array, a_full, y_full, elo_functions, elo_params):
    
    fun_to_scan = partial(update_ratings, elo_functions=elo_functions, elo_params=elo_params)

    ratings, liks = scan(fun_to_scan, init, [winners_array, losers_array, a_full, y_full])
    
    return ratings, jnp.sum(liks)

In [60]:
y_full = jnp.array([[x] for x in margins])

In [61]:
from ml_tools.jax import lo_tri_from_elements

def get_starting_elts(cov_mat):
    
    L = jnp.linalg.cholesky(cov_mat)
    elts = L[onp.tril_indices_from(L)]
    
    return elts

def cov_mat_from_elts(elts, n_latent, jitter=1e-6):
    
    cov_mat = lo_tri_from_elements(elts, n_latent)
    cov_mat = cov_mat @ cov_mat.T
    
    cov_mat = cov_mat + jnp.eye(n_latent) * jitter
    
    return cov_mat

In [62]:
n_latent = len(surf_encoder.classes_)

start_cov_mat = jnp.eye(n_latent)

starting_elts = get_starting_elts(start_cov_mat)
starting_elts = jnp.concatenate([starting_elts, jnp.array([0.1, 0.1, 0.1])])

In [63]:
from ml_tools.lin_alg import num_triangular_elts

ratings_lik = lambda *args: calculate_ratings_scan(*args)[1]

elo_params = EloParams(theta={'factor': 0.1, 'offset': 0.1, 'obs_sd': 0.1}, 
                       cov_mat=start_cov_mat)

elo_functions = EloFunctions(log_post_jac_x=jit(grad(calculate_log_posterior)),
                             log_post_hess_x=jit(hessian(calculate_log_posterior)), 
                             predictive_lik_fun=calculate_predictive_lik)

def update_params(x, params):
    
    cov_mat = cov_mat_from_elts(x[:num_triangular_elts(n_latent)], n_latent)
    theta = {'factor': x[-3]**2, 'offset': x[-2], 'obs_sd': x[-1]**2}
    
    params = EloParams(theta=theta, cov_mat=cov_mat)
    
    print(theta)
    print(cov_mat)
    
    return params

def to_optimise(x, start_params, functions):
        
    params = update_params(x, start_params)
    
    cur_lik = ratings_lik(winners_array, losers_array, a_full, y_full, functions, params)

    print(cur_lik)

    return -cur_lik

curried = partial(to_optimise, start_params=elo_params, functions=elo_functions)

to_opt_grad = jit(grad(curried))

In [64]:
to_opt_grad(starting_elts)

{'factor': Traced<ShapedArray(float64[])>with<JVPTrace(level=1/1)>, 'offset': Traced<ShapedArray(float64[])>with<JVPTrace(level=1/1)>, 'obs_sd': Traced<ShapedArray(float64[])>with<JVPTrace(level=1/1)>}
Traced<ShapedArray(float64[5,5])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float64[])>with<JVPTrace(level=1/1)>


DeviceArray([-8.46712359e+03, -1.12566725e+03, -2.18366122e+05,
             -8.41919559e+02, -1.15994590e+03, -7.20876647e+04,
             -6.93970397e+02, -5.47303605e+03, -2.33122653e+03,
             -2.77063528e+05, -8.47274335e+02, -2.98065276e+03,
             -1.54959721e+03, -4.55994707e+03, -1.07627834e+05,
             -1.38377981e+07, -2.62580943e+04, -1.39721682e+07])

In [65]:
from scipy.optimize import minimize

In [None]:
result = minimize(curried, starting_elts, jac=to_opt_grad, tol=1e-3)

{'factor': 0.010000000000000002, 'offset': 0.1, 'obs_sd': 0.010000000000000002}
[[1.000001 0.       0.       0.       0.      ]
 [0.       1.000001 0.       0.       0.      ]
 [0.       0.       1.000001 0.       0.      ]
 [0.       0.       0.       1.000001 0.      ]
 [0.       0.       0.       0.       1.000001]]
-620104.4893622485
{'factor': 0.6570534481979355, 'offset': 0.10134838615844509, 'obs_sd': 0.6682872901181472}
[[1.00087078e+00 5.78295621e-05 4.32524259e-05 3.56517471e-05
  4.35275199e-05]
 [5.78295621e-05 1.02255349e+00 6.02351006e-05 2.84200831e-04
  1.54779104e-04]
 [4.32524259e-05 6.02351006e-05 1.00741829e+00 1.20172842e-04
  7.98793287e-05]
 [3.56517471e-05 2.84200831e-04 1.20172842e-04 1.02865864e+00
  2.37544621e-04]
 [4.35275199e-05 1.54779104e-04 7.98793287e-05 2.37544621e-04
  1.01108528e+00]]
-65099.05943369346
{'factor': 718.0580299287989, 'offset': 0.05095539799862883, 'obs_sd': 830.8522720466985}
[[ 3.57254620e-01  1.33067815e-02  1.92682302e-02  1.54046

{'factor': 0.9319167808071877, 'offset': 0.007432767862805424, 'obs_sd': 0.5260726652029891}
[[ 0.77447917 -0.04418625  0.00801049  0.00551322 -0.00320577]
 [-0.04418625  0.02404104  0.01049534 -0.00594255 -0.01172901]
 [ 0.00801049  0.01049534  0.00854277 -0.00153463 -0.00798534]
 [ 0.00551322 -0.00594255 -0.00153463  0.2345662  -0.03090193]
 [-0.00320577 -0.01172901 -0.00798534 -0.03090193  0.03543613]]
-45232.938969768205
{'factor': 0.9290548035188502, 'offset': 0.005121271683782821, 'obs_sd': 0.5279366186146367}
[[ 0.86381146 -0.0064999   0.00527259  0.00426353  0.0051468 ]
 [-0.0064999   0.02178391  0.00121292 -0.00407303  0.00098929]
 [ 0.00527259  0.00121292  0.01773397  0.00314282  0.00301755]
 [ 0.00426353 -0.00407303  0.00314282  0.2292791  -0.00287396]
 [ 0.0051468   0.00098929  0.00301755 -0.00287396  0.01292466]]
-44783.15553911831
{'factor': 0.9292256480511698, 'offset': 0.003683684114523839, 'obs_sd': 0.5278395156619033}
[[ 0.8066758   0.00652019  0.00524129  0.00415781 

In [29]:
final_params = update_params(result.x, elo_params)

{'factor': 0.024625988956615653, 'offset': 0.10212920904257114, 'obs_sd': 0.08778945834230846}
[[0.27056187 0.12372393 0.17677553 0.16298591]
 [0.12372393 0.34363353 0.23048651 0.20608115]
 [0.17677553 0.23048651 0.22904195 0.20191787]
 [0.16298591 0.20608115 0.20191787 0.24036722]]


In [30]:
result.success

True

In [32]:
final_ratings, _ = calculate_ratings_scan(winners_array, losers_array, a_full, y_full, elo_functions, final_params)

In [33]:
fed = encoder.transform(['Roger Federer'])[0]

In [34]:
final_ratings[fed]

DeviceArray([3.00153266, 3.93462673, 3.81033389, 3.68553774])

In [36]:
final_params.cov_mat

DeviceArray([[0.27056187, 0.12372393, 0.17677553, 0.16298591],
             [0.12372393, 0.34363353, 0.23048651, 0.20608115],
             [0.17677553, 0.23048651, 0.22904195, 0.20191787],
             [0.16298591, 0.20608115, 0.20191787, 0.24036722]])

In [41]:
import pandas as pd

final_ratings = pd.DataFrame(final_ratings, index=encoder.classes_, columns=surf_encoder.classes_).drop('Robin Soderling')

In [42]:
from tpr.models.utils import to_elo_scale

In [43]:
final_ratings_elo = to_elo_scale(final_ratings)

In [44]:
final_ratings_elo.sort_values('indoor_hard', ascending=False)['indoor_hard'].round(0).astype(int).iloc[:10]

Novak Djokovic           2145
Rafael Nadal             2145
Roger Federer            2140
Daniil Medvedev          2053
Juan Martin Del Potro    1976
Dominic Thiem            1906
Milos Raonic             1873
Stefanos Tsitsipas       1864
Alexander Zverev         1860
Andrey Rublev            1853
Name: indoor_hard, dtype: int64