In [27]:
from pyClarion import (Agent, Input, Choice, ChunkStore, FixedRules, 
    Family, Atoms, Atom, BaseLevel, Pool, NumDict, Event, Priority)
from datetime import timedelta

import logging
import sys

import numpy as np
import pandas as pd
from typing import *

from utils import RuleWBLA

logger = logging.getLogger("pyClarion.system")
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

In [29]:
class Brick(Atoms):
    horizontal: Atom
    vertical: Atom
    mirror_L: Atom
    half_T: Atom

class GridRows(Atoms):
    r1: Atom
    r2: Atom
    r3: Atom
    r4: Atom
    r5: Atom
    r6: Atom

class GridCols(Atoms):
    c1: Atom
    c2: Atom
    c3: Atom
    c4: Atom
    c5: Atom
    c6: Atom

class QueryRel(Atoms):
    left: Atom
    above: Atom
    right: Atom
    below: Atom

class SignalTokens(Atoms):
    stop_construction: Atom
    continue_construction: Atom

class IO(Atoms):
    input_shape1: Atom
    input_shape2: Atom
    input_shape3: Atom
    input_shape4: Atom

    target_shape1: Atom
    target_shape2: Atom
    target_shape3: Atom
    target_shape4: Atom

    input_shape1_row1: Atom
    input_shape1_row2: Atom
    input_shape1_row3: Atom
    input_shape1_col1: Atom
    input_shape1_col2: Atom
    input_shape1_col3: Atom

    input_shape2_row1: Atom
    input_shape2_row2: Atom
    input_shape2_row3: Atom
    input_shape2_col1: Atom
    input_shape2_col2: Atom
    input_shape2_col3: Atom

    input_shape3_row1: Atom
    input_shape3_row2: Atom
    input_shape3_row3: Atom
    input_shape3_col1: Atom
    input_shape3_col2: Atom
    input_shape3_col3: Atom

    input_shape4_row1: Atom
    input_shape4_row2: Atom
    input_shape4_row3: Atom
    input_shape4_col1: Atom
    input_shape4_col2: Atom
    input_shape4_col3: Atom

    target_shape1_row1: Atom
    target_shape1_row2: Atom
    target_shape1_row3: Atom
    target_shape1_col1: Atom
    target_shape1_col2: Atom
    target_shape1_col3: Atom

    target_shape2_row1: Atom
    target_shape2_row2: Atom
    target_shape2_row3: Atom
    target_shape2_col1: Atom
    target_shape2_col2: Atom   
    target_shape2_col3: Atom

    target_shape3_row1: Atom
    target_shape3_row2: Atom
    target_shape3_row3: Atom
    target_shape3_col1: Atom
    target_shape3_col2: Atom
    target_shape3_col3: Atom   

    target_shape4_row1: Atom
    target_shape4_row2: Atom    
    target_shape4_row3: Atom
    target_shape4_col1: Atom
    target_shape4_col2: Atom
    target_shape4_col3: Atom
    
    construction_signal: Atom
    query_block: Atom
    query_block_reference: Atom
    query_relation: Atom
    output: Atom

class Response(Atoms):
    yes: Atom
    no: Atom

In [None]:
class BrickTaskData(Family):
    bricks: Brick
    grid_rows: GridRows
    grid_cols: GridCols
    query_rel: QueryRel
    signal_cons: SignalTokens
    io: IO
    response: Response

In [None]:
class Participant(Agent):  
    d: BrickTaskData 
    construction_input: Input
    response_input: Input
    response_rules: RuleWBLA
    response_blas: BaseLevel
    pool: Pool
    response_choice: Choice
    #TODO: fill in the rest here

    def __init__(self, name: str) -> None:
        p = Family() # what is p
        e = Family() # what is e?
        d = BrickTaskData()

        super().__init__(name, p=p, e=e, d=d)
        self.d = d
        with self:
            self.construction_input = Input("construction_input", (d, d))
            self.response_input = Input("response_input", (d, d))
            self.response_rules = RuleWBLA("response_rules", p=p, r=d, c=d, d=d, v=d, sd=1e-4)
            self.search_space_rules = RuleWBLA("search_space_rules", p=p, r=d, c=d, d=d, v=d, sd=1e-4)
            
            self.response_blas = BaseLevel("blas", p, e, self.response_rules.rules.lhs.chunks)
            self.search_space_blas = BaseLevel("search_space_blas", p, e, self.search_space_rules.rules.lhs.chunks)

            self.response_pool = Pool("pool", p, self.response_rules.rules.lhs.chunks, func=NumDict.sum) # to pool together the blas and condition activations
            self.search_space_pool = Pool("search_space_pool", p, self.search_space_rules.rules.lhs.chunks, func=NumDict.sum) # similar function

            self.response_choice = Choice("choice", p, self.response_rules.rules.lhs.chunks)
            self.search_space_choice = Choice("search_space_choice", p, self.search_space_rules.rules.lhs.chunks)
        
        self.response_blas.input = self.response_choice.main
        self.search_space_blas.input = self.search_space_choice.main

        self.response_rules.rules.lhs.bu.input = self.response_input.main # TODO: how to do the rules apply a change to the input if i never say to do so . 
        self.search_space_rules.rules.lhs.bu.input = self.construction_input.main
        
        #construction choices must be represented back in, in order to choose more rules
        self.search_space_choice.main = self.construction_input.main # TODO: is this the way to do this?

        self.response_pool["blas"] = (
            self.response_blas.main, 
            lambda d: d.bound_min(x=1e-8).log().with_default(c=0.0))
        self.search_space_pool["blas"] = (
            self.search_space_blas.main,
            lambda d: d.bound_min(x=1e-8).log().with_default(c=0.0))
        
        self.response_rules.bla_main = self.response_pool.main #updated site for the response rules to take BLAS into account
        self.search_space_rules.bla_main = self.search_space_pool.main

        self.response_choice.input = self.response_rules.rules.rhs.td.main # bu choice
        self.search_space_choice.input = self.search_rules.rules.rhs.td.main

        with self.response_pool.params[0].mutable():
            self.response_pool.params[0][~self.response_pool.p["blas"]] = 2e-1
        with self.search_space_pool.params[0].mutable():
            self.search_space_pool.params[0][~self.search_space_pool.p["blas"]] = 2e-1

        self.response_blas.ignore.add(~self.store.response_rules.rules.lhs.chunks.nil) # is this required? # TODO: check
        self.search_space_blas.ignore.add(~self.store.search_space_rules.rules.lhs.chunks.nil)

    def resolve(self, event: Event) -> None:
        if event.source == self.response_rules.rules.lhs.bu.update: # TODO: is this right?
            self.response_blas.update() # timestep update
        elif event.source == self.response_blas.update:
            self.response_choice.trigger() #poll outside the loop in experimental loop
        elif event.source == self.search_space_rules.rules.lhs.bu.update:
            self.search_space_blas.update()
        elif event.source == self.search_space_blas.update:
            self.search_space_choice.trigger() # TODO: this actually leads to a selection, that is reflected in choice.main, but not in the bu of the search space rules?
        elif event.source == self.search_space_choice.select_complete:
            #check if indeed we need to stop construction, if triggered the STOP rule
            cur_choice = self.search_space_choice.poll()
            if cur_choice[~self.d.io.construction_signal] == ~self.d.signal_cons.stop_construction: # TODO: check this
                self.end_construction()
        elif event.source == self.end_construction:
            # clear system queue
            self.system.queue.clear() # at this point, you should have a clear target_grid representation built -- correct or incorrect

    def start_construct_trial(self, 
        dt: timedelta, 
        priority: Priority = Priority.PROPAGATION
    ) -> None:
        self.system.schedule(self.start_construct_trial, dt=dt, priority=priority)

    def end_construction(self, 
                         dt: timedelta,
                         priority: Priority = Priority.PROPAGATION
                         ) -> None:
        self.system.schedule(self.end_construction, dt=dt, priority=priority)

    def start_response_trial(self, 
        dt: timedelta, 
        priority: Priority = Priority.PROPAGATION
    ) -> None:
        self.system.schedule(self.start_response_trial, dt=dt, priority=priority)
    
    def finish_response_trial(self,
                              dt: timedelta,
        priority: Priority = Priority.PROPAGATION
    ) -> None:
        self.system.schedule(self.finish_response_trial, dt=dt, priority=priority)

    #TODO: providing feedback on constructions
    #TODO:running search for the correct construction
    #TODO: RL for the construction task + rule extraction from results of RL on the bottom level. 

# Knowledge init

In [None]:
def init_participant_response_rules(participant: Participant) -> None:
    d = participant.d
    io = d.io
    bricks = d.bricks
    grid_rows = d.grid_rows
    grid_cols = d.grid_cols
    query_rel = d.query_rel
    response = d.response

    #response rules when given a query and some blocks
    
    # HALF_T X HORIZONTAL
    #i = 0, horizontal is next to the potruding part of the shape, else it is next to the flat part at the bottom
    half_T_right_of_horizontal = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.horizontal)
        + io.query_block_reference ** (bricks.horizontal if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.right if not switcharoo else query_rel.left)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape4 ** bricks.horizontal
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col+2-i}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col+3-i}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col+4-i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1, 6) for col in range(1, 3 + i)
    ]

    half_T_left_of_horizontal = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.horizontal)
        + io.query_block_reference ** (bricks.horizontal if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.left if not switcharoo else query_rel.right)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape4 ** bricks.horizontal
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col-3}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col-2}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col-1}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1, 6) for col in range(4, 6)
    ]
    """
    TODO:
    Check the actual relation for
    1 1
    1 2 2 2 is this ontopness or beside? 
    TODO: Similarly for the other relations involving L like shapes
    """

    half_T_below_horizontal = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.horizontal)
        + io.query_block_reference ** (bricks.horizontal if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.below if not switcharoo else query_rel.above)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape4 ** bricks.horizontal
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row-1}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row-1}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row-1}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col+i}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col+1+i}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col+2+i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(2, 6) for col in range(1, 3+i)
    ]

    half_T_above_horizontal = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.horizontal)
        + io.query_block_reference ** (bricks.horizontal if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.above if not switcharoo else query_rel.below)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape4 ** bricks.horizontal
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row+2}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row+2}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row+2}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col+2}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for row in range(1, 4) for col in range(1, 5)
]

    #HALF_T X VERTICAL

    half_T_left_vertical = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.left if not switcharoo else query_rel.right)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape3 ** bricks.vertical
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape3_row1 ** grid_rows[f"r{row+i}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row+i+1}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row+i+2}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col+2-i}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col+2-i}"]
        + io.target_shape3_col3 ** grid_cols[f"c{col+2-i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1, 4-i) for col in range(1, 5+i)
    ]

    half_T_right_vertical = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.right if not switcharoo else query_rel.left)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape3 ** bricks.vertical
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape3_row1 ** grid_rows[f"r{row+i}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row+i+1}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row+i+2}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col-1}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape3_col3 ** grid_cols[f"c{col-1}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1, 4-i) for col in range(2, 6)
    ]

    half_T_below_vertical = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.below if not switcharoo else query_rel.above)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape3 ** bricks.vertical
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape3_row1 ** grid_rows[f"r{row-3}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row-2}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row-3}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col+i}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col+i}"]
        + io.target_shape3_col3 ** grid_cols[f"c{col+i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(4, 6) for col in range(1, 6)
    ]

    half_T_above_vertical = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.above if not switcharoo else query_rel.below)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape3 ** bricks.vertical
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape3_row1 ** grid_rows[f"r{row+1-i+1}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row+1-i+2}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row+1-i+3}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col+i}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col+i}"]
        + io.target_shape3_col3 ** grid_cols[f"c{col+i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1, 3) for col in range(1, 6)
    ]

    #HALF_T X MIRROR_L
    half_T_left_mirror_L = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.mirror_L)
        + io.query_block_reference ** (bricks.mirror_L if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.left if not switcharoo else query_rel.right)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape2_row1 ** grid_rows[f"r{row-i}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+1-i}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+1-i}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col+1+2*i}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col+2*i}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col+1+2*i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1+i, 6) for col in range(1, 6-2*i)
    ]

    # 1 1 2
    # 1 2 2

    #       2
    # 1 1 2 2
    # 1
    half_T_right_mirror_L = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.mirror_L)
        + io.query_block_reference ** (bricks.mirror_L if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.right if not switcharoo else query_rel.left)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape2_row1 ** grid_rows[f"r{row-i}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+1-i}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+1-i}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col-2}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1+i, 6) for col in range(3, 6)
    ]

    half_T_below_mirror_L = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.mirror_L)
        + io.query_block_reference ** (bricks.mirror_L if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.below if not switcharoo else query_rel.above)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape2_row1 ** grid_rows[f"r{row-2}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row-1}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row-1}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col+1+i}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col+i}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col+1+i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(3, 6) for col in range(1, 6-i)
    ]

    half_T_above_mirror_L = [
        (+ io.query_block ** (bricks.half_T if not switcharoo else bricks.mirror_L)
        + io.query_block_reference ** (bricks.mirror_L if not switcharoo else bricks.half_T)
        + io.query_relation ** (query_rel.above if not switcharoo else query_rel.below)
        + io.target_shape1 ** bricks.half_T
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape1_row1 ** grid_rows[f"r{row}"]
        + io.target_shape1_row2 ** grid_rows[f"r{row}"]
        + io.target_shape1_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape1_col1 ** grid_cols[f"c{col}"]
        + io.target_shape1_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape1_col3 ** grid_cols[f"c{col}"]
        + io.target_shape2_row1 ** grid_rows[f"r{row+2}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+3}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+3}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for row in range(1, 4) for col in range(2, 6)
    ]

    #MIRROR_L X HORIZONTAL
    mirror_L_left_horizontal  = [
        (+ io.query_block ** (bricks.mirror_L if not switcharoo else bricks.horizontal)
        + io.query_block_reference ** (bricks.horizontal if not switcharoo else bricks.mirror_L)
        + io.query_relation ** (query_rel.left if not switcharoo else query_rel.right)
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape4 ** bricks.horizontal
        + io.target_shape2_row1 ** grid_rows[f"r{row}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col+1}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col+2}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col+3}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1, 6) for col in range(2, 4)
    ]

    mirror_L_right_horizontal = [
        (+ io.query_block ** (bricks.mirror_L if not switcharoo else bricks.horizontal)
        + io.query_block_reference ** (bricks.horizontal if not switcharoo else bricks.mirror_L)
        + io.query_relation ** (query_rel.right if not switcharoo else query_rel.left)
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape4 ** bricks.horizontal
        + io.target_shape2_row1 ** grid_rows[f"r{row}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col-3-i}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col-2-i}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col-1-i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1, 6) for col in range(4+i, 7)
    ]
    """
    2 2 2 1
        1 1

            1
    2 2 2 1 1
    """

    mirror_L_above_horizontal = [
        (+ io.query_block ** (bricks.mirror_L if not switcharoo else bricks.horizontal)
        + io.query_block_reference ** (bricks.horizontal if not switcharoo else bricks.mirror_L)
        + io.query_relation ** (query_rel.below if not switcharoo else query_rel.above)
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape4 ** bricks.horizontal
        + io.target_shape2_row1 ** grid_rows[f"r{row}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row+2}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row+2}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row+2}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col-1+i}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col+i}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col+1+i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1, 5) for col in range(2, 5-i)
    ]
    """
    1
    1 1
    2 2 2

    1
    1 1
    2 2 2
    """

    mirror_L_below_horizontal = [
        (+ io.query_block ** (bricks.mirror_L if not switcharoo else bricks.horizontal)
        + io.query_block_reference ** (bricks.horizontal if not switcharoo else bricks.mirror_L)
        + io.query_relation ** (query_rel.below if not switcharoo else query_rel.above)
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape4 ** bricks.horizontal
        + io.target_shape2_row1 ** grid_rows[f"r{row}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row-1}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row-1}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row-1}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col-2}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for row in range(2, 6) for col in range(3, 7)
    ]
    """
    2 2 2
        1
    1 1
    """

    #MIRROR_L X VERTICAL
    mirror_L_left_vertical  = [
        (+ io.query_block ** (bricks.mirror_L if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.mirror_L)
        + io.query_relation ** (query_rel.left if not switcharoo else query_rel.right)
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape3 ** bricks.vertical
        + io.target_shape2_row1 ** grid_rows[f"r{row}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col}"]
        + io.target_shape3_row1 ** grid_rows[f"r{row+i}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row+i+1}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row+i+2}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col+1}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col+1}"]
        + io.target_shape3_col3 ** grid_cols[f"c{col+1}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1, 5-i) for col in range(2, 6)
    ]

    """
    1 2
    1 1 2
        2

    1 
    1 1 2
        2
        2
    """

    mirror_L_right_vertical = [
        (+ io.query_block ** (bricks.mirror_L if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.mirror_L)
        + io.query_relation ** (query_rel.right if not switcharoo else query_rel.left)
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape3 ** bricks.vertical
        + io.target_shape2_row1 ** grid_rows[f"r{row}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col}"]
        + io.target_shape3_row1 ** grid_rows[f"r{row-1+i}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row+i}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row+1+i}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col-2}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col-2}"]
        + io.target_shape3_col3 ** grid_cols[f"c{col-2}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(3) for row in range(max(1, 2-i), 6-i) for col in range(3, 6)
    ]

    """
    2   
    2   1
    2 1 1

    2   1
    2 1 1
    2

        1
    2 1 1
    2
    2
    """

    mirror_L_above_vertical = [
        (+ io.query_block ** (bricks.mirror_L if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.mirror_L)
        + io.query_relation ** (query_rel.above if not switcharoo else query_rel.below)
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape3 ** bricks.vertical
        + io.target_shape2_row1 ** grid_rows[f"r{row}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col}"]
        + io.target_shape3_row1 ** grid_rows[f"r{row+2}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row+3}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row+4}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col-1+i}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col-1+i}"]
        + io.target_shape3_col3 ** grid_cols[f"c{col-1+i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(1, 3) for col in range(2, 6)
    ]

    """
    1
    1 1
    2
    2
    2 

    1
    1 1
    2
    2
    2
    """

    mirror_L_below_vertical = [
        (+ io.query_block ** (bricks.mirror_L if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.mirror_L)
        + io.query_relation ** (query_rel.below if not switcharoo else query_rel.above)
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape3 ** bricks.vertical
        + io.target_shape2_row1 ** grid_rows[f"r{row}"]
        + io.target_shape2_row2 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_row3 ** grid_rows[f"r{row+1}"]
        + io.target_shape2_col1 ** grid_cols[f"c{col}"]
        + io.target_shape2_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape2_col3 ** grid_cols[f"c{col}"]
        + io.target_shape3_row1 ** grid_rows[f"r{row-2-i}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row-1-i}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row-i}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col-1+i}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col-1+i}"]
        + io.target_shape3_col3 ** grid_cols[f"c{col-1+i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(2) for row in range(4, 6) for col in range(2, 6)
    ]

    """
    2
    2
    2 1
    1 1

    2
    2
    2
    1
    1 1
    """

    #HORIZONTAL X VERTICAL
    horizontal_left_vertical  = [
        (+ io.query_block ** (bricks.horizontal if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.horizontal)
        + io.query_relation ** (query_rel.left if not switcharoo else query_rel.right)
        + io.target_shape3 ** bricks.horizontal
        + io.target_shape4 ** bricks.vertical
        + io.target_shape3_row1 ** grid_rows[f"r{row}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col+1}"] 
        + io.target_shape3_col3 ** grid_cols[f"c{col+2}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row-2+i}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row-1+i}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col+3}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col+3}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col+3}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(3) for row in range(3-i, 7-i) for col in range(1, 4)
    ]

    """
        2
        2
    1 1 1 2

        2
    1 1 1 2
        2

    1 1 1 2
        2
        2      
    """

    horizontal_right_vertical = [
        (+ io.query_block ** (bricks.horizontal if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.horizontal)
        + io.query_relation ** (query_rel.right if not switcharoo else query_rel.left)
        + io.target_shape3 ** bricks.horizontal
        + io.target_shape4 ** bricks.vertical
        + io.target_shape3_row1 ** grid_rows[f"r{row}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col+1}"] 
        + io.target_shape3_col3 ** grid_cols[f"c{col+2}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row-2+i}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row-1+i}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row+i}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col-1}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col-1}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col-1}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(3) for row in range(3-i, 7-i) for col in range(2, 5)
    ]

    """
    2
    2
    2 1 1 1

    2
    2 1 1 1
    2

    2
    2
    2 1 1 1
    """

    horizontal_above_vertical = [
        (+ io.query_block ** (bricks.horizontal if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.horizontal)
        + io.query_relation ** (query_rel.above if not switcharoo else query_rel.below)
        + io.target_shape3 ** bricks.horizontal
        + io.target_shape4 ** bricks.vertical
        + io.target_shape3_row1 ** grid_rows[f"r{row}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col+1}"] 
        + io.target_shape3_col3 ** grid_cols[f"c{col+2}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row+1}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row+2}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row+3}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col+i}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col+i}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col+1+i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(3) for row in range(1, 4) for col in range(1, 4)
    ]

    """
    1 1 1
    2
    2
    2

    1 1 1
    2
    2
    2

    1 1 1
        2
        2
        2
    """

    horizontal_below_vertical = [
        (+ io.query_block ** (bricks.horizontal if not switcharoo else bricks.vertical)
        + io.query_block_reference ** (bricks.vertical if not switcharoo else bricks.horizontal)
        + io.query_relation ** (query_rel.below if not switcharoo else query_rel.left)
        + io.target_shape3 ** bricks.horizontal
        + io.target_shape4 ** bricks.vertical
        + io.target_shape3_row1 ** grid_rows[f"r{row}"]
        + io.target_shape3_row2 ** grid_rows[f"r{row}"]
        + io.target_shape3_row3 ** grid_rows[f"r{row}"]
        + io.target_shape3_col1 ** grid_cols[f"c{col}"]
        + io.target_shape3_col2 ** grid_cols[f"c{col+1}"] 
        + io.target_shape3_col3 ** grid_cols[f"c{col+2}"]
        + io.target_shape4_row1 ** grid_rows[f"r{row-3}"]
        + io.target_shape4_row2 ** grid_rows[f"r{row-2}"]
        + io.target_shape4_row3 ** grid_rows[f"r{row-1}"]
        + io.target_shape4_col1 ** grid_cols[f"c{col+i}"]
        + io.target_shape4_col2 ** grid_cols[f"c{col+i}"]
        + io.target_shape4_col3 ** grid_cols[f"c{col+i}"]
        >>
        + io.output ** response.yes) for switcharoo in (True, False) for i in range(3) for row in range(4, 7) for col in range(1, 5)
    ]
    """
    2
    2
    2
    1 1 1

    2
    2
    2
    1 1 1

        2
        2
        2
    1 1 1
    """

    participant.fixed_rules.compile(*(half_T_left_of_horizontal + half_T_left_vertical + half_T_left_mirror_L + 
                                      half_T_right_of_horizontal + half_T_right_vertical + half_T_right_mirror_L +
                                      half_T_below_horizontal + half_T_below_vertical + half_T_below_mirror_L +
                                      half_T_above_horizontal + half_T_above_vertical + half_T_above_mirror_L +
                                      mirror_L_left_horizontal + mirror_L_left_vertical +
                                      mirror_L_right_horizontal + mirror_L_right_vertical +
                                      mirror_L_below_horizontal + mirror_L_below_vertical +
                                      mirror_L_above_horizontal + mirror_L_above_vertical +
                                      horizontal_left_vertical + horizontal_left_vertical + horizontal_left_vertical +
                                      horizontal_right_vertical + horizontal_right_vertical + horizontal_right_vertical +
                                      horizontal_above_vertical + horizontal_above_vertical + horizontal_above_vertical +
                                      horizontal_below_vertical + horizontal_below_vertical + horizontal_below_vertical))
    

In [31]:
def init_participant_construction_rules(participant: Participant) -> None:
    d = participant.d
    io = d.io
    bricks = d.bricks
    grid_rows = d.grid_rows
    grid_cols = d.grid_cols
    query_rel = d.query_rel
    response = d.response
    con_signal = d.signal_cons

    # END_CONSTRUCTION RULE
    # if all four blocks have been used, then stop construction
    #TODO: is it possible for a dimension to not be bound with anything? that is the empty state yes? -- 0.0 activation value yea?
    stop_construction_rule = (
        + io.target_shape1 ** bricks.half_T
        + io.target_shape2 ** bricks.mirror_L
        + io.target_shape3 ** bricks.vertical
        + io.target_shape4 ** bricks.horizontal

        >>
        + io.construction_signal ** con_signal.stop_construction
    )

    participant.search_space_rules.compile(
        stop_construction_rule
    )

In [24]:
def present_stimulus(d:BrickTaskData, stim_grid: np.ndarray):
    global bricks
    stim_bricks = np.unique(stim_grid)
    stim_bricks = stim_bricks[bricks != 0]

    brick_map = {1: bricks.half_T, 2: bricks.mirror_L, 3: bricks.vertical, 4: bricks.horizontal}
    row_map = {1: d.grid_rows.r1, 2: d.grid_rows.r2, 3: d.grid_rows.r3, 4: d.grid_rows.r4, 5: d.grid_rows.c5, 6: d.grid_rows.r6}
    col_map = {1: d.grid_cols.c1, 2: d.grid_cols.c2, 3: d.grid_cols.c3, 4: d.grid_cols.c4, 5: d.grid_cols.c5, 6: d.grid_cols.c6}
    
    brick_row_map = {1: {1: d.io.input_shape1_row1, 2: d.io.input_shape1_row2, 3: d.io.input_shape1_row3}, 2: {1: d.io.input_shape2_row1, 2: d.io.input_shape2_row2, 3: d.io.input_shape2_row3}, 3: {1: d.io.input_shape3_row1, 2: d.io.input_shape3_row2, 3: d.io.input_shape3_row3}, 4: {1: d.io.input_shape4_row1, 2: d.io.input_shape4_row2, 3: d.io.input_shape4_row3}}
    brick_col_map = {1: {1: d.io.input_shape1_col1, 2: d.io.input_shape1_col2, 3: d.io.input_shape1_col3}, 2: {1: d.io.input_shape2_col1, 2: d.io.input_shape2_col2, 3: d.io.input_shape2_col3}, 3: {1: d.io.input_shape3_col1, 2: d.io.input_shape3_col2, 3: d.io.input_shape3_col3}, 4: {1: d.io.input_shape4_col1, 2: d.io.input_shape4_col2, 3: d.io.input_shape4_col3}}
    shape_brick_map = {1: d.io.input_shape1, 2: d.io.input_shape2, 3: d.io.input_shape3, 4: d.io.input_shape4}
    
    in_send_val = None
    for i, brick in enumerate(stim_bricks):
        # brick row indices: 
        row_indices = np.where(stim_grid == brick)[0] + 1
        col_indices = np.where(stim_grid == brick)[1] + 1

        if not in_send_val:
            in_send_val = (+ shape_brick_map[brick] ** brick_map[brick] 
                           + brick_row_map[brick][1] ** row_map[row_indices[0]] 
                           + brick_row_map[brick][2] **  row_map[row_indices[1]]
                           + brick_row_map[brick][3] ** row_map[row_indices[2]]
                           + brick_col_map[brick][1] ** col_map[col_indices[0]]
                           + brick_col_map[brick][2] ** col_map[col_indices[1]]
                           + brick_col_map[brick][3] ** col_map[col_indices[2]])
        else:
            in_send_val = (in_send_val
                            + shape_brick_map[brick] ** brick_map[brick]
                            + brick_row_map[brick][1] ** row_map[row_indices[0]]
                            + brick_row_map[brick][2] ** row_map[row_indices[1]]
                            + brick_row_map[brick][3] ** row_map[row_indices[2]]
                            + brick_col_map[brick][1] ** col_map[col_indices[0]]
                            + brick_col_map[brick][2] ** col_map[col_indices[1]]
                            + brick_col_map[brick][3] ** col_map[col_indices[2]])
            
    return in_send_val

def load_trial(d: BrickTaskData ,trial, t_type="test", q_type="query"):
    global query_rel
    grid_name = trial["Grid_Name"]
    
    stim_grid = np.load(f"/Users/mishaal/personalproj/clarion_replay/processed/{t_type}_data/{t_type}_stims/{grid_name}.npy")         
    chunk_grid = present_stimulus(stim_grid)
    
    query_map = {1: query_rel.left, 2: query_rel.ontop, 3: query_rel.right, 4: query_rel.below}
    brick_map = {1: bricks.half_T, 2: bricks.mirror_L, 3: bricks.vertical, 4: bricks.horizontal}

    if t_type == "test":
        chunk_test = ( + d.io.query_relation ** query_map[trial["Q_Relation"]] 
                      + d.io.query_block ** brick_map[trial["Q_Brick_Left"]]
                      + d.io.query_block_reference ** brick_map[trial["Q_Brick_Right"]])
    elif t_type == "train" and q_type == "query":
        # choose 2 blocks randomly
        blocks = np.random.choice([1, 2, 3, 4], 2, replace=False)
        # choose a relation randomly
        relation = np.random.choice([1, 2, 3, 4], 1)
        chunk_test = ( + d.io.query_relation ** query_map[relation] 
                      + d.io.query_block ** brick_map[blocks[0]]
                      + d.io.query_block_reference ** brick_map[blocks[1]])
    else: chunk_test = ()
    return chunk_grid, chunk_test

# Simulation

In [None]:
def run_participant_session(participant: Participant, session_df: pd.DataFrame, session_type="train", q_type="query"):
    global rule_defs
    results = []
    trials = []
    # Knowledge initialization
    init_participant_response_rules(participant)
    init_participant_construction_rules(participant)
    
    for _, trial in session_df.iterrows():
        trials.append(trial)
    
    participant.start_construct_trial(timedelta())
    while participant.system.queue:
        event = participant.system.advance()
        if event.source == participant.start_construct_trial:
            if not trials: break
            #load the next trial
            trial = trials.pop(0)
            grid_stimulus, test_query = load_trial(trial, t_type=session_type, q_type=q_type)
            participant.construction_input.send(grid_stimulus) # TODO: have a timeout somehow: but how to do timeout wihout proper timinmg constraints for the various events in the queue?
        elif event.source == participant.end_construction:
            participant.start_response_trial(timedelta()) #TODO: checkout the actual time delays
        elif event.source == participant.start_response_trial:
            participant.response_input.send(test_query)
        elif event.source == participant.response_rules.rules.rhs.td.update:
            participant.choice.select()
        elif event.source == participant.choice.select:
            results.append((event.time, participant.choice.poll())) #TODO: come up with a way to save the sequences of search space rules that were activates -- is there in sample attribute of the choice in a rule i believe?
            participant.finish_response_trial(timedelta())
        elif event.source == participant.end_response:
            participant.start_construct_trial(timedelta())

In [24]:
np.where(np.array([[0, 0, 1], [0, 1, 1], [0, 0, 0], [0,0, 0], [0, 0, 0], [0, 0,0]]) == 1)

(array([0, 1, 1]), array([2, 1, 2]))