In [1]:
import pandas as pd
import numpy as np
import jax.numpy as jnp
import jax
from jax import jit
from tqdm import tqdm
import matplotlib.pyplot as plt
import optax
from jax.scipy.stats import poisson

Poisson = jax.jit(poisson.pmf)
        
def coord_dict():
    coord_dict = {}
    for i in range(10):
        for j in range(10):
            if not 'HomeWin' in coord_dict and not 'AwayWin' in coord_dict and not 'Draw' in coord_dict:
                coord_dict['HomeWin'] = []
                coord_dict['Draw'] = []
                coord_dict['AwayWin'] = []
            if not 'BTTS_YES' in coord_dict and not 'BTTS_NO' in coord_dict:
                coord_dict['BTTS_YES'] = []
                coord_dict['BTTS_NO'] = []
            H = i
            A = j
            if H>A:
                coord_dict['HomeWin'].append((i,j))
            elif H==A:
                coord_dict['Draw'].append((i,j))
            else:
                coord_dict['AwayWin'].append((i,j))

            if H>0 and A>0:
                coord_dict['BTTS_YES'].append((i,j))
            else:
                coord_dict['BTTS_NO'].append((i,j))

            for ix in range(10):
                s = ix+0.5
                if not 'Under_'+str(s) in coord_dict and not 'Over_'+str(s) in coord_dict:
                    coord_dict['Under_'+str(s)] = []
                    coord_dict['Over_'+str(s)] = []
                if H+A<s:
                    coord_dict['Under_'+str(s)].append((i,j))
                else:
                    coord_dict['Over_'+str(s)].append((i,j))

            for ix in range(-8,8):
                s = ix+0.5
                if not 'AH_Home_'+str(s) in coord_dict and not 'AH_Away_'+str(s) in coord_dict:
                    coord_dict['AH_Home_'+str(s)] = []
                    coord_dict['AH_Away_'+str(s)] = []
                if H+s>A:
                    coord_dict['AH_Home_'+str(s)].append((i,j))
                else:
                    coord_dict['AH_Away_'+str(s)].append((i,j))
    for key in coord_dict:
        coord_dict[key] = tuple(zip(*coord_dict[key]))
    return coord_dict

cd = coord_dict()

@jit
def loss_fn(params, true):
    params = jax.nn.relu(params)
    home = Poisson(jnp.arange(10),params[0])
    away = Poisson(jnp.arange(10),params[1])
    matrix = jnp.outer(home,away)
    # odds1x2, ou_odds, ah_odds, bts_odds = odds
    global cd
    pred = jnp.concatenate([
    jnp.array([matrix[cd['HomeWin']].sum(),matrix[cd['Draw']].sum(),matrix[cd['AwayWin']].sum()]),
    jnp.array([(matrix[cd['Over_'+str(x[0])]].sum(), matrix[cd['Under_'+str(x[0])]].sum()) for x in ou_odds]).reshape(-1,),
    jnp.array([(matrix[cd['AH_Home_'+str(x[0])]].sum(), matrix[cd['AH_Away_'+str(x[0])]].sum()) for x in ah_odds]).reshape(-1,),
    jnp.array([matrix[cd['BTTS_YES']].sum(),matrix[cd['BTTS_NO']].sum()])])
    return (jnp.sqrt(((1/true-pred)**2).mean()))
    # return (jnp.sqrt(((true-1/pred)**2).mean()))
    # return (jnp.abs(1/true-pred)).mean()
    
grad_fn = jax.jit(jax.grad(loss_fn))

In [2]:
#### BOOKMAKERS ODDS

### MATCH https://www.betexplorer.com/soccer/england/premier-league-2021-2022/brentford-arsenal/863eg7q9/

odds1x2 = np.array([3.88, 3.27, 2.10])

ou_odds = np.array([[ 0.5,1.08,8.23],
                    [ 1.5,1.42,2.9],
                    [ 2.5,2.34,1.61],
                    [ 3.5,4.28,1.23],
                    [ 4.5,8.82,1.07],
                    [ 5.5,16.77,1.02],
                    [ 6.5,38.71,1.01],
                    [ 7.5,45.0,1.01]])

bts_odds = np.array([2.0,1.79])

ah_odds = np.array([[-2.5,34.25,1.02],
                    [-1.5, 9.59,1.07],
                    [-0.5,3.73,1.28],
                    [ 0.5,1.79,2.1 ],
                    [ 1.5,1.25,4.04],
                    [ 2.5,1.07,9.47],
                    [ 3.5,1.02,17.6 ]])

#### BOOKMAKERS ODDS

true = jnp.concatenate([odds1x2,ou_odds[:,1:].reshape(-1,),ah_odds[:,1:].reshape(-1,),bts_odds])

In [3]:
key = jax.random.PRNGKey(0)
params = jax.random.uniform(key, shape=(2,), minval=0.5, maxval=1.5)

train_loss = [1,]
valid_loss = []

es_params = [params,]

learning_rate = 0.1

optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)

# compute_loss = lambda params, x, y: optax.l2_loss(forward(params, X_train), y_train)
counter = 0
for epoch in range(1000):
    t_loss = loss_fn(params, true)
# loss_fn(params, X, y)
    if t_loss>=train_loss[-1]:
        counter+=1
    elif t_loss>=train_loss[-1]:
        counter=0    

    if counter>50:
        break

    es_params.append(params) 
    train_loss.append(t_loss)

    grads = grad_fn(params, true)
    if epoch%25==0:
        print('Epoch: ',epoch,', Loss: ',t_loss)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
params = es_params[np.array(train_loss).argmin()]

Epoch:  0 , Loss:  0.049465436
Epoch:  25 , Loss:  0.022535106
Epoch:  50 , Loss:  0.022016855
Epoch:  75 , Loss:  0.022010088
Epoch:  100 , Loss:  0.022002917
