In [None]:
from pyClarion import (Agent, Input, Choice, ChunkStore, FixedRules, 
    Family, Atoms, Atom)
from datetime import timedelta

import logging
import sys

import numpy as np

logger = logging.getLogger("pyClarion.system")
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

In [None]:
class Brick(Atoms):
    empty: Atom
    horizontal: Atom
    vertical: Atom
    mirror_L: Atom
    half_T: Atom

class GridRows(Atoms):
    r_one: Atom
    r_two: Atom
    r_three: Atom
    r_four: Atom
    r_five: Atom
    r_six: Atom

class GridCols(Atoms):
    c_one: Atom
    c_two: Atom
    c_three: Atom
    c_four: Atom
    c_five: Atom
    c_six: Atom

class QueryRel(Atoms):
    left: Atom
    ontop: Atom
    right: Atom
    below: Atom

class IO(Atoms):
    input_shape1: Atom
    input_shape2: Atom
    input_shape3: Atom
    input_shape4: Atom

    target_shape1: Atom
    target_shape2: Atom
    target_shape3: Atom
    target_shape4: Atom

    input_shape1_row1: Atom
    input_shape1_row2: Atom
    input_shape1_row3: Atom
    input_shape1_col1: Atom
    input_shape1_col2: Atom
    input_shape1_col3: Atom

    input_shape2_row1: Atom
    input_shape2_row2: Atom
    input_shape2_row3: Atom
    input_shape2_col1: Atom
    input_shape2_col2: Atom
    input_shape2_col3: Atom

    input_shape3_row1: Atom
    input_shape3_row2: Atom
    input_shape3_row3: Atom
    input_shape3_col1: Atom
    input_shape3_col2: Atom
    input_shape3_col3: Atom

    input_shape4_row1: Atom
    input_shape4_row2: Atom
    input_shape4_row3: Atom
    input_shape4_col1: Atom
    input_shape4_col2: Atom
    input_shape4_col3: Atom

    target_shape1_row1: Atom
    target_shape1_row2: Atom
    target_shape1_row3: Atom
    target_shape1_col1: Atom
    target_shape1_col2: Atom

    target_shape2_row1: Atom
    target_shape2_row2: Atom
    target_shape2_row3: Atom
    target_shape2_col1: Atom
    target_shape2_col2: Atom   
    target_shape2_col3: Atom

    target_shape3_row1: Atom
    target_shape3_row2: Atom
    target_shape3_row3: Atom
    target_shape3_col1: Atom
    target_shape3_col2: Atom
    target_shape3_col3: Atom   

    target_shape4_row1: Atom
    target_shape4_row2: Atom    
    target_shape4_row3: Atom
    target_shape4_col1: Atom
    target_shape4_col2: Atom
    target_shape4_col3: Atom

    query_block: Atom
    query_block_reference: Atom
    query_relation: Atom
    output: Atom

class Response(Atoms):
    yes: Atom
    no: Atom

In [None]:
class BrickTaskData(Family):
    bricks: Brick
    grid_rows: GridRows
    grid_cols: GridCols
    query_rel: QueryRel
    io: IO
    response: Response

In [15]:
d = BrickTaskData()
with Agent("agent", d=d) as agent:
    data_in = Input("data_in", (d, d))
    chunks = ChunkStore("chunks", c=d, d=d, v=d)
    chunks.bu.input = data_in.main

In [16]:
agent.system.procs

[<Agent 'agent' at 0x106c92710>,
 <Input 'data_in' at 0x1075a1a90>,
 <ChunkStore 'chunks' at 0x1075a1810>,
 <TopDown 'chunks.td' at 0x1075a1bd0>,
 <BottomUp 'chunks.bu' at 0x1075a1d10>]

# Knowledge init

In [None]:
io = d.io
bricks = d.bricks
grid_rows = d.grid_rows
grid_cols = d.grid_cols
query_rel = d.query_rel
response = d.response

In [19]:
def present_stimulus(stim_grid):
    global bricks
    stim_bricks = np.unique(stim_grid)
    stim_bricks = stim_bricks[bricks != 0]

    brick_map = {1: bricks.half_T, 2: bricks.mirror_L, 3: bricks.vertical, 4: bricks.horizontal}
    row_map = {1: grid_rows.r_one, 2: grid_rows.r_two, 3: grid_rows.r_three, 4: grid_rows.r_four, 5: grid_rows.r_five, 6: grid_rows.r_six}
    col_map = {1: grid_cols.c_one, 2: grid_cols.c_two, 3: grid_cols.c_three, 4: grid_cols.c_four, 5: grid_cols.c_five, 6: grid_cols.c_six}
    
    in_send_val = None
    for i, brick in enumerate(stim_bricks):
        # brick row indices: 
        row_indices = np.where(stim_grid == brick)[0] + 1
        col_indices = np.where(stim_grid == brick)[1] + 1

        if not in_send_val:
            in_send_val = (+ io.input_shape1 ** brick_map[brick] 
                           + io.input_shape1_row1 ** row_map[row_indices[0]] 
                           + io.input_shape1_row2 **  row_map[row_indices[1]]
                           + io.input_shape1_row3 ** row_map[row_indices[2]]
                           + io.input_shape1_col1 ** col_map[col_indices[0]]
                           + io.input_shape1_col2 ** col_map[col_indices[1]]
                           + io.input_shape1_col3 ** col_map[col_indices[2]])
        elif i == 1:
            in_send_val = (in_send_val
                           + io.input_shape2 ** brick_map[brick] 
                           + io.input_shape2_row1 ** row_map[row_indices[0]] 
                           + io.input_shape2_row2 **  row_map[row_indices[1]]
                           + io.input_shape2_row3 ** row_map[row_indices[2]]
                           + io.input_shape2_col1 ** col_map[col_indices[0]]
                           + io.input_shape2_col2 ** col_map[col_indices[1]]
                           + io.input_shape2_col3 ** col_map[col_indices[2]])
        elif i == 2:
            in_send_val = (in_send_val 
                           + io.input_shape3 ** brick_map[brick]
                           + io.input_shape3_row1 ** row_map[row_indices[0]] 
                           + io.input_shape3_row2 **  row_map[row_indices[1]]
                           + io.input_shape3_row3 ** row_map[row_indices[2]]
                           + io.input_shape3_col1 ** col_map[col_indices[0]]
                           + io.input_shape3_col2 ** col_map[col_indices[1]]
                           + io.input_shape3_col3 ** col_map[col_indices[2]])
        elif i == 3:
            in_send_val = (+ io.input_shape4 ** brick_map[brick]
                           + io.input_shape4_row1 ** row_map[row_indices[0]] 
                           + io.input_shape4_row2 **  row_map[row_indices[1]]
                           + io.input_shape4_row3 ** row_map[row_indices[2]]
                           + io.input_shape4_col1 ** col_map[col_indices[0]]
                           + io.input_shape4_col2 ** col_map[col_indices[1]]
                           + io.input_shape4_col3 ** col_map[col_indices[2]])
    return in_send_val

def load_trial(trial, data_in, t_type="test", q_type="query"):
    global query_rel
    grid_name = trial["Grid_Name"]
    
    stim_grid = np.load(f"/Users/mishaal/personalproj/clarion_replay/processed/{t_type}_data/{t_type}_stims/{grid_name}.npy")         
    chunk_grid = present_stimulus(stim_grid)
    
    query_map = {1: query_rel.left, 2: query_rel.ontop, 3: query_rel.right, 4: query_rel.below}
    brick_map = {1: bricks.half_T, 2: bricks.mirror_L, 3: bricks.vertical, 4: bricks.horizontal}

    if t_type == "test":
        chunk_test = ( + io.query_relation ** query_map[trial["Q_Relation"]] 
                      + io.query_block ** brick_map[trial["Q_Brick_Left"]]
                      + io.query_block_reference ** brick_map[trial["Q_Brick_Right"]])
    elif t_type == "train" and q_type == "query":
        # choose 2 blocks randomly
        blocks = np.random.choice([1, 2, 3, 4], 2, replace=False)
        # choose a relation randomly
        relation = np.random.choice([1, 2, 3, 4], 1)
        chunk_test = ( + io.query_relation ** query_map[relation] 
                      + io.query_block ** brick_map[blocks[0]]
                      + io.query_block_reference ** brick_map[blocks[1]])
    else: chunk_test = ()
    return chunk_grid, chunk_test

In [None]:
#response rules when given a query and some blocks

# chunk_defs = [
#     + io.input ** color.blu
#     + io.input ** shape.tria
#     + io.input ** number.one, # blue triangle

#     + io.input ** color.red
#     + io.input ** shape.tria
#     + io.input ** number.one,

#     + io.input ** color.grn
#     + io.input ** shape.squr
#     + io.input ** number.two,
# ]

# chunks.compile(*chunk_defs)

In [8]:
while agent.system.queue:
    event = agent.system.advance()
    # print("my bug: ", agent.system.queue)

event 0x0000 00:00:00.00 096 0 chunks.compile
    Added the following new chunk(s)
    chunk d:chunks:_0
        + io.input ** color.blu
        + io.input ** shape.tria
        + io.input ** number.one
    chunk d:chunks:_1
        + io.input ** color.red
        + io.input ** shape.tria
        + io.input ** number.one
    chunk d:chunks:_2
        + io.input ** color.grn
        + io.input ** shape.squr
        + io.input ** number.two
event 0x0000 00:00:00.00 096 2 chunks.compile_weights
event 0x0000 00:00:00.00 064 1 data_in.send
event 0x0000 00:00:00.00 064 3 chunks.bu.update
event 0x0000 00:00:00.00 064 4 chunks.update


In [11]:
print(chunks.bu.main[0])

NumDict 'd:chunks:?' c=0.0
    d:chunks:_0 0.75
    d:chunks:_1 0.5
    d:chunks:_2 -0.25
