# Testing grounds

## Setup

### Imports

In [4]:
from main import load_agent_model, train
from gymnasium.wrappers.flatten_observation import FlattenObservation
from footsies_gym.envs.footsies import FootsiesEnv
from footsies_gym.wrappers.normalization import FootsiesNormalized
from footsies_gym.wrappers.action_comb_disc import FootsiesActionCombinationsDiscretized
from footsies_gym.wrappers.statistics import FootsiesStatistics

### Environment

In [5]:
footsies_env = FootsiesEnv(
    game_path="../Footsies-Gym/Build/FOOTSIES.x86_64",
    frame_delay=0
)

statistics = FootsiesStatistics(footsies_env)

env = FootsiesActionCombinationsDiscretized(
    FlattenObservation(
        FootsiesNormalized(
            statistics
        )
    )
)

## Brisket testing

In [6]:
import torch
from agents.brisket.agent import FootsiesAgent as BrisketAgent

In [7]:
brisket = BrisketAgent(
    observation_space=env.observation_space,
    action_space=env.action_space,
    
    # For testing
    epsilon=0,
    epsilon_decay_rate=0,
    min_epsilon=0,
    log_run=False,
)

In [8]:
load_agent_model(brisket, "brisket")

Agent loaded


In [6]:
obs, info = env.reset()

In [24]:
obs_t = torch.tensor(obs, dtype=torch.float32).reshape(1, -1)

In [25]:
obs_t

tensor([[ 0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  1.0000,
          1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0417,  0.0417,
         -2.0000,  2.0000]])

In [30]:
[brisket.q_value(obs_t, brisket.action_oh(a)) for a in range(8)]

[0.110465869307518,
 0.08165869116783142,
 0.11638332903385162,
 0.09342239797115326,
 0.12450563162565231,
 0.1109839603304863,
 0.10539047420024872,
 0.10376224666833878]

### Brisket training

In [9]:
train(brisket, env, 1000)

100%|██████████| 1000/1000 [20:13<00:00,  1.21s/it]


In [15]:
sum(statistics.metric_special_moves_per_episode)

264

In [16]:
env.unwrapped.hard_reset()

In [17]:
obs, info = env.reset()

In [18]:
obs

array([ 0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  1.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.04166667,  0.04166667,
       -2.        ,  2.        ])

In [19]:
obs, _, _, _, _ = env.step(1)

In [20]:
obs

array([ 0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.08333334,
       -2.        ,  2.        ])