In [None]:
%load_ext autoreload
%autoreload 2
# %matplotlib widget

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('paper.mplstyle')
from matplotlib.lines import Line2D
import os
import warnings
import torch
from tqdm import tqdm

warnings.filterwarnings("ignore")

from fun.gp_hsh import GP_HSH, CustomSquareWarpedGP
from fun.grid_ops import get_grid_idcs, gen_reg_grid
import utils as u

In [29]:
def set_size(width_pt, fraction=1, subplots=(1, 1)):
    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction
    # Convert from pt to inches
    inches_per_pt = 1 / 72.27
    # Golden ratio to set aesthetic figure height
    golden_ratio = (5**.5 - 1) / 2
    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)

colors = sns.color_palette("Set2")[:5] + ['darkgrey']

In [None]:
result = pd.read_pickle('data/hs_hunting_exp.pkl')
result = result.replace('random', 'rand.')
result['dist_ang']*=90
result['dist_loc']*=30
result.head(5)

## Initialization

In [None]:
%matplotlib widget
linestyles=['-','--']
inits = np.unique(result['n_init'])[::-1]
fig, axs = plt.subplots(2, 3, figsize=(set_size(483.7)[0],3), sharex=True, sharey=True)
axs = axs.flatten()
hue = 'acquisition_function'
n_inits = np.unique(result['n_init'])[::-1]

for i, af in enumerate(np.unique(result['acquisition_function'])):
    colors2 = [sns.light_palette(colors[i], n_colors=10)[k] for k in [4,9]]
    res = result[(result['acquisition_function']==af)]
    sns.lineplot(data=res, hue='n_init', x="n_samples", y='nmse', ax=axs[i], legend=False, linewidth=.7, markersize=1, palette=colors2, style='n_init', style_order=n_inits)
    sns.despine()
    axs[i].set_ylabel(r'$NRMSE$')
    axs[i].set_xlabel('number of samples')
    axs[i].set_xticks(np.arange(10,51,10))
    axs[i].set_title(np.unique(result['acquisition_function'])[i], y=.8)

# Add a second legend for n_init with manually defined lines
handles_n_init = [Line2D([0], [0], color='black', linestyle=ls, label=f'n = {n_init}', linewidth=.7) 
                  for ls, n_init, j in zip(linestyles, n_inits, range(len(n_inits)))]

axs[1].legend(handles=handles_n_init, title='initialization points', loc='upper center', bbox_to_anchor=(0.5, 1.5), ncol=2, frameon=False)
fig.savefig('figures/init.pdf', transparent=True, bbox_inches="tight", pad_inches=0.01)
plt.show()

## Method Compare

In [None]:
%matplotlib widget
fig, axs = plt.subplots(2, 3, figsize=(set_size(483.7)[0],4.3))
hue = 'acquisition_function'
y_labels = [r'$NRMSE$', r'$d_{x,y}$ (mm)', r"$d_\theta$ (°)"]
metrics = ['nmse', 'dist_loc', 'dist_ang']
legend=[False, True, False]

colors = sns.set_palette(sns.color_palette(sns.color_palette("Set2")[:5] + ['darkgrey']))
res = result[(result['n_init']==10)]

for i, metric in enumerate(metrics):
    
    sns.lineplot(data=res, x="n_samples", y=metric, hue=hue, ax=axs[0,i], legend=legend[i], style=hue, dashes=False, linewidth=.7, markersize=1, palette=colors, hue_order=np.unique(result[hue]))
    axs[0,i].set_xlabel('number of samples')
    axs[0,i].set_xticks(np.arange(10,51,10))
    axs[0,i].set_ylabel(y_labels[i])

    res30 = res[res['n_samples']==30]
    sns.boxplot(data=res30, x=hue, y=metric, ax=axs[1,i], hue=hue, hue_order=np.unique(res30[hue]), showfliers=False, order=np.unique(res30[hue]))
    axs[1,i].set_ylabel(y_labels[i])
    axs[1,i].set_xlabel('')
    axs[1,i].set_xticklabels(axs[1,i].get_xticklabels(),rotation=45)

sns.despine()
sns.move_legend(axs[0,1], "upper center", frameon=False, bbox_to_anchor=(0.5, 1.5), ncol=3, title=None)
axs[0,1].set_ylim(1,8)
axs[0,2].set_ylim(3,12)
axs[0,0].set_ylim(0,.6)

axs[0,0].text(0, 1.1*axs[0,0].get_ylim()[1], '(A)', ha='center')
axs[1,0].text(-0.5, 1.1*axs[1,0].get_ylim()[1], '(B)', ha='center')
plt.subplots_adjust(wspace=.45, hspace=.5)
fig.savefig('figures/real_data.pdf', transparent=True, bbox_inches="tight", pad_inches=0.01)
plt.show()

# Sampling strategies

In [None]:
# data specs
n_max_samples = 30
subj = '009'
n_init = 5
seed = 3

search_r = 30
a_res = 10
s_res = 1

# get experimental data
res = np.load(os.path.join('data', 'result_'+subj+'.npy'), allow_pickle=True)[()]
meps_gt = res['meps']
grid = res['grid']
pos = grid[:,1:]/30
rads  = (grid[:,0]-90)/90 
locs_gt = np.hstack((pos, rads.reshape(-1,1)))

# generate test grid
grid = gen_reg_grid(search_r,s_res,a_res)
pos = grid[:,1:]/search_r
rads  = (grid[:,0]-90)/90 
locs = np.hstack((pos, rads.reshape(-1,1)))

# get initial grid points
grid_idcs = get_grid_idcs(locs, num_points=n_init, seed=seed)

# GT model botorch
train_X = torch.tensor(locs_gt, dtype=torch.float32)
train_Y = torch.tensor(meps_gt, dtype=torch.float32).reshape(-1,1)
locs_torch = torch.tensor(locs, dtype=torch.float32)

model = CustomSquareWarpedGP(train_X, train_Y)
model.fit(n_restarts=5)

cog_gt = u.cog(locs, model.predict(locs_torch)[0].flatten().detach().numpy(), percentile=90)

acquisition_functions =  ['UCB', 'EI','TS', 'KG', 'MVE']
acquisition_functions.sort()
acquisition_functions += ['random']

opts = []
# init result array
for acquisition_function in acquisition_functions:

    optimization = GP_HSH(all_locs=locs_torch, acquisition_function=acquisition_function)

    for i in tqdm(range(n_max_samples)):
        if i<len(grid_idcs):
            next_state=torch.tensor(grid_idcs[i])
        else:
            next_state = optimization.sample_state()
        return_value = model.posterior(locs_torch[next_state:next_state+1]).sample()[0,:,0]**2
 
        optimization.update(next_state, return_value)

    opts.append(optimization)


In [None]:
%matplotlib widget
import matplotlib.colors as mcolors
fig, axs = plt.subplots(2,int(np.ceil(len(opts)/2)), figsize=(set_size(483.7)[0],5), sharex=True, sharey=True)
cmap = plt.get_cmap('viridis')
axs=axs.flatten()

for i, optimization in enumerate(opts):
    data = grid[optimization.sampled_loc_idcs]
    returns =optimization.sampled_returns
    for d, r in zip(data, returns):
        axs[i].arrow(d[1], d[2], 3*np.cos(d[0]*np.pi/180), 3*np.sin(d[0]*np.pi/180), length_includes_head=True,
            head_width=2, head_length=1, width=.01, color=cmap(r/meps_gt.max()))
    
    circle = plt.Circle((0,0), 30, fill=False)
    axs[i].add_patch(circle)

    axs[i].arrow(cog_gt[0]*30, cog_gt[1]*30, 8*np.cos((cog_gt[2]*90+90)*np.pi/180), 8*np.sin((cog_gt[2]*90+90)*np.pi/180), length_includes_head=True,
                 head_width=5, head_length=3, width=1, lw=.2, fc='r', ec='k')
    
    cog_model = u.cog(locs, optimization.gp_model.predict(locs_torch)[0].flatten().detach().numpy(), percentile=90)
    
    axs[i].arrow(cog_model[0]*30, cog_model[1]*30, 8*np.cos((cog_model[2]*90+90)*np.pi/180), 8*np.sin((cog_model[2]*90+90)*np.pi/180), length_includes_head=True,
                 head_width=5, head_length=3, width=1, lw=.2, fc='cornflowerblue', ec='k')


    axs[i].spines['right'].set_visible(False)
    axs[i].spines['top'].set_visible(False)
    axs[i].set_title(acquisition_functions[i])
    axs[i].set_aspect('equal', adjustable='box')
    axs[i].set_xticks(np.linspace(-30,30,5))
    axs[i].set_yticks(np.linspace(-30,30,5))
    if i>2:
        axs[i].set_xlabel('$x$')
    if i%3 == 0:
        axs[i].set_ylabel('$y$')

norm = mcolors.Normalize(vmin=0, vmax=meps_gt.max())
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar_ax = fig.add_axes([0.93, 0.4, 0.01, 0.2])
cbar = fig.colorbar(sm, cax=cbar_ax)
cbar.ax.set_title(r"$r$", pad=10, fontdict={'fontsize':10})

fig.savefig('figures/sampling_strategies_no_init.pdf', transparent=True, bbox_inches="tight", pad_inches=0.01)
plt.show()

# Response Amplitude

In [54]:
subjs = ['002','003','004','006','007','008', '009']#, '010']

max_vals = pd.DataFrame(columns=['max', 'sub', 'type'])

for sub in subjs:
    # load data
    res = np.load(os.path.join('data', 'result_'+sub+'.npy'), allow_pickle=True)[()]
    meps_gt = res['meps']
    grid = res['grid']
    pos = grid[:,1:]/30
    rads  = (grid[:,0]-90)/90
    locs_gt = np.hstack((pos, rads.reshape(-1,1)))

    grid = gen_reg_grid(30,1,22.5)
    pos = grid[:,1:]/30
    rads  = (grid[:,0]-90)/90
    locs = np.hstack((pos, rads.reshape(-1,1)))
    
    # GT model botorch
    train_X = torch.tensor(locs_gt, dtype=torch.float32)
    train_Y = torch.tensor(meps_gt, dtype=torch.float32).reshape(-1,1)

    locs_torch = torch.tensor(locs, dtype=torch.float32)
    locs_45 = locs_torch[grid[:,0]==45]

    model = CustomSquareWarpedGP(train_X, train_Y)
    model.fit(n_restarts=5)

    max_vals.loc[len(max_vals)] = {'max': model.predict(locs_torch)[0].max().item(), 'type': 'all', 'sub': sub}
    max_vals.loc[len(max_vals)] = {'max': model.predict(locs_45)[0].max().item(), 'type': '45', 'sub': sub}

In [None]:
rel_diffs = max_vals[max_vals['type']=='all']['max'].values  / max_vals[max_vals['type']=='45']['max'].values
abs_diffs = (max_vals[max_vals['type']=='all']['max'].values  - max_vals[max_vals['type']=='45']['max'].values)
rel_diffs, abs_diffs

print(r'Rel. difference: %.2f +- %.2f' % (np.median(rel_diffs), np.std(rel_diffs)))
print(r'Abs. difference: %.2f +- %.2f' % (np.median(abs_diffs), np.std(abs_diffs)))
fig, ax = plt.subplots(figsize=(set_size(483.7)[0]/3,2))
sns.boxplot(rel_diffs, showfliers=False)
sns.stripplot(rel_diffs, linewidth=1, color='grey')
sns.despine()
ax.set_xticks([])
ax.set_ylabel(r'$\displaystyle\frac{r^*_{all}}{r^*_{45^\circ}}$', rotation=0, ha='right', fontsize=10)
fig.savefig('figures/rel_diff.pdf', transparent=True, bbox_inches="tight", pad_inches=0.01)
plt.show()