In [19]:
%load_ext autoreload
%autoreload 2

import numpy as np
import gymnasium as gym
from gymnasium import spaces
import wcst_gym
from wcst_gym.n_hot_observation import NHotObservation
from wcst_session import WcstSession


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Test the env

In [20]:
sess = WcstSession(1, 0)
env = gym.make("WcstSession-v0", wcst_session=sess)

In [21]:
env.reset(seed=None)

(array([[ 0,  6, 10],
        [ 1,  5,  9],
        [ 3,  4,  8],
        [ 2,  7, 11]]),
 {'current_rule': 10})

In [22]:
env.observation_space.shape

(4, 3)

In [23]:
env.step(3)

(array([[ 1,  6,  9],
        [ 3,  5, 10],
        [ 0,  4, 11],
        [ 2,  7,  8]]),
 0,
 False,
 False,
 {'current_rule': 10})

In [24]:
env.step(1)

(array([[ 3,  6,  9],
        [ 2,  4, 10],
        [ 1,  5,  8],
        [ 0,  7, 11]]),
 1,
 False,
 False,
 {'current_rule': 10})

In [25]:
env.step(2)

(array([[ 2,  5, 10],
        [ 1,  6,  8],
        [ 0,  7,  9],
        [ 3,  4, 11]]),
 0,
 False,
 False,
 {'current_rule': 10})

### Test the N-hot observation wrapper

In [26]:
env = NHotObservation(env)

In [27]:
env.reset()

(array([[0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]], dtype=int8),
 {'current_rule': 11})

### Test the N-hot observation wrapper interactions with VectorEnvs

In [24]:
test_sess = WcstSession(
    correct_value=1, 
    incorrect_value=0, 
)
def make_wcst():
    env = gym.make("WcstSession-v0", wcst_session=test_sess)
    env = NHotObservation(env)
    return env
envs = gym.vector.SyncVectorEnv([make_wcst])

In [25]:
envs.reset()

(array([[[1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
         [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1],
         [0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0],
         [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0]]], dtype=int8),
 {'current_rule': array([3]), '_current_rule': array([ True])})