### Test Notebook for the WCST Engine

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np

from wcst_session import WcstSession
from block_switching_conditions import short_condition, create_bernoulli_condition
from card_generators import RandomCardGenerator
from rule_generators import RandomRuleGeneratorHuman


In [3]:
sess = WcstSession(correct_value=10, incorrect_value=-4, block_switching_condition=short_condition, random_seed=42)

In [4]:
sess.get_cards_text()

array([['TRIANGLE', 'YELLOW', 'ESCHER'],
       ['STAR', 'MAGENTA', 'SWIRL'],
       ['SQUARE', 'CYAN', 'RIPPLE'],
       ['CIRCLE', 'GREEN', 'POLKADOT']], dtype='<U8')

In [5]:
sess.make_selection(2)

(False, -4)

In [6]:
sess.get_cards_text()

array([['STAR', 'MAGENTA', 'POLKADOT'],
       ['TRIANGLE', 'YELLOW', 'ESCHER'],
       ['CIRCLE', 'CYAN', 'RIPPLE'],
       ['SQUARE', 'GREEN', 'SWIRL']], dtype='<U8')

In [7]:
sess.make_selection(2)

(False, -4)

In [8]:
sess.get_cards_text()

array([['CIRCLE', 'GREEN', 'ESCHER'],
       ['STAR', 'CYAN', 'SWIRL'],
       ['TRIANGLE', 'MAGENTA', 'RIPPLE'],
       ['SQUARE', 'YELLOW', 'POLKADOT']], dtype='<U8')

In [9]:
sess.make_selection(3)

(True, 10)

In [10]:
sess.get_cards_text()

array([['SQUARE', 'CYAN', 'POLKADOT'],
       ['CIRCLE', 'GREEN', 'ESCHER'],
       ['STAR', 'YELLOW', 'RIPPLE'],
       ['TRIANGLE', 'MAGENTA', 'SWIRL']], dtype='<U8')

In [11]:
sess.make_selection(0)

(True, 10)

In [12]:
sess.get_cards_text()

array([['STAR', 'CYAN', 'RIPPLE'],
       ['SQUARE', 'GREEN', 'SWIRL'],
       ['TRIANGLE', 'MAGENTA', 'ESCHER'],
       ['CIRCLE', 'YELLOW', 'POLKADOT']], dtype='<U8')

In [13]:
sess.make_selection(1)

(False, -4)

In [14]:
sess.get_cards_text()

array([['SQUARE', 'CYAN', 'SWIRL'],
       ['STAR', 'MAGENTA', 'POLKADOT'],
       ['TRIANGLE', 'GREEN', 'RIPPLE'],
       ['CIRCLE', 'YELLOW', 'ESCHER']], dtype='<U8')

In [15]:
sess.make_selection(0)

(False, -4)

In [16]:
sess.dump_history()

Unnamed: 0,TrialNumber,BlockNumber,TrialAfterRuleChange,Response,ItemChosen,CurrentRule,Reward,Item0Shape,Item0Color,Item0Pattern,Item1Shape,Item1Color,Item1Pattern,Item2Shape,Item2Color,Item2Pattern,Item3Shape,Item3Color,Item3Pattern
0,0,0,0,Incorrect,2,POLKADOT,-4,TRIANGLE,YELLOW,ESCHER,STAR,MAGENTA,SWIRL,SQUARE,CYAN,RIPPLE,CIRCLE,GREEN,POLKADOT
1,1,0,1,Incorrect,2,POLKADOT,-4,STAR,MAGENTA,POLKADOT,TRIANGLE,YELLOW,ESCHER,CIRCLE,CYAN,RIPPLE,SQUARE,GREEN,SWIRL
2,2,0,2,Correct,3,POLKADOT,10,CIRCLE,GREEN,ESCHER,STAR,CYAN,SWIRL,TRIANGLE,MAGENTA,RIPPLE,SQUARE,YELLOW,POLKADOT
3,3,0,3,Correct,0,POLKADOT,10,SQUARE,CYAN,POLKADOT,CIRCLE,GREEN,ESCHER,STAR,YELLOW,RIPPLE,TRIANGLE,MAGENTA,SWIRL
4,4,0,4,Incorrect,1,POLKADOT,-4,STAR,CYAN,RIPPLE,SQUARE,GREEN,SWIRL,TRIANGLE,MAGENTA,ESCHER,CIRCLE,YELLOW,POLKADOT
5,5,0,5,Incorrect,0,POLKADOT,-4,SQUARE,CYAN,SWIRL,STAR,MAGENTA,POLKADOT,TRIANGLE,GREEN,RIPPLE,CIRCLE,YELLOW,ESCHER


### Test WcstSession with only 2 cards, 1 feature dimension

In [17]:
sess = WcstSession(
    correct_value=10, 
    incorrect_value=-4, 
    block_switching_condition=short_condition, 
    card_generator=RandomCardGenerator(42, 2, 1),
    rule_generator=RandomRuleGeneratorHuman(42, 2, 1),
    feature_names=np.array(["RED", "BLUE"]),
    dim_names=["Color"],
    random_seed=42
)

In [18]:
sess.get_cards_text()

array([['BLUE'],
       ['RED']], dtype='<U4')

In [19]:
sess.make_selection(0)

(True, 10)

In [20]:
sess.get_cards_text()

array([['RED'],
       ['BLUE']], dtype='<U4')

In [21]:
sess.make_selection(1)

(True, 10)

In [22]:
sess.get_cards_text()

array([['RED'],
       ['BLUE']], dtype='<U4')

In [23]:
sess.make_selection(0)

(False, -4)

### Test WCST with probabilistic reward

In [24]:
sess = WcstSession(
    correct_value=10, 
    incorrect_value=-4, 
    prob_reward_matches=0.5,
    block_switching_condition=short_condition, 
    card_generator=RandomCardGenerator(42, 2, 1),
    rule_generator=RandomRuleGeneratorHuman(42, 2, 1),
    feature_names=np.array(["RED", "BLUE"]),
    dim_names=["Color"],
    random_seed=42
)

In [25]:
sess.get_cards_text()

array([['BLUE'],
       ['RED']], dtype='<U4')

In [26]:
sess.make_selection(0)

(True, -4)

In [28]:
sess.get_cards_text()

array([['RED'],
       ['BLUE']], dtype='<U4')

In [29]:
sess.make_selection(1)

(True, 10)

In [30]:
sess.get_cards_text()

array([['RED'],
       ['BLUE']], dtype='<U4')

In [31]:
sess.make_selection(0)

(False, -4)

In [32]:
sess.get_cards_text()

array([['RED'],
       ['BLUE']], dtype='<U4')

In [33]:
sess.make_selection(0)

(False, -4)

In [34]:
sess.get_cards_text()

array([['BLUE'],
       ['RED']], dtype='<U4')

In [35]:
sess.make_selection(1)

(False, 10)

In [36]:
sess.dump_history()

Unnamed: 0,TrialNumber,BlockNumber,TrialAfterRuleChange,Response,ItemChosen,CurrentRule,Reward,Item0Color,Item1Color
0,0,0,0,Correct,0,BLUE,-4,BLUE,RED
1,1,0,1,Correct,1,BLUE,10,RED,BLUE
2,2,0,2,Incorrect,0,BLUE,-4,RED,BLUE
3,3,0,3,Incorrect,0,BLUE,-4,RED,BLUE
4,4,0,4,Incorrect,1,BLUE,10,BLUE,RED


### Test with probabilistic block switch

In [12]:
sess = WcstSession(
    correct_value=10, 
    incorrect_value=-4, 
    block_switching_condition=create_bernoulli_condition(0.2), 
)

In [13]:
sess.get_cards_text()

array([['TRIANGLE', 'YELLOW', 'POLKADOT'],
       ['STAR', 'GREEN', 'ESCHER'],
       ['CIRCLE', 'CYAN', 'SWIRL'],
       ['SQUARE', 'MAGENTA', 'RIPPLE']], dtype='<U8')

In [14]:
sess.make_selection(0)

(False, -4)

In [15]:
sess.current_block

1

In [16]:
sess.get_cards_text()

array([['STAR', 'MAGENTA', 'ESCHER'],
       ['TRIANGLE', 'GREEN', 'SWIRL'],
       ['CIRCLE', 'YELLOW', 'RIPPLE'],
       ['SQUARE', 'CYAN', 'POLKADOT']], dtype='<U8')

In [17]:
sess.make_selection(0)

(False, -4)

In [18]:
sess.current_block

1

In [19]:
sess.get_cards_text()

array([['SQUARE', 'MAGENTA', 'POLKADOT'],
       ['STAR', 'GREEN', 'RIPPLE'],
       ['CIRCLE', 'YELLOW', 'SWIRL'],
       ['TRIANGLE', 'CYAN', 'ESCHER']], dtype='<U8')

In [20]:
sess.make_selection(0)

(True, 10)

In [21]:
sess.current_block

2