In [None]:
# imports
import os
import imp
import numpy as np
import pandas as pd
import theano
import lasagne
import loading
from training import *
from network import *
from architectures import *
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import bayes_mvs, entropy, linregress, spearmanr

# settings
sns.set_style('white')
sns.set_context('poster')
colors = sns.color_palette()

%matplotlib inline

# aliases
L = lasagne.layers
nl = lasagne.nonlinearities
T = theano.tensor
bmvs = bayes_mvs

In [None]:
histkws = {
    'alpha': .8, 'edgecolor': 'white',
    'normed': False, 'bins': np.arange(1.2, 3.0, .1)   
}

scatterkws = {
    'marker':'o', 'markersize':7, 'linestyle': 'None', 'alpha': .8
}

hmkws = {
    'cbar': False, 'cmap': sns.blend_palette([(.95, .95, .95), colors[0]], n_colors=16, as_cmap=True),
    'square': True, 'xticklabels': False, 'yticklabels': False,
    'vmin': 0, 'vmax': 1
}

boardplotkws = {
    'marker': 'o', 'markersize': 20, 'markeredgecolor': 'black', 'markeredgewidth': 2,
    'linestyle': 'None'
}

def show_net_response(pos_idx, ax, net=net):
    response = net.output_fn(Xs[pos_idx:pos_idx+1, :, :, :])
    sns.heatmap(response.reshape([4, 9])[::-1, :], ax=ax, **hmkws)
    
    if Xs[pos_idx, :, :, :].sum()%2 == 0:
        b = 0
        w = 1
    else:
        b = 1
        w = 0
    
    bcoords = np.where(Xs[pos_idx, b, :, :]==1)
    wcoords = np.where(Xs[pos_idx, w, :, :]==1)
    rcoords = np.unravel_index(ys[pos_idx], (4, 9))

    ax.plot(bcoords[1]+.5, bcoords[0]+.5, color='black', **boardplotkws)
    ax.plot(wcoords[1]+.5, wcoords[0]+.5, color='white', **boardplotkws)
    ax.plot(rcoords[1]+.5, rcoords[0]+.5, color=colors[2], **boardplotkws)
    plt.setp(ax, frame_on=False)
    
    return None

sns.palplot(sns.blend_palette([(.95, .95, .95), colors[0]], n_colors=16, as_cmap=False))

In [None]:
archname='multiconvX'
def countpieces(row):
    bp = row['bp']
    wp = row['wp']
    p = np.array(list(bp+wp)).astype(int)
    return p.sum()
pretrain_tidy['npieces'] = df.apply(countpieces, axis=1)
ptt = pretrain_tidy[[archname, 'subject', 'group', 'npieces']].mean(axis=1, level=0)
ptt_piecepiv = ptt.pivot_table(index='npieces', values=archname, columns='subject', aggfunc=np.mean)

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(12, 14), squeeze=False)

ax = axes[0, 0]
ax.hist(ptpiv.values, color=colors[0], label='Convnet', **histkws)
ax.hist(defarray.mean(axis=1), color=colors[1], label='H search', **histkws)
ax.legend(loc=0)
plt.setp(ax, xlabel='CV NLL', ylabel='# Subjects')


ax = axes[0, 1]
ptemp = pretrain_tidy[['subject', archname]]
mos = [bmvs(ptemp.loc[ptemp['subject']==i, archname].values, alpha=.95) for i in np.arange(40)]
means = np.array([mo[0][0] for mo in mos])
lbs = np.array([mo[0][1][0] for mo in mos])
ubs = np.array([mo[0][1][1] for mo in mos])

dmos = [bmvs(defarray[i, :], alpha=.95) for i in np.arange(40)]
dmeans = np.array([mo[0][0] for mo in dmos])
dlbs = np.array([mo[0][1][0] for mo in dmos])
dubs = np.array([mo[0][1][1] for mo in dmos])

orderidx = defarray.mean(axis=1).argsort()

ax.plot(np.arange(40), dmeans[orderidx], color=colors[1], **scatterkws)
ax.fill_between(np.arange(40), y1=dlbs[orderidx], y2=dubs[orderidx], alpha=.25, color=colors[1])
ax.plot(np.arange(40), means[orderidx], color=colors[0], **scatterkws)
plt.setp(ax, xlabel='Subject (ranked by default fit)', ylabel='CV NLL')


ax = axes[1, 0]
mos = [bmvs(ptt_piecepiv.loc[i, :].values, alpha=.95) for i in np.arange(36)]
means = np.array([mo[0][0] for mo in mos])
lbs = np.array([mo[0][1][0] for mo in mos])
ubs = np.array([mo[0][1][1] for mo in mos])

means_corrected = -(means + np.log(1/(np.arange(36)+1)[::-1]))
ax.plot(np.arange(36), -means) # **scatterkws)
# ax.fill_between(np.arange(36), y1=lbs, y2=ubs)
plt.setp(ax, xlabel='# pieces', ylabel='CV NLL')


ax = axes[1, 1]
blank = np.zeros([1, 2, 4, 9])
sns.heatmap(net.output_fn(blank).reshape([4, 9])[::-1, :], ax=ax, **hmkws)
plt.setp(ax, frame_on=False)


ax = axes[2, 0]
show_net_response(0, ax=ax)


ax = axes[2, 1]
show_net_response(12, ax=ax)


sns.despine();

In [None]:
filter_layer = L.get_all_layers(net.net)[1]
filter_output = L.get_output(filter_layer, deterministic=True)
scaled_foutput = L.get_output(L.get_all_layers(net.net)[2], deterministic=True)
pooled_foutput = L.get_output(L.get_all_layers(net.net)[3], deterministic=True)

filter_output_fn = theano.function([net.input_var], filter_output)
scaled_foutput_fn = theano.function([net.input_var], scaled_foutput)
pooled_foutput_fn = theano.function([net.input_var], pooled_foutput)

filters = L.get_all_param_values(filter_layer)[0]

imshowkws = {
    'interpolation': 'nearest',
#     'vmin': -1, 'vmax': 1,
    'cmap': sns.diverging_palette(20, 240, n=15, s=99, as_cmap=True)
}

def show_filter_output(pos_idx, func=filter_output_fn, filter_idx=None, ax=None, imshowkws=imshowkws):
    if not ax:
        ax = plt.gca()
    
    if not (filter_idx is None):
        fout = func(Xs[pos_idx:pos_idx+1, :, :, :])[0, filter_idx, :, :]
    else:
        fout = func(Xs[pos_idx:pos_idx+1, :, :, :])[0, :, :, :].sum(axis=0)
        
    if not ('vmin' in imshowkws.keys()):
        if func==filter_output_fn:
            ax.imshow(fout, vmin=-7.6, vmax=7.6, **imshowkws)
        else:
            ax.imshow(fout, vmin=-1, vmax=1, **imshowkws)
    else:
        ax.imshow(fout, **imshowkws)
    plt.setp(ax, frame_on=False, xticklabels=[], yticklabels=[], xlabel='Filter response')
    
    return ax

sns.palplot(sns.diverging_palette(20, 240, n=11, s=99, as_cmap=False))

In [None]:
fig, axes = plt.subplots(33, 5, figsize=(24, 136), squeeze=False)

pos_idx = 50 #30

for i in np.arange(32):
    
    ax = axes[i+1, 0]
    if i%2==0:
        show_filter_output(pos_idx, filter_idx=i/2, ax=ax, func=pooled_foutput_fn)
        plt.setp(ax, xlabel='Filter response (post-pooling)')
    else:
        plt.setp(ax, frame_on=False, xticklabels=[], yticklabels=[])
    
    ax = axes[i+1, 1]
    show_filter_output(pos_idx, filter_idx=i, ax=ax, func=scaled_foutput_fn)
    plt.setp(ax, xlabel='Filter response (post-PReLu)')

    
    ax = axes[i+1, 2]
    show_filter_output(pos_idx, filter_idx=i, ax=ax)
    plt.setp(ax, xlabel='Filter response')


    ax = axes[i+1, 3]
    ax.imshow(filters[i, 0, :, :], vmin=-2.6, vmax=2.6, **imshowkws)
    plt.setp(ax, frame_on=False, xticklabels=[], yticklabels=[], xlabel='Own filter')


    ax = axes[i+1 , 4]
    ax.imshow(filters[i, 1, :, :],vmin=-2.6, vmax=2.6, **imshowkws)
    plt.setp(ax, frame_on=False, xticklabels=[], yticklabels=[], xlabel='Opp filter')


# make separate figure
ax = axes[0, 0]
kws = {'interpolation':'nearest', 
       'cmap': hmkws['cmap'], 
#        'vmin': -46, 'vmax':-7
      }
show_filter_output(pos_idx, func=pooled_foutput_fn, ax=ax, imshowkws=kws)
plt.setp(ax, xlabel='Filter response (sum, post-pooling)') #.format(kws['vmin'], kws['vmax']))


ax = axes[0, 1]
show_net_response(pos_idx, ax=ax)
ax.invert_yaxis()


ax = axes[0, 2:]
plt.setp(ax, frame_on=False, xticklabels=[], yticklabels=[])


sns.despine()

In [None]:
Xx, yy, Ss, G, Np = loading.unpack_data(df)

def countpieces(row):
    bp = row['bp']
    wp = row['wp']
    p = np.array(list(bp+wp)).astype(int)
    return p.sum()

df['npieces'] = df.apply(countpieces, axis=1)

# rewrite as vectorized func for pandas!
for s in np.arange(40):
    for i in np.arange(df.loc[df['subject']==s, 'npieces'].max()):
        c = (df['npieces']==i)&(df['subject']==s)
        df.loc[c, 'mc rt'] = df.loc[c, 'rt'] - df.loc[c, 'rt'].mean()
        df.loc[c, 'mc rt'] = df.loc[c, 'mc rt'] / df.loc[c, 'mc rt'].std()

In [None]:
g = df.groupby('npieces')
rt_hists = g['rt'].apply(lambda x: np.histogram(x, bins=1000)[0])
rt_hists = rt_hists.map(entropy).values

In [None]:
scatterkws = {
    'marker':'o', 'markersize':6, 'linestyle': 'None', 'alpha': .15
}

predictions = net.output_fn(Xx)

numlegal = 36 - df['npieces'].values
df['numlegal'] = numlegal
numlegalent = numlegal * (1 / numlegal) * np.log(1 / numlegal)
entropies = np.apply_along_axis(entropy, axis=1, arr=predictions)
# entropies = entropies + numlegalent
# rt_np_ent = df.pivot_table(index='numlegal', values='rt', aggfunc=entropy).values

response_times = df['rt'].values

v1 = (response_times < 60000)
v2 = (numlegal < 32)
valid =  v1 & v2
valid = np.where(valid)[0]
aggents = [entropies[np.where(numlegal==i)[0]] for i in np.arange(1, 37)]
meanents = np.array([m.mean() for m in aggents])
sements = 1.96*np.array([m.std() / np.sqrt(m.size) for m in aggents])
print("Entropy vs log RT\n", linregress(entropies[np.where(v2)[0]], np.log(response_times[np.where(v2)[0]])), '\n')
print("Spearman R Numlegal vs Entropy\n", spearmanr(numlegal, entropies), '\n')

fig, axes = plt.subplots(2, 2, figsize=(20, 12), squeeze=False)
v = np.where(v1)[0]
ax = axes[0, 0]

ax.plot(entropies[v], response_times[v], **scatterkws)
plt.setp(ax, xlabel='Prediction entropy', ylabel='Response time (ms)')

ax = axes[0, 1]
ax.plot(numlegal[v], response_times[v], **scatterkws)
plt.setp(ax, xlabel='# Legal moves', ylabel='Response time (ms)')

ax = axes[1, 0]
ax.plot(numlegal, entropies, **scatterkws)
ax.plot(np.arange(1, 37), meanents, linewidth=5, label='Mean')
ax.fill_between(
    np.arange(1, 37), y1=meanents+sements, y2=meanents-sements, 
    alpha=.5, color=colors[1]
)
ax.legend(loc=0)
plt.setp(ax, xlabel='# Legal moves', ylabel='Prediction entropy')

ax = axes[1, 1]
ax.plot(meanents, rt_hists, marker='o', linestyle='none', markersize=10)
plt.setp(ax, xlabel='Mean Entropy per # legal moves', ylabel='RT Entropy per # legal moves')
# plt.setp(ax, frame_on=False, xticklabels=[], yticklabels=[])

sns.despine();