# Custom EloFunctions

This example walks through the creation of a custom Elo algorithm taking into account the different formats used in tennis matches.

Tennis matches on the men's professional circuit are broadly played in two formats: best of three sets, and best of five sets. Best of five set matches are longer and have been shown to be less prone to upsets, so the better player is more likely to win than in a best of three set match (see for example [here](https://www.degruyter.com/view/j/jqas.2018.14.issue-1/jqas-2017-0077/jqas-2017-0077.xml) for more details).

From a model perspective, rather than using the usual win probability of:

$p(win | \delta) = \textrm{logit}^{-1}(b \delta)$

where $b = \log(10) / 400$ to match Elo's win probability, we could instead model it as:

$p(win | \delta) = \textrm{logit}^{-1}(b * (1 + \textrm{is_bo5} * \textrm{bo5_factor}) * \delta)$

where is_bo5 is an indicator of whether it's a best of five match, and bo5_factor is the (most likely positive) addition when it's a best of five match.

In the following, we walk through how to specify `EloFunctions` to use with this framework to take this into account. To do this, we're going to build off the "basic.py" functions found in `jax_elo/elo_functions/basic.py` to keep things simple, so we won't be including a margin of victory.

In [1]:
# We first import what we'll need:
from functools import partial

import jax.numpy as jnp
from jax import jit, grad, hessian
from jax.scipy.special import expit
from jax.scipy.stats import multivariate_normal

from jax_elo.utils.normals import weighted_sum, logistic_normal_integral_approx
from jax_elo.core import EloFunctions, calculate_win_prob, get_starting_elts
from jax_elo.utils.flattening import reconstruct
from jax_elo.utils.linalg import num_mat_elts, pos_def_mat_from_tri_elts


# The pre-factor to switch to the Elo scale, as discussed
b = jnp.log(10) / 400.


# First, we specify the (log) likelihood.
@jit
def calculate_likelihood(x, mu, a, theta, y):
    
    # a @ x will give us the skill difference (see paper for details):
    delta = a @ x
    
    # Now we'll assume that y[0] is the bo5 indicator:
    is_bo5 = y[0]
    
    # Now we write down the log likelihood as discussed
    # We'll have to make sure to include the bo5 factor in the dictionary of parameters theta later.
    win_prob = jnp.log(
        expit(b * (1 + is_bo5 * theta['bo5_factor']) * delta))

    return win_prob

# Now, we need to modify the (log) marginal likelihood in a similar way.
@jit
def calculate_marginal_lik(x, mu, a, cov_mat, theta, y):

    # This gives the mean and variance of delta:
    latent_mean, latent_var = weighted_sum(x, cov_mat, a)
        
    # Define our new multiplier:
    is_bo5 = y[0]
    multiplier = b * (1 + is_bo5 * theta['bo5_factor'])
    
    # Multiplying a normal random variable with mean mu and variance sigma^2
    # by a factor c will yield a new normal distribution with mean c * mu and
    # variance c^2 sigma^2, so we use that to calculate the integral approximation:
    win_prob = jnp.log(logistic_normal_integral_approx(
        multiplier * latent_mean, multiplier**2 * latent_var))

    return win_prob

# The (log) prior here is multivariate normal to account for correlated skills:
@jit
def calculate_prior(x, mu, cov_mat, theta):

    return multivariate_normal.logpdf(x, mu, cov_mat)


# The log posterior is the log_prior plus the log_likelihood:
@jit
def calculate_log_posterior(x, mu, cov_mat, a, theta, y):

    return (calculate_likelihood(x, mu, a, theta, y) +
            calculate_prior(x, mu, cov_mat, theta))

# The parse_theta function has to make the dictionary theta from a flat vector
# to use in the optimisation.
def parse_theta(flat_theta, summary):
    
    # Use the reconstruct utility function to turn the flat array into a dictionary
    theta = reconstruct(flat_theta, summary, jnp.reshape)
    
    # We need to recover the covariance matrix from its triangular elements
    # This is necessary because we are optimising over covariance matrices,
    # and the lower triangular representation ensures that we will always
    # have valid positive definite matrices.
    cov_elts = theta['cov_mat']
    target_mat_size = num_mat_elts(len(cov_elts))
    cov_mat = pos_def_mat_from_tri_elts(cov_elts, target_mat_size)
    theta['cov_mat'] = cov_mat
    
    return theta

# Finally, we need to define a new win probability function:
def calculate_win_prob_bo5(mu1, mu2, a, y, elo_params):
    
    is_bo5 = y[0]
    pre_factor = b * (1 + is_bo5 * elo_params.theta['bo5_factor'])
    
    # We can use the usual function since it has a pre_factor argument:
    return calculate_win_prob(mu1, mu2, a, y, elo_params, 
                              pre_factor=pre_factor)

# Now we put these together into the Tuple of EloFunctions, using
# JAX to compute the Jacobian and Hessian needed for the update:
bo5_functions = EloFunctions(
    log_post_jac_x=jit(grad(calculate_log_posterior)),
    log_post_hess_x=jit(hessian(calculate_log_posterior)),
    marginal_lik_fun=calculate_marginal_lik,
    parse_theta_fun=parse_theta,
    win_prob_fun=jit(calculate_win_prob_bo5))



### Fitting the parameters

In [2]:
# Now let's try it. Get the data:
from jax_elo.utils.data import get_data

# Point this to where your tennis_atp dataset is
df = get_data('/Users/ingramm/Projects/tennis/tennis-data/data/sackmann/tennis_atp/')

# Use only some recent years
to_use = df[df['tourney_date'].dt.year >= 2016]

In [3]:
# Let's optimise the parameters:
# Let's not make it surface specific to start with
from jax_elo.core import optimise_elo, EloParams, calculate_ratings_history
from jax_elo.utils.encoding import encode_players

start_cov_mat = jnp.eye(1) * 100**2
# As discussed in the parse_theta section, we use the triangular
# elements of the cholesky decomposition of the covariance matrix
# to make sure it stays a valid covariance matrix.
start_cov_elts = get_starting_elts(start_cov_mat)

start_theta = {'bo5_factor': jnp.array(0.), 'cov_mat': start_cov_elts}

start_params = EloParams(theta=start_theta)

In [4]:
winner_ids, loser_ids, unique_players = encode_players(to_use['winner_name'].values, to_use['loser_name'].values)

In [5]:
n_matches = len(winner_ids)

# We're not using surfaces here, so a is just [1, -1] for each match:
a = jnp.stack([jnp.ones(n_matches), -jnp.ones(n_matches)], axis=1)

# y is more interesting. We need a best of five indicator.
# We'll just use slams here. This isn't quite right: some other tournaments are best of five.
# But let's keep it simple here.
is_bo5 = to_use['tourney_name'].isin(['Australian Open', 'Roland Garros', 'Wimbledon', 'US Open']).values

# We need a 2D array, so reshape:
y = is_bo5.reshape(-1, 1).astype(float)

In [6]:
# Now we're ready to optimise the parameters:
opt_params, opt_results = optimise_elo(start_params, bo5_functions, winner_ids, loser_ids, a, y,
                                       len(unique_players), tol=1e-5)

theta: {'bo5_factor': array(0.), 'cov_mat': DeviceArray([[10000.000001]], dtype=float64)}
cov_mat: [[10000.000001]]
theta: {'bo5_factor': Traced<ShapedArray(float64[])>with<JVPTrace(level=1/1)>
  with primal = Traced<ShapedArray(float64[]):JaxprTrace(level=-1/1)>
       tangent = Traced<ShapedArray(float64[]):JaxprTrace(level=0/1)>, 'cov_mat': Traced<ShapedArray(float64[1,1])>with<JVPTrace(level=1/1)>
  with primal = Traced<ShapedArray(float64[1,1]):JaxprTrace(level=-1/1)>
       tangent = Traced<ShapedArray(float64[1,1]):JaxprTrace(level=0/1)>}
cov_mat: Traced<ShapedArray(float64[1,1])>with<JVPTrace(level=1/1)>
  with primal = Traced<ShapedArray(float64[1,1]):JaxprTrace(level=-1/1)>
       tangent = Traced<ShapedArray(float64[1,1]):JaxprTrace(level=0/1)>
theta: {'bo5_factor': array(1.00962673), 'cov_mat': DeviceArray([[9994.50940957]], dtype=float64)}
cov_mat: [[9994.50940957]]
theta: {'bo5_factor': array(0.26884322), 'cov_mat': DeviceArray([[9998.53782003]], dtype=float64)}
cov_mat: 

In [7]:
# The estimated optimal bo5 factor
opt_params.theta['bo5_factor']

array(0.52662679)

In [8]:
# The estimated prior covariance matrix
# Here we only have one skill, so it's 1x1
opt_params.theta['cov_mat']

DeviceArray([[5704.06147268]], dtype=float64)

### Using the final parameters to predict & evaluate

In [9]:
history, final_ratings = calculate_ratings_history(to_use['winner_name'], to_use['loser_name'], a, y, 
                                                   bo5_functions, opt_params)

  0%|          | 0/10850 [00:00<?, ?it/s]
  0%|          | 0/10850 [00:00<?, ?it/s][A
  0%|          | 1/10850 [00:00<1:55:47,  1.56it/s][A
  0%|          | 36/10850 [00:00<1:20:56,  2.23it/s][A
  1%|          | 73/10850 [00:00<56:36,  3.17it/s]  [A
  1%|          | 135/10850 [00:00<39:29,  4.52it/s][A
  2%|▏         | 189/10850 [00:01<27:36,  6.44it/s][A
  2%|▏         | 253/10850 [00:01<19:17,  9.16it/s][A
  3%|▎         | 307/10850 [00:01<13:32, 12.98it/s][A
  4%|▎         | 380/10850 [00:01<09:29, 18.40it/s][A
  4%|▍         | 479/10850 [00:01<06:37, 26.07it/s][A
  5%|▌         | 572/10850 [00:01<04:39, 36.80it/s][A
  6%|▌         | 669/10850 [00:01<03:16, 51.72it/s][A
  7%|▋         | 751/10850 [00:01<02:21, 71.20it/s][A
  8%|▊         | 826/10850 [00:01<01:42, 97.39it/s][A
  8%|▊         | 899/10850 [00:02<01:15, 131.37it/s][A
  9%|▉         | 972/10850 [00:02<00:57, 170.75it/s][A
 10%|▉         | 1040/10850 [00:02<00:44, 219.69it/s][A
 10%|█         | 1108/10850

In [10]:
# As expected, the optimal factor is greater than zero.
# We can compare the win probability at a slam / not a slam:
p1 = 'Novak Djokovic'
p2 = 'Rafael Nadal'

p1_final_rating = final_ratings[p1]
p2_final_rating = final_ratings[p2]

a = jnp.array([1, -1])

# Comparing the two:
print('Best of five: ', jnp.round(
    calculate_win_prob_bo5(p1_final_rating, p2_final_rating, a, [1.], opt_params), 3))
print('Best of three: ', jnp.round(
    calculate_win_prob_bo5(p1_final_rating, p2_final_rating, a, [0.], opt_params), 3))

Best of five:  0.63
Best of three:  0.593


In [11]:
# How well does this predict?
win_probs = jnp.stack([x['prior_win_prob'] for x in history])
jnp.mean(jnp.log(win_probs))

DeviceArray(-0.62765106, dtype=float64)

In [12]:
from jax_elo.utils.elo import optimise_static_k, compute_elo_ratings

# What if we just fit basic Elo, ignoring slams?
k, success = optimise_static_k(to_use['winner_name'], to_use['loser_name'])
elo_pred = compute_elo_ratings(to_use['winner_name'], to_use['loser_name'], lambda _: k)
elo_win_probs = [x['winner_prob'] for x in elo_pred]
jnp.mean(jnp.log(jnp.array(elo_win_probs)))

DeviceArray(-0.62949445, dtype=float64)

In [13]:
import numpy as onp

# So we're just a bit better, as you'd expect. What about at slams?
result_df = to_use.copy()

result_df['bo5_elo'] = onp.array(win_probs)
result_df['elo'] = onp.array(elo_win_probs)

In [14]:
result_df['is_bo5'] = is_bo5

In [15]:
# So both predict better at slams, and the best of five version has a bigger edge there.

In [16]:
result_df.groupby('is_bo5').apply(lambda df: onp.mean(onp.log(df['elo']))).round(3)

is_bo5
False   -0.647
True    -0.557
dtype: float64

In [17]:
result_df.groupby('is_bo5').apply(lambda df: onp.mean(onp.log(df['bo5_elo']))).round(3)

is_bo5
False   -0.646
True    -0.551
dtype: float64