# Testing grounds

## Setup

### Imports

In [92]:
import os
import numpy as np
from main import load_agent_model, train
from gymnasium.wrappers.flatten_observation import FlattenObservation
from footsies_gym.envs.footsies import FootsiesEnv
from footsies_gym.wrappers.normalization import FootsiesNormalized
from footsies_gym.wrappers.action_comb_disc import FootsiesActionCombinationsDiscretized
from footsies_gym.wrappers.statistics import FootsiesStatistics
from importlib import reload

### Environment

In [2]:
human_testing_kwargs = {
    "frame_delay": 0,
    "fast_forward": False,
    "vs_player": True,
    "render_mode": "human",
}

normal_testing_kwargs = {
    "frame_delay": 0
}

In [32]:
footsies_env = FootsiesEnv(
    game_path="../Footsies-Gym/Build/FOOTSIES.x86_64",
    **normal_testing_kwargs,
    log_file=os.path.join(os.getcwd(), "out.log"),
    log_file_overwrite=True,
)

statistics = FootsiesStatistics(footsies_env)

env = FootsiesActionCombinationsDiscretized(
    FlattenObservation(
        FootsiesNormalized(
            statistics
        )
    )
)

## Environment testing

In [3]:
for e in range(5):
    print("Env reset")
    obs, info = env.reset()
    print("A")
    terminated, truncated = False, False

    while not (terminated or truncated):
        obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
        if reward != 0.0:
            print(reward)

Env reset
A
0.3
0.7
Env reset
A
0.3
0.7
Env reset
A
-1.0
Env reset
A
-0.3
0.3
1.0
Env reset
A
0.3
0.7


In [4]:
env.close()

In [25]:
env.unwrapped.hard_reset()

## Brisket testing

In [117]:
import torch
from agents.brisket.agent import FootsiesAgent as BrisketAgent
from agents.brisket.loggables import get_loggables as get_brisket_loggables
from agents.logger import TrainingLoggerWrapper

For reloading in case changes were made

In [116]:
import agents.brisket.agent
import agents.logger
reload(agents.brisket.agent)
reload(agents.logger)

<module 'agents.logger' from '/home/martinho/projects/footsies-agents/agents/logger.py'>

In [118]:
brisket = BrisketAgent(
    observation_space=env.observation_space,
    action_space=env.action_space,
    
    # For testing
    # epsilon=0,
    # epsilon_decay_rate=0,
    # min_epsilon=0,
)

In [119]:
brisket = TrainingLoggerWrapper(
    brisket,
    10,
    cummulative_reward=True,
    win_rate=True,
    test_states_number=100,
    **get_brisket_loggables(brisket),
)

In [120]:
brisket.preprocess(env)

In [8]:
load_agent_model(brisket, "brisket")

Agent loaded


In [17]:
obs, info = env.reset()

In [24]:
obs_t = torch.tensor(obs, dtype=torch.float32).reshape(1, -1)

In [25]:
obs_t

tensor([[ 0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  1.0000,
          1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0417,  0.0417,
         -2.0000,  2.0000]])

In [30]:
[brisket.q_value(obs_t, brisket.action_oh(a)) for a in range(8)]

[0.110465869307518,
 0.08165869116783142,
 0.11638332903385162,
 0.09342239797115326,
 0.12450563162565231,
 0.1109839603304863,
 0.10539047420024872,
 0.10376224666833878]

### Brisket training

In [121]:
train(brisket, env, 100)

100%|██████████| 100/100 [00:57<00:00,  1.73it/s]


In [15]:
sum(statistics.metric_special_moves_per_episode)

264

In [16]:
env.unwrapped.hard_reset()

In [17]:
obs, info = env.reset()

In [18]:
obs

array([ 0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  1.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.04166667,  0.04166667,
       -2.        ,  2.        ])

In [19]:
obs, _, _, _, _ = env.step(1)

## ...