In [1]:
import random
import os
import json
import datetime
import zipfile
from collections import Counter
import scipy.stats
import numpy as np
from tqdm.notebook import tqdm


from markdown import markdown
import textwrap
import pprint
from itertools import permutations, cycle, product
import copy

from construal_shifting import gridutils, sampsat
from construal_shifting.utils import maze_code
from msdm.domains import GridWorld

from frozendict import frozendict

## Experiment variables

In [2]:
test_mazes = json.load(open("./mazes/test_mazes.json", 'r'))
test_mazes = {
    f"{r['idx']}-test": r['maze']
    for r in test_mazes
}
assert len(test_mazes) == 4

training_mazes = json.load(open("./mazes/training_mazes.json", 'r'))
coarse_train_mazes = {
    f"{r['idx']}-{r['bias']}-train": r['maze'] 
    for r in training_mazes if r['bias'] == 'coarse'
}
assert len(coarse_train_mazes) == 8
fine_train_mazes = {
    f"{r['idx']}-{r['bias']}-train": r['maze'] 
    for r in training_mazes if r['bias'] == 'fine'
}
assert len(fine_train_mazes) == 8

In [3]:
PSITURKAPP_CONFIG_DIR = "../../psiturkapp/static/config/"
EXPERIMENT_CODE_VERSION = "shifting-exp-1"
MAIN_OBSTACLE_COLOR = "rgba(173, 216, 230, 1.)"
BROKEN_OBSTACLE_COLOR = "rgba(173, 216, 230, .5)"
initialPoints = 100
dollarsPerPoint = .1/100
default_trialparams = {
    "navigationText": "&nbsp;",

    "participantStarts": True,
    "showPoints": False,
    "initialPoints": initialPoints,
    "dollarsPerPoint": dollarsPerPoint,
    "goalCountdown": True,
    "hideObstaclesOnMove": False,
    "hideBrokenPiecesOnMove": False,

    "TILE_SIZE": 40,
    "INITIALGOAL_COUNTDOWN_MS": 120000,
    "GOAL_COUNTDOWN_MS": 1000
}
default_taskparams = {
    "absorbing_features":["$",],
    "wall_features":["#", 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'],
    "default_features":[".",],
    "initial_features":["@",],
    "feature_rewards": {
        "$": 0
    },
    "step_cost": -1,
    "wall_bump_cost": -10,
    "feature_colors": {
        "@": "white",
        ".": "white",
        **{c: MAIN_OBSTACLE_COLOR for c in "ABCDEFG"},
        **{c: BROKEN_OBSTACLE_COLOR for c in "abcdefg"},
        '#': 'black',
        '$': 'yellow'
    }
}


## Functions/objects for generating code

In [4]:
def generate_latin_square(n, rng):
    orders = list(permutations(range(n)))
    def gen():
        sq = np.array(rng.sample(orders, k=n))
        for col in sq.T:
            sampsat.condition(len(set(col)) == len(col))
        return [list(row) for row in sq]
    return sampsat.rejection(gen, debug=True)

def latin_square_cycle(vals, rng):
    while True:
        lsquare = generate_latin_square(len(vals), rng)
        for iseq in lsquare:
            yield [vals[i] for i in iseq]

def block_shuffle_cycle(vals, rng, no_contiguous_index=True):
    last_idx = None
    while True:
        idxs = list(range(len(vals)))
        while True:
            rng.shuffle(idxs)
            if not no_contiguous_index:
                break
            if idxs[0] != last_idx:
                break
        yield from [vals[i] for i in idxs]
        last_idx = idxs[-1]

def cycle_n(arr, n):
    for i in range(n):
        yield arr[i % len(arr)]

def swap_start_goal_func(tile_array):
    tile_array = tile_array.replace("@", "!")
    tile_array = tile_array.replace("$", "@")
    tile_array = tile_array.replace("!", "$")
    return tile_array

def feature_locations(tile_array, f):
    gw = GridWorld(
        tile_array=tile_array
    )
    return frozenset(gw.feature_locations[f])

In [5]:
def countdict_to_array(countdict):
    dim_vals = [set([]) for _ in next(iter(countdict.keys()))]
    for kvals in countdict.keys():
        for i, kval in enumerate(kvals):
            dim_vals[i].add(kval)
    dim_vals = [dict(zip(sorted(dv), range(len(dv)))) for dv in dim_vals]

    countdict_arr = np.zeros([len(dim) for dim in dim_vals])
    for kvals, count in countdict.items():
        idx = tuple([dim_vals[ki][k] for ki, k in enumerate(kvals)])
        countdict_arr[idx] = count
    return countdict_arr

In [6]:
from dataclasses import dataclass

@dataclass
class ExperimentGenerator:
    rng : random.Random
    EXPERIMENT_CODE_VERSION : str
    cutoff_time_min : str
    expectedTime : str
    preloadImages : list
    MAIN_OBSTACLE_COLOR : str

    default_trialparams : dict
    default_taskparams : dict
    fine_train_mazes : dict
    coarse_train_mazes : dict
    test_mazes : dict

    n_train_orders_per_cond : int
    transform_names = ['base', 'rot90', 'rot180', 'rot270', 'vflip', 'hflip', 'trans', 'rtrans']
    reversed_training : bool
    test_swap_start_goals : bool
    
    def generate_experiment(self):
        timelines = []
        timeline_gen_param_spaces = dict(
            train_cond=[
                "ffffffffffff", "cccccccccccc"
            ],
            train_order_i=list(range(self.n_train_orders_per_cond)),
        )
        tlg_param_names, tlg_param_spaces = zip(*timeline_gen_param_spaces.items())
        
        def double_blocked_latin_square(vals, rng):
            block1_cycle = latin_square_cycle(vals, rng)
            block2_cycle = latin_square_cycle(vals, rng)
            while True:
                yield next(block1_cycle) + next(block2_cycle)
        
        test_maze_name_orders = double_blocked_latin_square(sorted(self.test_mazes.keys()), self.rng)
        
        for tlg_param_assn in tqdm(list(product(*tlg_param_spaces))):
            tlg_params = dict(zip(tlg_param_names, tlg_param_assn))
            test_maze_name_order = next(test_maze_name_orders)
            tlg = TimelineGenerator(
                exp=self,
                **tlg_params,
                test_maze_name_order=test_maze_name_order
            )
            timelines.extend(tlg.generate_timelines())
            
        return {
            "params": {
                "cutoff_time_min": self.cutoff_time_min,
                "expectedTime": self.expectedTime,
                "recruitment_platform": "prolific",
                "EXPERIMENT_CODE_VERSION": self.EXPERIMENT_CODE_VERSION,
            },
            "preloadImages" : self.preloadImages,
            "timelines" : timelines
        }
        


In [7]:
@dataclass
class TimelineGenerator:
    exp : ExperimentGenerator
    train_cond : str
    train_order_i : int
    test_maze_name_order : list

    def generate_timelines(self):
        training, test = self.main_trials()
        start = [
            *self.setup(),
            *self.tutorial(),
            *self.comprehension_check(),
        ]
        end = [
            *self.post_task_questions(),
            *self.tear_down()
        ]
        timelines = []
        timelines.append([*start, *training, *test, *end])
        if self.exp.reversed_training:
            timelines.append([*start, *training[::-1], *test, *end])
        return timelines
    
    def main_trials(self):
        def gen():
            training = self.training_trials()
            test = self.test_trials()

            no_repeated_start = sampsat.no_repeat([feature_locations(t['taskparams']['tile_array'], '@') for t in training + test])
            sampsat.condition(no_repeated_start)
            if self.exp.reversed_training:
                no_repeated_start = sampsat.no_repeat([feature_locations(t['taskparams']['tile_array'], '@') for t in training[::-1] + test])
                sampsat.condition(no_repeated_start)
            
            return training, test
        
        training, test = sampsat.rejection(
            func=gen,
            debug=True
        )
        return training, test
    
    def training_trials(self):
        rng = self.exp.rng
        coarse_maze_names_cyc = block_shuffle_cycle(
            list(self.exp.coarse_train_mazes.keys()),
            rng=rng,
            no_contiguous_index=False
        )
        fine_maze_names_cyc = block_shuffle_cycle(
            list(self.exp.fine_train_mazes.keys()),
            rng=rng,
            no_contiguous_index=False
        )
        train_maze_names = []
        for maze_type in self.train_cond:
            if maze_type == 'f':
                train_maze_names.append(next(fine_maze_names_cyc))
            elif maze_type == 'c':
                train_maze_names.append(next(coarse_maze_names_cyc))
            else:
                raise
        transforms = block_shuffle_cycle(self.exp.transform_names, rng=rng)
        swap_start_goals = block_shuffle_cycle([True, True, False, False], rng=rng, no_contiguous_index=False)
        
        train_mazes = {**self.exp.coarse_train_mazes, **self.exp.fine_train_mazes}
        trial_configs = []
        for grid_name in train_maze_names:
            trial_config = self.create_jsPsych_trial_config(
                transform=next(transforms),
                swap_start_goal=next(swap_start_goals),
                tile_array=train_mazes[grid_name],
                grid_name=grid_name,
                roundtype=f"{self.train_cond}_training",
            )
            trial_configs.append(trial_config)
        return trial_configs
        
    def test_trials(self):
        rng = self.exp.rng
        
        transforms = block_shuffle_cycle(self.exp.transform_names, rng=rng)
        if self.exp.test_swap_start_goals:
            swap_start_goals = block_shuffle_cycle([True, True, False, False], rng=rng, no_contiguous_index=False)
        else:
            swap_start_goals = cycle([False])
        
        trial_configs = []
        for maze_name in self.test_maze_name_order:
            trial_config = self.create_jsPsych_trial_config(
                transform=next(transforms),
                swap_start_goal=next(swap_start_goals),
                tile_array=self.exp.test_mazes[maze_name],
                grid_name=maze_name,
                roundtype=f"test",
            )
            trial_configs.append(trial_config)
        return trial_configs
    
    def create_jsPsych_trial_config(
        self,
        transform,
        swap_start_goal,
        tile_array,
        **kws
    ):
        """Convert trial parameters to jsPsych plugin trial configuration"""
        tile_array = '\n'.join(tile_array)
        if swap_start_goal:
            tile_array = swap_start_goal_func(tile_array)
        tile_array = [r.strip() for r in tile_array.split('\n') if len(r.strip())]
        tile_array = getattr(gridutils.transform_grid, transform)(tile_array)
        tconfig = {
            "type": "GridNavigation",
            "trialparams": {
                **self.exp.default_trialparams,
                **{
                    "transform": transform,
                    "swap_start_goal": swap_start_goal,
                },
                **kws
            },
            "taskparams": {
                **self.exp.default_taskparams,
                "tile_array": tile_array,
            }
        }
        return tconfig

    @property
    def condition(self):
        return f"{self.train_cond}_{self.test_cond}"
    
    def tear_down(self):
        return [
            {
                "type": "SaveGlobalStore",
                "condition_name": self.train_cond,
            },
            {
                "type": "fullscreen",
                "fullscreen_mode": False
            },
        ]
    
    def setup(self):
        return [
            {
                "type": "reCAPTCHA"
            },
            {
                "type": "fullscreen",
                "fullscreen_mode": True
            },
        ]
    
    def tutorial(self):
        default_trialparams = self.exp.default_trialparams
        default_taskparams = self.exp.default_taskparams
        initialPoints = default_trialparams['initialPoints']
        dollarsPerPoint = default_trialparams['dollarsPerPoint']
        step_cost = -default_taskparams['step_cost']
        wall_bump_cost = -default_taskparams['wall_bump_cost']
        return [
            {
                "type": "CustomInstructions",
                "instructions": markdown(textwrap.dedent(f"""
                    # Instructions
                    Thank you for participating in our experiment!

                    You will play a game where you control a blue circle on a grid.
                    You can move up, down, left, or right by pressing the __arrow keys__⬆️⬇️⬅️➡️.

                    <img src="static/images/bluedotgrid.png" width="150px">

                     The <span style='background-color: yellow;'><b>Yellow</b></span> tile with
                     the <span style="color: green"><b>green</b></span>
                     square is the goal 👀.

                    <img src="static/images/goalsquare.png" width="150px">

                    Before you take your first move, the green square not shrink.
                    Once you make your first move, it will shrink quickly whenever you stand still
                    and reset when you move.

                    <br>

                    __Black__ tiles are walls that you cannot pass through ⛔️.

                    <br>

                    <b>Blue</b> tiles are
                    obstacles that might change
                    between different rounds. You cannot pass through these either 🚫.
                """)),
                "timing_post_trial": 1000,
                "continue_wait_time": 5000,
            },
            {
                "type": "GridNavigation",
                "trialparams": {
                    **default_trialparams,
                    "round": 0,
                    "roundtype": "practice",
                    "navigationText": textwrap.dedent(f"""
                        Get to the <span style='background-color: yellow;'>Yellow</span> goal. <br>
                        You cannot go through <span style='background-color: black;color: white'>Black</span> or
                        Blue tiles.
                    """),

                    "showPoints": False,
                    "initialPoints": 100,
                    "dollarsPerPoint": 0.0,
                },
                "taskparams": {
                    **default_taskparams,
                    "tile_array": [
                        '............$',
                        '.............',
                        '.............',
                        '.....###.....',
                        '.............',
                        '...#.....#CCC',
                        'AAA#BBB..#CCC',
                        'AAA#BBB..#CCC',
                        'AAA.BBB......',
                        '.....###.....',
                        '.............',
                        '.............',
                        '@............',
                    ],
                }
            },
            {
                "type": "CustomInstructions",
                "instructions": markdown(textwrap.dedent(f"""
                    # Instructions
                    Sometimes the <b>blue</b> obstacles will have
                    pieces broken off. These are lighter parts, and look like this:

                    <img src="static/images/broken_obstacle.png" width="350px">

                    Broken parts of obstacles do <b>not</b> block you, but unbroken parts still block.
                """)),
                "timing_post_trial": 1000,
                "continue_wait_time": 5000,
            },
            {
                "type": "GridNavigation",
                "trialparams": {
                    **default_trialparams,
                    "round": 1,
                    "roundtype": "practice",
                    "navigationText": textwrap.dedent(f"""
                        Get to the <span style='background-color: yellow;'>Yellow</span> goal. <br>
                        You cannot go through <span style='background-color: black;color: white'>Black</span> or
                        Blue obstacles
                        except for broken parts (lighter parts).
                    """),

                    "showPoints": False,
                    "initialPoints": 100,
                    "dollarsPerPoint": 0.0,
                },
                "taskparams": {
                    **default_taskparams,
                    "tile_array": [
                        '....AAA.....@',
                        '....AAA......',
                        '....aaa......',
                        '.....###.....',
                        '......bBB....',
                        '...#..BBB#...',
                        '...#..bbB#...',
                        '...#.....#CcC',
                        '..........CcC',
                        '.....###..CCc',
                        '.............',
                        '.............',
                        '$............',
                    ],
                }
            },
            {
                "type": "CustomInstructions",
                "instructions": markdown(textwrap.dedent(f"""
                    # Instructions
                    Great! In the main part of the experiment, you will start with <b>{initialPoints}</b> points
                    on each trial. Each step costs <b>{step_cost} {"point" if step_cost == 1 else "points"}</b> and crashing
                    into a wall or block costs <b>{wall_bump_cost} {"point" if wall_bump_cost == 1 else "points"}</b>.

                    If the green square disappears completely, you will receive <b>ZERO</b> points for that trial ☹️.

                    At the end of the experiment, we will add up all your points and calculate a bonus.<br>
                    <b>{initialPoints} points is worth {int(dollarsPerPoint*initialPoints*100)} cents</b>.

                    Next, you will do practice rounds where we show you your points
                    <br>(these will not be included in your bonus).
                """)),
                "timing_post_trial": 1000,
                "continue_wait_time": 5000,
            },
            {
                "type": "GridNavigation",
                "trialparams": {
                    **default_trialparams,
                    "round": 4,
                    "roundtype": "practice",
                    "navigationText": f"""
                        Each step costs <b>{step_cost} {"point" if step_cost == 1 else "points"}</b> and bumping
                        into a wall costs <b>{wall_bump_cost} {"point" if wall_bump_cost == 1 else "points"}</b>.
                    """,

                    "showPoints": True,
                    "initialPoints": 100,
                    "dollarsPerPoint": 0.0,
                },
                "taskparams": {
                    **default_taskparams,
                    "tile_array": [
                        '...AAa......@',
                        '...aAA.......',
                        '...aAA.......',
                        '.....###.....',
                        '.............',
                        '...#.....#...',
                        '...#bbB..#...',
                        '...#BBb..#...',
                        '....BBB......',
                        '.....###.....',
                        '......CCc....',
                        '......CCc....',
                        '$.....CCc....',
                    ],
                }
            },
            {
                "type": "GridNavigation",
                "trialparams": {
                    **default_trialparams,
                    "round": 5,
                    "roundtype": "practice",
                    "navigationText": f"""
                        Each step costs <b>{step_cost} {"point" if step_cost == 1 else "points"}</b> and bumping
                        into a wall costs <b>{wall_bump_cost} {"point" if wall_bump_cost == 1 else "points"}</b>.
                    """,

                    "showPoints": True,
                    "initialPoints": 100,
                    "dollarsPerPoint": 0.0,
                },
                "taskparams": {
                    **default_taskparams,
                    "tile_array": [
                        '@............',
                        '.........AaA.',
                        '.........AAa.',
                        '.....###.AaA.',
                        '....bBB......',
                        '...#BBB..#...',
                        'CCc#Bbb..#...',
                        'cCc#.....#...',
                        'CCC..........',
                        '.....###.....',
                        '.............',
                        '.............',
                        '............$',
                    ],
                }
            },
        ]
        
    def comprehension_check(self):
        default_trialparams = self.exp.default_trialparams
        default_taskparams = self.exp.default_taskparams
        initialPoints = default_trialparams['initialPoints']
        dollarsPerPoint = default_trialparams['dollarsPerPoint']
        step_cost = -default_taskparams['step_cost']
        wall_bump_cost = -default_taskparams['wall_bump_cost']
        return [
            dict(
                type="CustomSurvey",
                preamble=markdown(textwrap.dedent(f"""
                    # Instructions

                    In the main part of the experiment, we will give you a series of mazes to navigate.
                    Try to reach the goal <b>without the green square disappearing</b>!
                    Remember, once you take your first move,
                    the green square shrinks when you stand still and resets when you move.

                    <br>

                    Remember that you start with <b>{initialPoints}</b> points on each
                    round and it costs points to move and even more points to bump into walls. If the
                    green square disappears completely, you always receive zero points on that round. <b>Points
                    are converted to a bonus at the end.</b>

                    <hr>
                    To continue, you must answer the following comprehension questions correctly within <b><u>2 tries</b></u>.
                    """)),
                maxAttempts=2,
                questions=[
                    {
                        "prompt": "When does the green square shrink?",
                        "options": ["When I stand still", "It does not shrink"],
                        "required": True,
                        "requireCorrect": True,
                        "correct": "When I stand still",
                        "name": "greenSquareShrinkCheck2",
                        "type": "multiple-choice"
                    },
                    {
                        "prompt": "The green square does not shrink before I start moving",
                        "options": ["True", "False"],
                        "required": True,
                        "requireCorrect": True,
                        "correct": "True",
                        "name": "greenSquareShrinkCheck1",
                        "type": "multiple-choice"
                    },
                    {
                        "prompt": "The blue obstacles:",
                        "options": ["Always block you", "Never block you", "Have broken pieces that do not block you"],
                        "required": True,
                        "requireCorrect": True,
                        "correct": "Have broken pieces that do not block you",
                        "name": "brokenObstacleCheck",
                        "type": "multiple-choice"
                    },
                    {
                        "prompt": "You lose extra points for bumping into things:",
                        "options": ["Yes", "No", "Sometimes"],
                        "required": True,
                        "requireCorrect": True,
                        "correct": "Yes",
                        "name": "bumpingCost",
                        "type": "multiple-choice"
                    },
                    {
                        "prompt": "You can still win points for a round even if the green square disappears.",
                        "options": ["Yes", "No", "Sometimes"],
                        "required": True,
                        "requireCorrect": True,
                        "correct": "No",
                        "name": "greenSquareCost",
                        "type": "multiple-choice"
                    },
                ]
            )
        ]
    
    def post_task_questions(self):
        return [
            {
                "type": 'CustomSurvey',
                "questions": [
                  {
                      "prompt": "Any general comments on how you performed the task?",
                      "required": True,
                      "name": "generalComments",
                      "type": "textbox",
                      "rows": 5,
                      "columns":50,
                  },
                ],
            },
            {
                "type": 'CustomSurvey',
                "questions": [
                  {
                      "prompt": "Did you think about the green square while moving?",
                      "required": True,
                      "options": ["Yes", "No"],
                      "name": "greenSquareThink",
                      "type": "multiple-choice"
                  },
                  {
                      "prompt": "Did you think about winning points when doing this task?",
                      "required": True,
                      "options": ["Yes", "No"],
                      "name": "winPointsThink",
                      "type": "multiple-choice"
                  },
                  {
                      "prompt": "Did you think about broken parts of blocks when doing this task?",
                      "required": True,
                      "options": ["Yes", "No"],
                      "name": "brokenBlocksThink",
                      "type": "multiple-choice"
                  },
                  {
                      "prompt": "Age",
                      "required": True,
                      "name": "age",
                      "type": "textbox",
                      "rows": 1,
                      "columns":10,
                  },
                  {
                      "prompt": "Gender",
                      "required": True,
                      "name": "gender",
                      "type": "textbox",
                      "rows": 1,
                      "columns":10,
                  },
                ],
            },
        ]

## Sanity checking code
- confirm the number of practice, training, test trials
- check number of distinct mazes used
- check repeating start/end states
- test that index in which a grid appears is unbiased
- test that trial index and condition are unbiased
- test that transformations are unbiased
- test that training/test trial index is unbiased
- test that each timeline is unique
- redundancy-check expected conditions 

In [8]:
from types import SimpleNamespace
from collections import defaultdict
class ConfigWrapper:
    def __init__(self, config):
        self.config = config
    
    def counts(self):
        counts = SimpleNamespace(
            train_cond_grid_idx=defaultdict(int),
            test_cond_grid_idx=defaultdict(int),
            train_transforms=defaultdict(int),
            test_transforms=defaultdict(int)
        )
        for tl in self.config['timelines']:
            tl = TimelineWrapper(tl)
            tl_counts = tl.counts()
            for k, v in tl_counts.train_cond_grid_idx.items():
                counts.train_cond_grid_idx[k] += v
            for k, v in tl_counts.test_cond_grid_idx.items():
                counts.test_cond_grid_idx[k] += v
            for k, v in tl_counts.train_transforms.items():
                counts.train_transforms[k] += v
            for k, v in tl_counts.test_transforms.items():
                counts.test_transforms[k] += v
        return counts

    def timelines(self):
        tls = []
        for tl in self.config['timelines']:
            tls.append(TimelineWrapper(tl))
        return tls
    
class TimelineWrapper:
    def __init__(self, timeline):
        self.timeline = timeline
    
    def trials(self):
        trials = []
        for ti, trial in enumerate(self.timeline):
            trials.append(jsPsychTrial(trial, trial_index=ti))
        return trials
    
    def main_trials(self):
        trials = []
        for trial in self.trials():
            if trial.is_GridNavigation() and trial.roundtype() != 'practice':
                trials.append(trial)
        return trials
    
    def training_trials(self):
        return [t for t in self.main_trials() if '_training' in t.roundtype()]
    
    def test_trials(self):
        return [t for t in self.main_trials() if 'test' in t.roundtype()]

    def practice_trials(self):
        return [t for t in self.trials() if t.is_GridNavigation() and 'practice' in t.roundtype()]
    
    def condition(self):
        for trial in self.trials():
            if trial.type() == "SaveGlobalStore":
                return trial.config['condition_name']
    
    def condition2(self):
        # check condition based on secondary criteria
        return ''.join([t.grid_type()[0] for t in self.training_trials()])
    
    def main_trial_param_seq(self):
        return tuple([
            t.GridNavigation_params() for t in self.main_trials()
        ])
    
    def counts(self):
        counts = SimpleNamespace(
            train_transforms=defaultdict(int),
            test_transforms=defaultdict(int),
            train_grid_names=defaultdict(int),
            test_grid_names=defaultdict(int),
            train_cond_grid_idx=defaultdict(int),
            test_cond_grid_idx=defaultdict(int),
        )
        for trial in self.training_trials():
            counts.train_transforms[trial.transform()] += 1
            counts.train_grid_names[trial.grid_name()] += 1
            counts.train_cond_grid_idx[(self.condition(), trial.grid_idx(), trial.trial_index())] += 1
        
        for trial in self.test_trials():
            counts.test_transforms[trial.transform()] += 1
            counts.test_grid_names[trial.grid_name()] += 1
            counts.test_cond_grid_idx[(self.condition(), trial.grid_idx(), trial.trial_index())] += 1
        return counts
    
class jsPsychTrial:
    def __init__(self, config, trial_index=None):
        self.config = config
        self._trial_index = trial_index
    
    def type(self):
        return self.config['type']
    
    def is_GridNavigation(self):
        return self.config['type'] == 'GridNavigation'
    
    def grid_name(self):
        return self.config['trialparams'].get('grid_name', None)
    
    def grid_type(self):
        return self.grid_name().split('-')[1]
    
    def grid_idx(self):
        return int(self.config['trialparams']['grid_name'].split('-')[0])
    
    def trial_index(self):
        return self._trial_index
    
    def roundtype(self):
        return self.config['trialparams'].get('roundtype', None)
    
    def start_loc(self):
        return feature_locations(self.config['taskparams']['tile_array'], '@')
    
    def goal_loc(self):
        return feature_locations(self.config['taskparams']['tile_array'], '$')
    
    def transform(self):
        return self.config['trialparams'].get('transform')
    
    def swap_start_goal(self):
        return self.config['trialparams'].get('swap_start_goal')
    
    def GridNavigation_params(self):
        return tuple([
            self.grid_name(),
            self.transform(),
            self.swap_start_goal()
        ])


In [9]:
def sanity_checks(config):
    EXP_TRAINING_TRIALS = 12
    EXP_TEST_TRIALS = 8
    EXP_PRACTICE_TRIALS = 4
    EXP_UNIQUE_TRAINING_MAZES = 8
    EXP_UNIQUE_TEST_MAZES = 4
    EXP_COND_COUNT = {
        "ffffffffffff": 64*2, 
        "cccccccccccc": 64*2,
    }

    def test_timeline(tl):
        assert len(tl.training_trials()) == EXP_TRAINING_TRIALS
        assert len(tl.test_trials()) == EXP_TEST_TRIALS
        assert len(tl.practice_trials()) == EXP_PRACTICE_TRIALS, len(tl.practice_trials())

        unique_training_mazes = set([t.grid_name() for t in tl.training_trials()])
        assert len(unique_training_mazes) == EXP_UNIQUE_TRAINING_MAZES, (len(unique_training_mazes), EXP_UNIQUE_TRAINING_MAZES)
        unique_test_mazes = set([t.grid_name() for t in tl.test_trials()])
        assert len(unique_test_mazes) == EXP_UNIQUE_TEST_MAZES

        start_locs = [t.start_loc() for t in tl.main_trials()]
        assert all([s1 != s2 for s1, s2 in zip(start_locs, start_locs[1:])])
        goal_locs = [t.goal_loc() for t in tl.main_trials()]
        assert all([s1 != s2 for s1, s2 in zip(goal_locs, goal_locs[1:])])

    cf = ConfigWrapper(config)

    print("Testing expected counts in individual timelines")
    condition_count = Counter()
    timeline_param_seqs = []
    for tl in cf.timelines():
        print(".", end="")
        test_timeline(tl)
        timeline_param_seqs.append(tl.main_trial_param_seq)

        # redundancy check conditions
        assert tl.condition() == tl.condition2(), (tl.condition(), tl.condition2())
        condition_count[tl.condition()] += 1
    print()
    
    print("Testing that all main trial sequences are unique across timelines")
    assert len(timeline_param_seqs) == len(set(timeline_param_seqs))
    print(f"{len(timeline_param_seqs)} unique seq")

    print("Testing expected condition counts")
    print(condition_count)
    assert condition_count == EXP_COND_COUNT, condition_count

    cf_counts = cf.counts()
    
    print("Testing that training maze index and trial index are unbiased")
    train_arr = countdict_to_array(cf_counts.train_cond_grid_idx)
    for i in range(train_arr.shape[1]): # iterating over each index -> cond x trialidx
        maze_idx_table = train_arr[:, i, :] # cond, grididx, trialidx
        chi2stat, pval = scipy.stats.chi2_contingency(maze_idx_table)[:2]
        print(maze_idx_table, f"chisq={chi2stat:.2f}; p = {pval:.2f}")
        assert pval > .05

    print("Testing that test maze index and trial index are unbiased")
    test_arr = countdict_to_array(cf_counts.test_cond_grid_idx)
    for i in range(test_arr.shape[1]):
        maze_idx_table = test_arr[:, i, :]
        chi2stat, pval = scipy.stats.chi2_contingency(maze_idx_table)[:2]
        print(maze_idx_table, f"chisq={chi2stat:.2f}; p = {pval:.2f}")
        assert pval > .05

    print("Testing that uniform number of transforms are used")
    res = scipy.stats.chisquare(list(cf.counts().train_transforms.values()))
    assert res.pvalue > .05
    res = scipy.stats.chisquare(list(cf.counts().test_transforms.values()))
    assert res.pvalue > .05

## Generate experiment config code

In [10]:
expgen = ExperimentGenerator(
    rng=random.Random(233112),
    EXPERIMENT_CODE_VERSION=EXPERIMENT_CODE_VERSION,
    cutoff_time_min=45,
    expectedTime="10 minutes",
    preloadImages=[
        "static/images/bluedotgrid.png",
        "static/images/goalsquare.png",
        "static/images/green_goal.png",
        "static/images/broken_obstacle.png"
    ],
    MAIN_OBSTACLE_COLOR=MAIN_OBSTACLE_COLOR,
    
    default_trialparams=default_trialparams,
    default_taskparams=default_taskparams,
    fine_train_mazes=fine_train_mazes,
    coarse_train_mazes=coarse_train_mazes,
    test_mazes=test_mazes,
    
    n_train_orders_per_cond=64,
    reversed_training=True,
    test_swap_start_goals=True
)
config = expgen.generate_experiment()
config["params"]["creation_datetime"] = datetime.datetime.now().strftime("%b %d, %Y %I:%M%p")

  0%|          | 0/128 [00:00<?, ?it/s]

427 runs
901 runs
16 runs
85 runs
70 runs
2 runs
125 runs
137 runs
13 runs
58 runs
1029 runs
29 runs
1030 runs
801 runs
397 runs
258 runs
274 runs
721 runs
847 runs
363 runs
108 runs
218 runs
519 runs
24 runs
392 runs
512 runs
6 runs
12 runs
107 runs
32 runs
159 runs
81 runs
246 runs
1 runs
131 runs
48 runs
1883 runs
578 runs
87 runs
571 runs
38 runs
33 runs
87 runs
867 runs
279 runs
218 runs
613 runs
15 runs
67 runs
242 runs
8 runs
101 runs
478 runs
75 runs
496 runs
93 runs
1183 runs
343 runs
77 runs
390 runs
110 runs
23 runs
222 runs
10 runs
63 runs
308 runs
329 runs
286 runs
89 runs
148 runs
2 runs
344 runs
635 runs
274 runs
27 runs
50 runs
329 runs
74 runs
481 runs
159 runs
215 runs
124 runs
94 runs
25 runs
260 runs
735 runs
51 runs
394 runs
8 runs
84 runs
22 runs
552 runs
119 runs
89 runs
25 runs
229 runs
405 runs
518 runs
74 runs
269 runs
78 runs
166 runs
701 runs
257 runs
165 runs
80 runs
168 runs
175 runs
1976 runs
822 runs
293 runs
69 runs
316 runs
80 runs
341 runs
895 runs
10

In [11]:
print("-> Sanity checking configuration file")
sanity_checks(config)

-> Sanity checking configuration file
Testing expected counts in individual timelines
................................................................................................................................................................................................................................................................
Testing that all main trial sequences are unique across timelines
256 unique seq
Testing expected condition counts
Counter({'ffffffffffff': 128, 'cccccccccccc': 128})
Testing that training maze index and trial index are unbiased
[[15. 16. 20. 16. 14. 14. 14. 14. 16. 20. 16. 15.]
 [14. 11. 20. 17. 15. 20. 20. 15. 17. 20. 11. 14.]] chisq=4.13; p = 0.97
[[19. 16. 17. 17. 17. 14. 14. 17. 17. 17. 16. 19.]
 [12. 17. 17. 13. 18. 15. 15. 18. 13. 17. 17. 12.]] chisq=3.75; p = 0.98
[[17. 17. 15. 15. 19. 12. 12. 19. 15. 15. 17. 17.]
 [19. 13. 13. 19. 18. 14. 14. 18. 19. 13. 13. 19.]] chisq=2.87; p = 0.99
[[18. 14. 27. 15. 16. 14. 14. 16. 15. 27. 14. 18.]
 [21.

In [12]:
print(f"-> Saving configuration file in {PSITURKAPP_CONFIG_DIR}")
json.dump(
    config,
    open(f"config.json", "w"),
    separators=(",", ":")
)
zipfile.ZipFile(
    PSITURKAPP_CONFIG_DIR+"config.json.zip",
    mode="w",
    compression=zipfile.ZIP_DEFLATED
).write("config.json")
zipfile.ZipFile(
    PSITURKAPP_CONFIG_DIR+EXPERIMENT_CODE_VERSION+"-config.json.zip",
    mode="w",
    compression=zipfile.ZIP_DEFLATED
).write("config.json")
pprint.pprint(config["params"])

-> Saving configuration file in ../../psiturkapp/static/config/
{'EXPERIMENT_CODE_VERSION': 'shifting-exp-1',
 'creation_datetime': 'Sep 12, 2022 09:52AM',
 'cutoff_time_min': 45,
 'expectedTime': '10 minutes',
 'recruitment_platform': 'prolific'}
