In [None]:
import jax.numpy as jnp
from jax import jit, grad, tree_multimap
import optax
from jax_util import optimise, get_target, get_log_loss, predict_proba, update_ratings, scan_dataset, negative_average_log_loss, EloRatingNet

__EPS__ = 1e-12
learning_rate = 0.01
n_gradient_steps = 500

In [None]:
# Example
teamA_rating = 2.1
teamB_rating = 1.1
allow_draw = True
params = {"gamma": 0.1, "beta": 1, "alpha": 0}
scores = (3, 1)

y = get_target(scores)
probabilities = predict_proba(params, teamA_rating, teamB_rating, allow_draw)

print(f"The probability of team A to win is {probabilities[0].round(2)*100}%, team B to win is {probabilities[2].round(2)*100}%, and the draw is {probabilities[1].round(2)*100}%")
print(f"Team A wins the game, the log loss is {get_log_loss(y,probabilities)}")

In [None]:
# Example
rating = jnp.array([100, 100], dtype=float)
teamA_idx = 0
teamB_idx = 1

params = {"gamma": 0.1, "beta": 1, "alpha": 0}
params["kappa"] = 1

y = 1
teamA_idx = 0
teamB_idx = 1
allow_draw = True

new_rating = update_ratings(params, teamA_idx, teamB_idx, y, rating, allow_draw)
print(f"Rating was {rating}, the new rating is {new_rating}")

In [None]:
# Example

# init the model
elo_model = EloRatingNet(n_teams=2)
params = elo_model.init_params()

# create toy dataset with 2 teams and 2 matches. 
# first, team A (index 0) plays team B (index 1) and won 3-1.
# second, team B plays team A and it is a draw.
toy_dataset = {
    "team_index": jnp.array([[0, 1], [1, 0]]),
    "scores": jnp.array([[3.0, 1.0], [1.0, 1.0]]),
}

output = scan_dataset(params, toy_dataset)
print(f"The log-loss of the two matches is {output['loss_history']}")

loss = negative_average_log_loss(params, toy_dataset)
print(f"The dataset negative log loss is {loss}")

In [None]:
# Jax magic gradient
negative_average_log_loss_grad_fn = jit(grad(negative_average_log_loss))

grads = negative_average_log_loss_grad_fn(params, toy_dataset)
grads

In [None]:
# Example of one gradient descent step "handmade"

# init params
params = elo_model.init_params()

for key, val in params.items():
    if isinstance(params[key], list):
        # update r0 list parameters
        params[key] = jnp.array([v - learning_rate * grads[key][k] for k, v in enumerate(params[key])])
    else:
        # update float parameters
        params[key] = val - learning_rate * grads[key]
print("Parameters after the handmade gradient step:")
params

In [None]:
# Example of one gradient descent step with tree_multimap

# init params
params = elo_model.init_params()
params = tree_multimap(lambda p, g: p - learning_rate * g, params, grads)
print("Parameters after the tree_multimap gradient step:")
params

In [None]:
# Example of one gradient descent step with optax

# init params
params = elo_model.init_params()
# 1. pick an optimisor method: we use classical sgd as previously
tx = optax.sgd(learning_rate=learning_rate)
# 2. init the optimisor state
opt_state = tx.init(params)
# 3. update: this step returns the updates pytree of the parameters and the new opt_state.
updates, opt_state = tx.update(grads, opt_state)
# 4. update the parameters. This step is equivalent to jax.tree_map(lambda x,y: x+y, params, updates)
params = optax.apply_updates(params, updates)
print("Parameters after the Optax gradient step:")
params

In [None]:
# example: 500 gradient steps  on our toy_dataset

elo_model = EloRatingNet(n_teams=2)
optimiser = optax.sgd(learning_rate=0.01)

params = elo_model.init_params()
params = optimise(params, optimiser, toy_dataset)

In [None]:
import pandas as pd
from jax_util import EloDataset

# load the data
data = (
    pd.read_csv("../input/barclays-premier-league/match.csv")
    .sort_values("match_date")
    .iloc[1:]
)

# map team's index into names
clubs = pd.read_csv("../input/barclays-premier-league/club.csv")
clubs.index = clubs.club_id
data["home_team_name"] = clubs.loc[data["home_team_id"], "club_name"].values
data["away_team_name"] = clubs.loc[data["away_team_id"], "club_name"].values

# init the EloDataset. Note taht we also pass the time index.
football_data = EloDataset(
    valid_date="2014-06-01", # validation set starts here, train is all dates before
    test_date="2018-06-01", # test set starts here
    time=pd.DatetimeIndex(data["match_date"]),
)

# prepare the data
football_data.prepare_data(
    data[["home_team_name", "away_team_name"]],
    data[["home_team_goals", "away_team_goals"]],
)


In [None]:
# init the class. Draw is possible in football.
model = EloRatingNet(allow_draw=True)

# chain/customise the optimiser
optimiser = optax.chain(
    optax.sgd(learning_rate=0.05), # we use sgd
    optax.keep_params_nonnegative() # we constrain the parameters
)

# fit the model
model.fit_parameters(
    football_data, optimiser, max_step=1000, early_stopping=100, verbose=50
)

model.ratings.head()

In [None]:
model.plot_rating_history(['Liverpool','Manchester City'])
model.predict_proba('Liverpool','Manchester City')

# sgdr_schedule: schedule multiple warmup_cosine_decay_schedule using a list of dict of parameters.
n_schedules = 10
schedule = optax.sgdr_schedule(
    [
        dict(
            init_value=0.05 / ((i + 1)), # initial value of the lr
            peak_value=0.3 / (i + 1),# peak value of the lr
            warmup_steps=30,# step value of the peak_value
            decay_steps=100,# step value of the end_value
            end_value=0.05 / ((i + 1) + 1),# end value of the lr
        )
        for i in range(1, n_schedules)
    ]
)

# plot
n_gradient_steps = 1000
_ = (
    pd.Series([schedule(i) for i in range(n_gradient_steps)])
    .astype(float)
    .plot(
        title="Chained warmup cosine decay schedule",
        xlabel="gradient step",
        ylabel="learnin rate",
        figsize=(11, 7),
    )
)