In [None]:
import os

import numpy as np
import pandas as pd
import scipy.stats as sts
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('white')
sns.set_context('talk')

matplotlib.rcParams['xtick.labelsize'] = 10
matplotlib.rcParams['ytick.labelsize'] = 10
matplotlib.rcParams['axes.labelsize'] = 12
matplotlib.rcParams['legend.fontsize'] = 12
matplotlib.rcParams['mathtext.fontset'] = 'custom'
matplotlib.rcParams['mathtext.rm'] = 'Bitstream Vera Sans'
matplotlib.rcParams['mathtext.default'] = 'rm'
matplotlib.rcParams['mathtext.it'] = 'Bitstream Vera Sans:italic'
matplotlib.rcParams['mathtext.bf'] = 'Bitstream Vera Sans:bold'

%matplotlib inline

In [None]:
files = [f for f in os.listdir('./results') if f[-3:]=='csv']
d0 = pd.read_csv('./results/' + files[0])

D = d0.loc[:, ['subject', 'color', 'bp', 'wp', 'response', 'rt', 'splitg', 'n_pieces']]

D = D.copy()

names = [f[14:-4] for f in files if f[:3]=='app']

for f in files:
    if f[:3]=='app':
        _d = pd.read_csv('./results/' + f)
        D.loc[:, f[14:-4]] = _d.cnn_nll.values
    
D.head()

In [None]:
L = pd.read_csv('./results/loglik_by_board_default.txt', names=['n_pieces', 'default'])
D.loc[:, 'default'] = L.default.values

In [None]:
levels = 2**np.arange(9)
types = ['smart', 'naive']
m_smart, m_naive, l_smart, l_naive, u_smart, u_naive = [np.zeros(9) for _ in range(6)]

for name in names:
    n = np.log2(int(name[6:]))
    if name[:5]=='naive':
        m, l, u = m_naive, l_naive, u_naive
    else:
        m, l, u = m_smart, l_smart, u_smart
    
    mstats, _, _ = sts.bayes_mvs(D.loc[:, name].values)
    m[n] = mstats[0]
    l[n] = mstats[1][0]
    u[n] = mstats[1][1]
    
mstats_default, _, _ = sts.bayes_mvs(D.loc[:, 'default'].values, alpha=.95)
m_default, l_default, u_default = mstats_default[0], mstats_default[1][0], mstats_default[1][1]

In [None]:
print(
    'Is 32 filters significantly better than 16?\n', 
    sts.ttest_ind(D.smart_16.values, D.smart_32.values)
)

print(
    'Is 64 better than 32?\n',
    sts.ttest_ind(D.smart_32.values, D.smart_64.values)
)

print(
    'Just to check, is smart 32 better than naive 32?\n',
    sts.ttest_ind(D.smart_32.values, D.naive_32.values)
)

print(
    'It is official: best CNN is 32 filter with rule knowledge'
)

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(8.5, 5), squeeze=False)

trends = axes[0, 0]

trends.plot(np.arange(9), m_smart, label='Rule knowledge', color='teal')
trends.fill_between(
    np.arange(9), l_smart, u_smart, 
    zorder=0, alpha=.2, facecolor='teal'
)

trends.plot(np.arange(9), m_naive, label='No rule knowledge', color='#11AAAA')
trends.fill_between(
    np.arange(9), l_naive, u_naive, 
    zorder=1, alpha=.2, facecolor='#11AAAA'
)

trends.plot([0, 8], [m_default, m_default], label='Default model', color='goldenrod')
trends.fill_between(
    [0, 8], [l_default,]*2, [u_default,]*2, 
    alpha=.2, facecolor='goldenrod'
)

plt.setp(trends, xlabel=r'$\log_2{}$ Number of filters', ylabel='Negative log likelihood')

trends.legend()
sns.despine()
fig.tight_layout()
fig.savefig('./results/trends.png')

In [None]:
B = pd.read_csv('./results/appended_data_smart_32.csv')
B.columns

In [None]:
reconstitute = lambda x: np.array(list(map(int, x))).reshape(4,9)

def show_output(pos, ax, df, show_zet=False):
    positions = ['cnn_' + str(i) for i in range(36)]
    d = df.loc[pos, positions].values.reshape([4,9]).astype(float)

    sns.heatmap(
        d, ax=ax, 
        square=True, vmin=0, vmax=.25, cbar=False, linewidth=1, linecolor='black',
        cmap=sns.palettes.blend_palette(['#999999', '#66FF66'], n_colors=18, as_cmap=True),
        xticklabels=False, yticklabels=False, annot=True, fmt='.2f'
    )

    if df.loc[pos, 'color'] == 0:
        own_color, opp_color = 'black', 'white'
    else:
        own_color, opp_color = 'white', 'black'

    p = np.where(reconstitute(df.loc[pos, 'bp'])==1)
    ax.scatter( .5 + p[1], 3.5 - p[0], c=own_color, s=400)
    
    p = np.where(reconstitute(df.loc[pos, 'wp'])==1)
    ax.scatter( .5 + p[1], 3.5 - p[0], c=opp_color, s=400)
    
    if show_zet:
        r = B.loc[pos, 'response']
        p = (r % 9, r // 9)
        ax.plot(
            p[0] + .5, 3.5 - p[1], 
            linestyle='None', 
            marker='o', markersize=20, markerfacecolor='None', 
            markeredgecolor='black', markeredgewidth=1
        );
        

In [None]:
fig = plt.figure(figsize=(8.5, 8))
gs = matplotlib.gridspec.GridSpec(4, 2)

for i in np.arange(8):
    ax = plt.subplot(gs[i//2, i%2])
    show_output(pos=i*2, ax=ax, df=B, show_zet=True)
    
sns.despine()

plt.tight_layout()
fig.savefig('examples.png', bbox_inches='tight')

In [None]:
fig, cplot = plt.subplots(1, 1, figsize=(8.5, 1))

cplot.imshow(
    np.linspace(0, .25, 18).reshape(1, 18),
    cmap=sns.palettes.blend_palette(['#999999', '#66FF66'], n_colors=18, as_cmap=True),
    interpolation="nearest"#, aspect="auto"
)

plt.setp(
    cplot, xticklabels=[0, .25], xticks=[-.5, 17.5]
)
cplot.yaxis.set_visible(False)
cplot.bbox.size[0] = 500
sns.despine(ax=cplot, bottom=True, left=True)
fig.tight_layout()
fig.savefig('colorbar.png')

In [None]:
pd.read_csv('./results/trace_0.csv')