### Test the current impl of WCST session with an extended feature dimension

In [8]:
%load_ext autoreload
%autoreload 2

import numpy as np

from wcst_session import WcstSession
from block_switching_conditions import short_condition, create_bernoulli_condition
from card_generators import RandomCardGenerator
from rule_generators import RandomRuleGeneratorHuman
from constants import EXTENDED_DIM_FEATURE_NAMES, EXTENEDED_DIM_NAMES


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:

sess = WcstSession(
    correct_value=1, 
    incorrect_value=0, 
    block_switching_condition=short_condition, 
    card_generator=RandomCardGenerator(None, num_cards=4, num_dims=4),
    rule_generator=RandomRuleGeneratorHuman(None, num_rules=16, num_dims=4),
    feature_names=EXTENDED_DIM_FEATURE_NAMES,
    dim_names=EXTENEDED_DIM_NAMES,
    random_seed=42
)

In [10]:
sess.get_cards_text()

array([['SQUARE', 'MAGENTA', 'SWIRL', 'SOLID'],
       ['STAR', 'YELLOW', 'POLKADOT', 'DOTTED'],
       ['TRIANGLE', 'GREEN', 'ESCHER', 'DASHED'],
       ['CIRCLE', 'CYAN', 'RIPPLE', 'DASHDOT']], dtype='<U8')

In [11]:
sess.make_selection(2)

(False, 0)

In [13]:
sess.get_current_rule_text()

'SOLID'

In [14]:
sess.get_cards_text()

array([['TRIANGLE', 'MAGENTA', 'SWIRL', 'DASHDOT'],
       ['STAR', 'YELLOW', 'POLKADOT', 'DASHED'],
       ['CIRCLE', 'CYAN', 'RIPPLE', 'DOTTED'],
       ['SQUARE', 'GREEN', 'ESCHER', 'SOLID']], dtype='<U8')

In [15]:
sess.make_selection(3)

(True, 1)

In [16]:
sess.dump_history()

Unnamed: 0,TrialNumber,BlockNumber,TrialAfterRuleChange,Response,ItemChosen,CurrentRule,Reward,Item0Shape,Item0Color,Item0Pattern,...,Item1Pattern,Item1Outline,Item2Shape,Item2Color,Item2Pattern,Item2Outline,Item3Shape,Item3Color,Item3Pattern,Item3Outline
0,0,0,0,Incorrect,2,SOLID,0,SQUARE,MAGENTA,SWIRL,...,POLKADOT,DOTTED,TRIANGLE,GREEN,ESCHER,DASHED,CIRCLE,CYAN,RIPPLE,DASHDOT
1,1,0,1,Correct,3,SOLID,1,TRIANGLE,MAGENTA,SWIRL,...,POLKADOT,DASHED,CIRCLE,CYAN,RIPPLE,DOTTED,SQUARE,GREEN,ESCHER,SOLID


### Check if gym env works as well:

In [22]:
import gymnasium as gym
from gymnasium import spaces
import wcst_gym
from wcst_gym.n_hot_observation import NHotObservation


In [18]:
env = gym.make("WcstSession-v0", wcst_session=sess)

In [19]:
env.reset(seed=None)

(array([[ 1,  7, 10, 14],
        [ 0,  4,  9, 13],
        [ 2,  6, 11, 15],
        [ 3,  5,  8, 12]]),
 {'current_rule': 3})

In [20]:
env.observation_space.shape

(4, 4)

In [21]:
env.step(3)

(array([[ 0,  6, 10, 14],
        [ 2,  4,  8, 15],
        [ 3,  5, 11, 12],
        [ 1,  7,  9, 13]]),
 1,
 False,
 False,
 {'current_rule': 3})

In [23]:
env = NHotObservation(env)

In [24]:
env.reset()

(array([[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0],
        [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1]], dtype=int8),
 {'current_rule': 7})