In [1]:
import numpy as np
from scipy import interpolate
from matplotlib import pyplot as plt

import threading
import time

from demo_util import gabor_util
ContrastLandscape = gabor_util.ContrastLandscape

In [5]:
from functools import wraps
import ipywidgets as widgets


grid_size = 16
cell_size = 32
kernel_size = 65
scale = 4
wavelength = 8
start_distance = 128
num_points = 12
path_angle = np.deg2rad(45)
angle_noise = np.deg2rad(10)
random_phase = True
contrast_landscape = ContrastLandscape.FIXED
contrast_grid_size = (10, 10)
min_contrast = 0.
max_contrast = 1.
contrast_epsilon = 0.4
contrast_smooth = 0.

align_phase_prob = 0.5
has_path_prob = 0.5


def sample_image(has_path, align_phase):
    _, bg_path_image, _, _ = gabor_util.generate_images(
        seed=None,
        grid_size=grid_size,
        cell_size=cell_size,
        kernel_size=kernel_size,
        scale=scale,
        wavelength=wavelength,
        start_distance=start_distance,
        num_points=num_points if has_path else 0,
        path_angle=path_angle,
        angle_noise=angle_noise,
        random_phase=random_phase,
        align_phase=align_phase,
        contrast_landscape=contrast_landscape,
        contrast_grid_size=contrast_grid_size,
        min_contrast=min_contrast,
        max_contrast=max_contrast,
        generate_contrast_image=False,
        contrast_epsilon=contrast_epsilon,
        contrast_smooth=contrast_smooth,
    )
    
    return bg_path_image


def sample_images_and_labels():
    path_index = np.random.randint(2)
    align_phase = np.random.rand() > align_phase_prob
    condition = "aligned" if align_phase else "not-aligned"
    
    gray_image = np.zeros((grid_size*cell_size, grid_size*cell_size))
    images = [
        sample_image(path_index == 0, align_phase),
        gray_image,
        sample_image(path_index == 1, align_phase),
        gray_image,
    ]
    return images, path_index, condition
    

class ExperimentState:
    
    def __init__(self, max_iteration=3):
        self.reset()
         
    @property
    def iteration(self):
        return self._correct + self._wrong
        
    def add_correct(self):
        self._correct = self._correct + 1
        
    def add_wrong(self):
        self._wrong = self._wrong + 1
        
    def reset(self):
        self._correct = 0
        self._wrong = 0
        
    def __str__(self):
        return "correct: {0:d} ({1:.1%}), wrong: {2:d} ({3:.1%})".format(
            self._correct,
            self._correct/(1 if self.iteration == 0 else self.iteration),
            self._wrong,
            self._wrong/(1 if self.iteration == 0 else self.iteration),
        )
    

class Experiment:
    
    def __init__(self, max_iteration=3):
        self._max_iter = max_iteration
        self._conditions = {
            "aligned": ExperimentState(),
            "not-aligned": ExperimentState(),
        }
        self._images = None
        self._answer = None
        self._cond = None
            
    @property
    def images(self):
        return self._images
    
    @property
    def iteration(self):
        return sum([cond.iteration for cond in self._conditions.values()])
    
    def set_max_iter(self, value):
        self._max_iter = value

    def reset(self):
        for cond in self._conditions.values():
            cond.reset()
        
    def should_stop(self):
        return self.iteration >= self._max_iter
    
    def record_answer(self, path_index):
        cond = self._conditions[self._cond]
        if self._answer == path_index:
            cond.add_correct()
        else:
            cond.add_wrong()

    def update_images(self, images, answer, condition):
        self._images = images
        self._answer = answer
        self._cond = condition
        
    def __str__(self):
        return "\n".join([
            "{}: {}".format(name, cond)
            for name, cond in self._conditions.items()
        ])
    
    
start_button = widgets.Button(
    description='Start',
    disabled=False,
    button_style='',
    tooltip='Start experiment',
)
reset_button = widgets.Button(
    description='Reset',
    disabled=True,
    button_style='',
    tooltip='Reset experiment',
)
first_button = widgets.Button(
    description='Path in 1st image',
    disabled=True,
    button_style='',
    tooltip='Click here if you believe you saw a path in the first image.',
)
second_button = widgets.Button(
    description='Path in 2nd image',
    disabled=True,
    button_style='',
    tooltip="Click here if you believe you saw a path in the second image.",
)
num_iterations = widgets.BoundedIntText(
    value=20,
    min=1,
    max=1000,
    step=1,
    description='# iterations:',
)
progress = widgets.IntProgress(
    value=0,
    min=0,
    max=num_iterations.value,
    step=1,
    description='Progress:',
)

debug_mode = False

out = widgets.Output()
final_out = widgets.Output()
experiment = Experiment()

@out.capture(clear_output=True, wait=True)
def display_image(image):
    plt.figure(figsize=(8, 8))
    plt.imshow(image, cmap="gray", origin="upper",
               vmin=-1., vmax=1.)
    plt.show()


def display_stimuli(images, image_duration):
    display_image(images[0])
    for image in images[1:]:
        first_button.disabled = True
        second_button.disabled = True
        reset_button.disabled = True
        time.sleep(image_duration)
        display_image(image)
    first_button.disabled = False
    second_button.disabled = False
    reset_button.disabled = False
        
    if debug_mode:
        print(experiment)
    
def start_experiment():
    start_button.disabled = True
    
    experiment.set_max_iter(num_iterations.value)
    experiment.update_images(*sample_images_and_labels())
    
    display_stimuli(experiment.images, 1.)
    
    
def reset_experiment():
    start_button.disabled = False
    first_button.disabled = True
    second_button.disabled = True
    reset_button.disabled = True
    
    experiment.set_max_iter(num_iterations.value)
    experiment.reset()
    final_out.clear_output()
    progress.value = 0
    

def step_experiment(image_choice):
    experiment.record_answer(image_choice)
    progress.value = experiment.iteration
    
    experiment.update_images(*sample_images_and_labels())
    
    if experiment.should_stop():
        first_button.disabled = True
        second_button.disabled = True
        with final_out:
            print(experiment)
    else:
        display_stimuli(experiment.images, 1.)
    
    
def update_max_progress(change):
    progress.max = change["new"]
    
    
num_iterations.observe(update_max_progress, names='value')
    
start_button.on_click(lambda b: start_experiment())
reset_button.on_click(lambda b: reset_experiment())
    
first_button.on_click(lambda b: step_experiment(0))
second_button.on_click(lambda b: step_experiment(1))
    
control_buttons = widgets.HBox([start_button, reset_button])
answer_buttons = widgets.HBox([first_button, second_button])

experiment.update_images(*sample_images_and_labels())
display_image(experiment.images[-1])

ui = widgets.VBox([
    num_iterations,
    progress,
    control_buttons,
    out,
    answer_buttons,
    final_out,
])
ui.layout.align_items = "center"
display(ui)

VBox(children=(BoundedIntText(value=20, description='# iterations:', max=1000, min=1), IntProgress(value=0, de…