In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import psyneulink as pnl
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import sys; sys.path.append('./shape-naming/')
import stroop_model

## Plotting Utility

In [6]:
DEFAULT_FIGURE_SIZE = (18, 10)
DEFAULT_TITLES = ('72 epochs', '504 epochs', '2520 epochs')
ORDERED_KEYS = ('control', 'conflict', 'congruent')

def plot_by_condition(first_results, second_results=None, third_results=None,
                     figsize=DEFAULT_FIGURE_SIZE, titles=DEFAULT_TITLES,
                     ordered_keys=ORDERED_KEYS):
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=DEFAULT_FIGURE_SIZE)
    fig.patch.set_facecolor('#DDDDDD')
    fig.suptitle('Figure 12', fontsize=24)

    plot_single_result_set(ax1, first_results, titles[0], ordered_keys)
    if second_results is not None: 
        plot_single_result_set(ax2, second_results, titles[1], ordered_keys)
    if third_results is not None:
        plot_single_result_set(ax3, third_results, titles[2], ordered_keys)
    
    plt.show()
    
def plot_single_result_set(ax, results, title, ordered_keys):
    shape_naming, color_naming, _ = results
    shape_naming_avg = [np.average(shape_naming[key]) for key in ordered_keys]
    shape_naming_std = [np.std(shape_naming[key]) for key in ordered_keys]
    
    color_naming_avg = [np.average(color_naming[key]) for key in ordered_keys]
    color_naming_std = [np.std(color_naming[key]) for key in ordered_keys]
    x_values = np.arange(3)
                                  
    ax.errorbar(x_values, shape_naming_avg, yerr=shape_naming_std, marker='o', 
                markersize=10, capsize=10, label='shape')
    ax.errorbar(x_values, color_naming_avg, yerr=color_naming_std, marker='s', 
                markersize=10, capsize=10, label='color')

    ax.set_xticks(x_values)
    ax.set_xticklabels(ordered_keys)
    ax.tick_params(length=15, labelsize=16)
    ax.legend(fontsize=16)
    ax.set_title(title, dict(fontsize=16))

## Create a model and test it

In [None]:
stroop = stroop_model.StroopModel()

In [4]:
train_output = stroop.train()

In [5]:
test_output = stroop.test()