# Testing grounds

## Setup

### Imports

In [1]:
import os
import numpy as np
from main import load_agent_model, train
from gymnasium.wrappers.flatten_observation import FlattenObservation
from footsies_gym.envs.footsies import FootsiesEnv
from footsies_gym.wrappers.normalization import FootsiesNormalized
from footsies_gym.wrappers.action_comb_disc import FootsiesActionCombinationsDiscretized
from footsies_gym.wrappers.statistics import FootsiesStatistics
from importlib import reload

### Environment

In [2]:
human_testing_kwargs = {
    "frame_delay": 0,
    "fast_forward": False,
    "vs_player": True,
    "render_mode": "human",
}

normal_testing_kwargs = {
    "frame_delay": 0,
    "dense_reward": True,
}

In [3]:
footsies_env = FootsiesEnv(
    game_path="../Footsies-Gym/Build/FOOTSIES.x86_64",
    **normal_testing_kwargs,
    log_file=os.path.join(os.getcwd(), "out.log"),
    log_file_overwrite=True,
)

statistics = FootsiesStatistics(footsies_env)

env = FootsiesActionCombinationsDiscretized(
    FlattenObservation(
        FootsiesNormalized(
            statistics
        )
    )
)

## Environment testing

In [9]:
for e in range(5):
    print("Env reset")
    obs, info = env.reset()
    terminated, truncated = False, False

    while not (terminated or truncated):
        obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
        if reward != 0.0:
            print(reward)

Env reset
0.3
0.7
Env reset
0.3
0.7
Env reset
0.3
0.7
Env reset
0.3
0.7
Env reset
0.3
0.7


In [10]:
footsies_env.hard_reset()

## Brisket testing

In [4]:
import torch
from agents.brisket.agent import FootsiesAgent as BrisketAgent
from agents.brisket.loggables import get_loggables as get_brisket_loggables
from agents.logger import TrainingLoggerWrapper

For reloading in case changes were made

In [62]:
import agents.brisket.agent
import agents.logger
reload(agents.brisket.agent)
reload(agents.logger)

<module 'agents.logger' from '/home/martinho/projects/footsies-agents/agents/logger.py'>

In [5]:
brisket = BrisketAgent(
    observation_space=env.observation_space,
    action_space=env.action_space,
    
    # For testing
    # epsilon=0,
    # epsilon_decay_rate=0,
    # min_epsilon=0,
)

In [6]:
brisket = TrainingLoggerWrapper(
    brisket,
    10,
    cummulative_reward=True,
    win_rate=True,
    test_states_number=100,
    **get_brisket_loggables(brisket),
)

In [7]:
brisket.preprocess(env)

In [35]:
load_agent_model(brisket, "brisket")

Can't load agent, there was no agent saved!


In [20]:
obs, info = env.reset()

In [21]:
obs_t = torch.tensor(obs, dtype=torch.float32).reshape(1, -1)

In [22]:
obs_t

tensor([[ 1.0000,  1.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000, -0.4545,  0.4545]])

In [25]:
[brisket.agent.q_value(obs_t, brisket.agent.action_oh(a)) for a in range(8)]

[0.13985449075698853,
 0.07847494632005692,
 0.11667368561029434,
 0.04185425862669945,
 0.21139173209667206,
 0.13351216912269592,
 0.10901963710784912,
 0.15006758272647858]

### Brisket training

In [10]:
footsies_env.hard_reset()

In [9]:
train(brisket, env, 10)

100%|██████████| 10/10 [00:07<00:00,  1.29it/s]


In [15]:
sum(statistics.metric_special_moves_per_episode)

264

In [16]:
env.unwrapped.hard_reset()

In [17]:
obs, info = env.reset()

In [18]:
obs

array([ 0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  1.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.04166667,  0.04166667,
       -2.        ,  2.        ])

In [19]:
obs, _, _, _, _ = env.step(1)

## ...