In [1]:
from env.SoccerActionsEnv import SoccerActionsEnv

import pandas as pd
import numpy as np

import lib.draw as draw
import matplotlib.pyplot as plt
from tqdm import tqdm
from random import randint, random

import optuna

In [2]:
def calculate_horizontal_position(x, y):
    if x < 0.75:
        if y < 0.25:
            return 0
        elif y < 0.5:
            return 1
        elif y < 0.75:
            return 2
        else:
            return 3
    else:
        if y < 0.2037:
            return 0
        elif y < 0.3653:
            return 1
        elif y < 0.50:
            return 2
        elif y < 0.6347:
            return 3
        elif y < 0.7963:
            return 4
        else:
            return 5

def calculate_square(x, y):
    if x < 0.1666:
        return 0 + calculate_horizontal_position(x, y)
    elif x < 0.3333:
        return 4 + calculate_horizontal_position(x, y)
    elif x < 0.5:
        return 8 + calculate_horizontal_position(x, y)
    elif x < 0.6666:
        return 12 + calculate_horizontal_position(x, y)
    elif x < 0.75:
        return 16 + calculate_horizontal_position(x, y)
    elif x < 0.8428:
        return 20 + calculate_horizontal_position(x, y)
    elif x < 0.9476:
        return 26 + calculate_horizontal_position(x, y)
    else:
        return 32 + calculate_horizontal_position(x, y)

def test_model(action, r, a):
    env = SoccerActionsEnv(randomized_start=True, end_on_xg=True)
    obs = env.reset()

    saving_rewards = []
    for i in tqdm(range(20000)):
        pos = calculate_square(obs[0], obs[1])
        obs, rewards, done, info = env.step([action[pos], r[pos], a[pos]])
        if done:
            saving_rewards.append(rewards)
            env.reset()

    return np.mean(saving_rewards)

In [3]:
# Testing random solutions.
res = []
for _ in range(20):
    xxs = [8.33, 8.33, 8.33, 8.33, 25, 25, 25, 25, 41.66, 41.66, 41.66, 41.66, 58.33, 58.33, 58.33, 58.33, 70.83, 70.83, 70.83, 70.83, 79.64, 79.64, 79.64, 79.64, 79.64, 79.64, 89.52, 89.52, 89.52, 89.52, 89.52, 89.52, 97.38, 97.38, 97.38, 97.38, 97.38, 97.38]
    yys = [12.5, 37.5, 62.5, 87.5, 12.5, 37.5, 62.5, 87.5, 12.5, 37.5, 62.5, 87.5, 12.5, 37.5, 62.5, 87.5, 12.5, 37.5, 62.5, 87.5, 10.19, 28.45, 43.27, 56.73, 71.55, 89.81, 10.19, 28.45, 43.27, 56.73, 71.55, 89.81, 10.19, 28.45, 43.27, 56.73, 71.55, 89.81]

    df = pd.DataFrame(xxs, columns=['x'])
    df['y'] = yys
    df['i'] = [i for i in range(38)]
    df['action'] = [randint(0, 1) for _ in range(38)]
    df['r'] = [random() for _ in range(38)]
    df['a'] = [random() for _ in range(38)]

    res.append(test_model(df.action, df.r, df.a))

100%|██████████| 20000/20000 [00:05<00:00, 3767.06it/s]
100%|██████████| 20000/20000 [00:05<00:00, 3861.41it/s]
100%|██████████| 20000/20000 [00:05<00:00, 3414.21it/s]
100%|██████████| 20000/20000 [00:07<00:00, 2855.17it/s]
100%|██████████| 20000/20000 [00:05<00:00, 3653.32it/s]
100%|██████████| 20000/20000 [00:05<00:00, 3604.76it/s]
100%|██████████| 20000/20000 [00:05<00:00, 3699.46it/s]
100%|██████████| 20000/20000 [00:06<00:00, 3307.73it/s]
100%|██████████| 20000/20000 [00:04<00:00, 4063.73it/s]
100%|██████████| 20000/20000 [00:05<00:00, 3574.26it/s]
100%|██████████| 20000/20000 [00:05<00:00, 3489.76it/s]
100%|██████████| 20000/20000 [00:05<00:00, 3648.74it/s]
100%|██████████| 20000/20000 [00:04<00:00, 4038.81it/s]
100%|██████████| 20000/20000 [00:05<00:00, 3389.69it/s]
100%|██████████| 20000/20000 [00:06<00:00, 3297.82it/s]
100%|██████████| 20000/20000 [00:06<00:00, 3045.39it/s]
100%|██████████| 20000/20000 [00:05<00:00, 3773.55it/s]
100%|██████████| 20000/20000 [00:04<00:00, 4194.

In [6]:
res, np.mean(res)

([0.009424858535869414,
  0.008734602896329179,
  0.008709480269200744,
  0.007146156173013849,
  0.010254079414527642,
  0.006477043866865235,
  0.007768977161797744,
  0.008304101965142118,
  0.010981788948600564,
  0.007131876298444237,
  0.007952305206155766,
  0.0073512640217286845,
  0.010349178410716215,
  0.008382657269042838,
  0.00804690247582992,
  0.005431808999909281,
  0.011213075458818117,
  0.010726977348035717,
  0.006896863240031581,
  0.006793860176454982],
 0.00840389290682569)