In [1]:
import numpy as np
from scipy import interpolate
from matplotlib import pyplot as plt

from demo_util import gabor_util
ContrastLandscape = gabor_util.ContrastLandscape

In [3]:
from functools import wraps
import ipywidgets as widgets


grid_size = 16
cell_size = 32
kernel_size = 65
scale = 4
wavelength = 8
start_distance = 128
num_points = 12
path_angle = np.deg2rad(45)
angle_noise = np.deg2rad(10)
random_phase = True
contrast_landscape = ContrastLandscape.FIXED
contrast_grid_size = (10, 10)
min_contrast = 0.
max_contrast = 1.
contrast_epsilon = 0.4
contrast_smooth = 0.

align_phase_prob = 0.5
has_path_prob = 0.5


def sample_image_and_label():
    has_path = np.random.rand() > align_phase_prob
    align_phase = np.random.rand() > align_phase_prob
    condition = "aligned" if align_phase else "not-aligned"
    
    _, bg_path_image, _, _ = gabor_util.generate_images(
        seed=None,
        grid_size=grid_size,
        cell_size=cell_size,
        kernel_size=kernel_size,
        scale=scale,
        wavelength=wavelength,
        start_distance=start_distance,
        num_points=num_points if has_path else 0,
        path_angle=path_angle,
        angle_noise=angle_noise,
        random_phase=random_phase,
        align_phase=align_phase,
        contrast_landscape=contrast_landscape,
        contrast_grid_size=contrast_grid_size,
        min_contrast=min_contrast,
        max_contrast=max_contrast,
        generate_contrast_image=False,
        contrast_epsilon=contrast_epsilon,
        contrast_smooth=contrast_smooth,
    )
    
    return bg_path_image, has_path, condition


class ExperimentState:
    
    def __init__(self, max_iteration=3):
        self._hits = 0
        self._num_pos = 0
        self._false_alarm = 0
        self._num_neg = 0
         
    @property
    def iteration(self):
        return self._num_neg + self._num_pos
        
    def hit(self):
        self._hits = self._hits + 1
        self._num_pos = self._num_pos + 1
        
    def miss(self):
        self._num_pos = self._num_pos + 1
        
    def false_alarm(self):
        self._false_alarm = self._false_alarm + 1
        self._num_neg = self._num_neg + 1
        
    def reject(self):
        self._num_neg = self._num_neg + 1
        
    def reset(self):
        self._hits = 0
        self._num_pos = 0
        self._false_alarm = 0
        self._num_neg = 0
        
    def __str__(self):
        return "hits: {}, num pos: {}, false alarms: {}, num neg: {}".format(
            self._hits, self._num_pos, self._false_alarm, self._num_neg)
    

class Experiment:
    
    def __init__(self, max_iteration=3):
        self._max_iter = max_iteration
        self._conditions = {
            "aligned": ExperimentState(),
            "not-aligned": ExperimentState(),
        }
        self._image = None
        self._label = None
        self._cond = None
         
    @property
    def image(self):
        return self._image
    
    @property
    def iteration(self):
        return sum([cond.iteration for cond in self._conditions.values()])

    def reset(self):
        for cond in self._conditions.values():
            cond.reset()
        
    def should_stop(self):
        return self.iteration >= self._max_iter
    
    def record_answer(self, has_path):
        cond = self._conditions[self._cond]
        if self._label:
            if has_path:
                cond.hit()
            else:
                cond.miss()
        else:
            if has_path:
                cond.false_alarm()
            else:
                cond.reject()
    
    def update_image(self, image, label, condition):
        self._image = image
        self._label = label
        self._cond = condition
        
    def __str__(self):
        return "\n".join([
            "{}: {}".format(name, cond)
            for name, cond in self._conditions.items()
        ])
    
    
start_button = widgets.Button(
    description='Start',
    disabled=False,
    button_style='',
    tooltip='Start experiment',
)
reset_button = widgets.Button(
    description='Reset',
    disabled=True,
    button_style='',
    tooltip='Reset experiment',
)
path_present = widgets.Button(
    description='Path',
    disabled=True,
    button_style='',
    tooltip='Click here if you believe you see a path.',
)
path_missing = widgets.Button(
    description='No path',
    disabled=True,
    button_style='',
    tooltip="Click here if you don't see a path.",
)

debug_mode = False

out = widgets.Output()
experiment = Experiment()


def display_image(image):
    plt.figure(figsize=(8, 8))
    plt.imshow(image, cmap="gray", origin="upper",
               vmin=-1., vmax=1.)
    plt.show()

    
@out.capture(clear_output=True, wait=True)
def reset_experiment():
    experiment.reset()
    experiment.update_image(*sample_image_and_label())
    
    display_image(experiment.image)
    if debug_mode:
        print(experiment)
    
    path_present.disabled = False
    path_missing.disabled = False
    start_button.disabled = True
    reset_button.disabled = False


@out.capture(clear_output=True, wait=True)
def step_experiment(has_path):
    path_present.disabled = True
    path_missing.disabled = True
    
    experiment.record_answer(has_path)
    experiment.update_image(*sample_image_and_label())
    display_image(experiment.image)
    
    if debug_mode:
        print(experiment)
        
    path_present.disabled = False
    path_missing.disabled = False
    
    
start_button.on_click(lambda b:reset_experiment())
reset_button.on_click(lambda b:reset_experiment())
    
path_present.on_click(lambda b: step_experiment(True))
path_missing.on_click(lambda b: step_experiment(False))
    
control_buttons = widgets.HBox([start_button, reset_button])
answer_buttons = widgets.HBox([path_present, path_missing])

ui = widgets.VBox([
    control_buttons,
    out,
    answer_buttons,
])
ui.layout.align_items = "center"
display(ui)

VBox(children=(HBox(children=(Button(description='Start', style=ButtonStyle(), tooltip='Start experiment'), Bu…

In [42]:
%matplotlib inline
from ipywidgets import interactive
import ipywidgets as widgets
import functools

seed = widgets.BoundedIntText(
    value=283214663,
    min=0,
    max=2e32,
    step=1,
)
scale = widgets.IntSlider(
    value=4,
    min=1,
    max=32,
    step=1,
    continuous_update=False,
)
wavelength = widgets.IntSlider(
    value=8,
    min=1,
    max=32,
    step=1,
    continuous_update=False,
)
cell_size = widgets.IntSlider(
    value=32,
    min=8,
    max=64,
    step=1,
    continuous_update=False,
)
num_points = widgets.IntSlider(
    value=12,
    min=3,
    max=32,
    step=1,
    continuous_update=False,
)
path_angle = widgets.FloatSlider(
    value=45,
    min=0,
    max=180,
    step=1.,
    continuous_update=False,
)
angle_noise = widgets.FloatSlider(
    value=10.,
    min=0,
    max=45,
    step=1.0,
    continuous_update=False,
)
path_contrast = widgets.ToggleButton(
    value=False,
    description='Path',
    tooltip="Make the path's contrast random",
)
bg_contrast = widgets.ToggleButton(
    value=False,
    description='Background',
    tooltip="Make the background's contrast random",
)
shared_random = widgets.ToggleButton(
    value=False,
    description='Shared Random',
    tooltip=("Make the path's random contrast equal to the "
             "background's random contrast."),
)
min_contrast = widgets.FloatSlider(
    value=0.,
    min=0.,
    max=1.,
    step=0.02,
    continuous_update=False,
)
max_contrast = widgets.FloatSlider(
    value=1.,
    min=0.,
    max=1,
    step=0.02,
    continuous_update=False,
)
contrast_resolution = widgets.IntSlider(
    value=20,
    min=3,
    max=30,
    step=1,
    continuous_update=False,
)
left_col_phase = widgets.ToggleButton(
    value=True,
    description='Random phase',
    tooltip=("Use random phases"),
)
left_col_align = widgets.ToggleButton(
    value=False,
    description='Align phase',
    tooltip=("Align the gabor widow to match the phase"),
)
right_col_phase = widgets.ToggleButton(
    value=True,
    description='Random phase',
    tooltip=("Use random phases"),
)
right_col_align = widgets.ToggleButton(
    value=True,
    description='Align phase',
    tooltip=("Align the gabor widow to match the phase"),
)

def update_max_range(*args):
    max_contrast.min = min_contrast.value
min_contrast.observe(update_max_range, 'value')

grid_size = 16
kernel_size = 65
start_distance = 128

def plotting_function(path_contrast, bg_contrast, shared_random, contrast_resolution,
                      path_angle, angle_noise, **kwargs):
    contrast_grid_size = (contrast_resolution, contrast_resolution)
    contrast_landscape = ContrastLandscape.FIXED
    if path_contrast:
        contrast_landscape = contrast_landscape | ContrastLandscape.RANDOM_PATH
    if bg_contrast:
        contrast_landscape = contrast_landscape | ContrastLandscape.RANDOM_BACKGROUND
    if shared_random:
        contrast_landscape = contrast_landscape | ContrastLandscape.SHARED_RANDOM
    
    path_image, bg_path_image, path_contrast, bg_contrast = gabor_util.generate_images(
        grid_size=grid_size,
        kernel_size=kernel_size,
        start_distance=start_distance,
        contrast_grid_size=contrast_grid_size,
        contrast_landscape=contrast_landscape,
        generate_contrast_image=True,
        path_angle=np.deg2rad(path_angle),
        angle_noise=np.deg2rad(angle_noise),
        **kwargs,
    )
    plot_images(path_image, bg_path_image, path_contrast, bg_contrast)
    
widget_dict = {
        "seed": seed,
        "scale": scale,
        "wavelength": wavelength,
        "cell_size": cell_size,
        "num_points": num_points,
        "path_angle": path_angle,
        "angle_noise": angle_noise,
        "path_contrast": path_contrast,
        "bg_contrast": bg_contrast,
        "shared_random": shared_random,
        "min_contrast": min_contrast,
        "max_contrast": max_contrast,
        "contrast_resolution": contrast_resolution,
    }

left_widget_dict = widget_dict.copy()
left_widget_dict.update({
    "random_phase": left_col_phase,
    "align_phase": left_col_align,
})

left_buttons = widgets.HBox([left_col_phase, left_col_align])
left_buttons.layout.justify_content = "center"

left_out = widgets.VBox([
    left_buttons,
    widgets.interactive_output(plotting_function, left_widget_dict),
])

right_widget_dict = widget_dict.copy()
right_widget_dict.update({
    "random_phase": right_col_phase,
    "align_phase": right_col_align,
})

right_buttons = widgets.HBox([right_col_phase, right_col_align])
right_buttons.layout.justify_content = "center"

right_out = widgets.VBox([
    right_buttons,
    widgets.interactive_output(plotting_function, left_widget_dict),
])

out = widgets.HBox([left_out, right_out])
out.layout.justify_content = "space-around"

gabor_widget = widgets.VBox([
    widgets.HBox([widgets.Label('Gabor scale (px):'), scale]),
    widgets.HBox([widgets.Label('Gabor wavelength (px):'), wavelength]),
])
path_widget = widgets.VBox([
    widgets.HBox([widgets.Label('Path length (# steps):'), num_points]),
    widgets.HBox([widgets.Label('Path angles (deg):'), path_angle]),
    widgets.HBox([widgets.Label('Path angle noise (deg):'), angle_noise]),
])
contrast_widget = widgets.VBox([
    widgets.HBox([widgets.Label('Random Contrast:'),
                  path_contrast,
                  bg_contrast,
                  shared_random]),
    widgets.HBox([widgets.Label('Max contrast:'), max_contrast]),
    widgets.HBox([widgets.Label('Min contrast:'), min_contrast]),
    widgets.HBox([widgets.Label('Contrast grid'), contrast_resolution]),
])


tabs = widgets.Tab()
tab_titles = ["Gabor", "Path", "Contrast"]
tabs.children = [gabor_widget, path_widget, contrast_widget]
for i, title in enumerate(tab_titles):
    tabs.set_title(i, title)

tabs.layout.margin = "1em"
    
widget_rows = widgets.VBox([
    widgets.HBox([widgets.Label('Seed:'), seed, widgets.Label('Cell size (px):'), cell_size]),
    tabs,
])
all_outputs = widgets.VBox([widget_rows, out])
all_outputs.layout.align_items = "center"

display(all_outputs)

VBox(children=(VBox(children=(HBox(children=(Label(value='Seed:'), BoundedIntText(value=283214663, max=2000000…