In [1]:
import numpy as np
import laserhockey.hockey_env as h_env
from TD3 import TD3

In [2]:
# Create the environment
env = h_env.HockeyEnv(mode=h_env.HockeyEnv.NORMAL)

# Basic opponents
weak_basic_opponent = h_env.BasicOpponent(weak=True).act
strong_basic_opponent = h_env.BasicOpponent(weak=False).act

# Rollout function
def rollout(p1, p2, num_games=10, render=False):
    counter = np.zeros(3)
    
    for _ in range(num_games):
        state_l, done = env.reset(), False
        state_r = env.obs_agent_two()
        while not done:
            if render:
                env.render()
            action_l = p1(state_l)
            action_r = p2(state_r)
            state_l, _, done, info = env.step(np.hstack([action_l, action_r]))
            state_r = env.obs_agent_two()
        counter[info["winner"] + 1] += 1
    env.close()
    
    wins = counter[2]
    defeats = counter[0]
    draws = counter[1]
    
    return wins, defeats, draws

In [3]:
# TD3 agents
def get_TD3_policy(name):
    policy = TD3(state_dim=18, action_dim=4, hidden_dim=256, max_action=1.0, normalize_obs=True)
    policy.load(f"./models/{name}")
    return policy.act

TD3_policy = get_TD3_policy('TD3')
SP_TD3_policy = get_TD3_policy('SP-TD3')
aSP_TD3_policy = get_TD3_policy('aSP-TD3')

### Select the players:

In [4]:
p1 = SP_TD3_policy
p2 = strong_basic_opponent

### Observe some games:

In [5]:
import time
time.sleep(5)
rollout(p1, p2, render=True)

(10.0, 0.0, 0.0)

### Print the win-rate:

In [6]:
num_games = 100
wins, defeats, draws = rollout(p1, p2, num_games)
print(
    f'# Games:   {num_games:8d}\n' +
    f'-------------------\n' +
    f'# Wins:    {int(wins):8d}\n' + 
    f'# Defeats: {int(defeats):8d}\n' +
    f'# Ties:    {int(draws):8d}\n' +
    f'-------------------\n' +
    f'=> ({wins/num_games:.2f}/{defeats/num_games:.2f}/{draws/num_games:.2f})'
)

# Games:        100
-------------------
# Wins:          98
# Defeats:        1
# Ties:           1
-------------------
=> (0.98/0.01/0.01)
