In [None]:
import numpy as np
import os,sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.train import load_rnn
from rnn_scripts.utils import su_stats
from matplotlib.colors import colorConverter as cc
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Load sweep
# --------------
#model_dir = os.getcwd()+ "/../models/sweep/"
model_dir = os.getcwd()+ "/../models/sweep_osc/" #Oscillatory initialisation

#select models to load
directory = os.fsencode(model_dir)
    

In [None]:
# Process all models in sweep
    
Data = {"model":[],
        "rank":[],
        "reg":[],
        "readout":[],
        "angs_dist":[],
        "means_dist":[],
        "accuracy":[]}

for file in os.listdir(directory):
     filename = os.fsdecode(file)
     if filename.endswith("state_dict.pkl"): 
        model = filename[:-15]
        print(model)
        rnn,params,task_params,training_params = load_rnn(model_dir+model)
        angs_dist, means_dist = su_stats(task_params,rnn, normalize=True, out_nonlinearity=True)
        Data["model"].append(model)
        Data["rank"].append(params['rank'])
        Data["reg"].append(training_params["offset_reg_cost_masked"])
        Data["readout"].append(params["out_nonlinearity"])
        Data["angs_dist"].append(angs_dist)
        Data["means_dist"].append(means_dist)
        Data["accuracy"].append(training_params["val_loss"][-1])


In [None]:
# Plot sweep results

fix, ax = plt.subplots(2,4, figsize = (4.5,3))
s=15

def measure(x):
    return np.mean(np.abs(x))
measure_phase = measure
measure_rate = measure
acc_tr = .2 # Don't include models with accuracy below this threshold in the top row
plt_colors =["tab:blue","tab:green","tab:orange","tab:purple"]
accs=[[[] for _ in range(4)] for _ in range(4) ]
alpha = .2
edge_alpha = 1
ms = ['^','o','P','*']
for i, model in enumerate(Data["model"]):
    angs_dist = Data["angs_dist"][i]
    means_dist = Data["means_dist"][i]

    if Data["reg"][i] == 0 and Data["readout"][i] == 'tanh':
        acc_ind = 0
        zorder = np.random.randint(0,100)

    elif Data["reg"][i] == 0 and Data["readout"][i] == 'identity':
        acc_ind = 2
        zorder = np.random.randint(0,100)

    elif Data["reg"][i] == 1 and Data["readout"][i] == 'tanh':
        acc_ind = 1
        zorder = np.random.randint(0,100)

    elif Data["reg"][i] == 1 and Data["readout"][i] == 'identity':
        acc_ind = 3
        zorder = np.random.randint(0,100)
    if Data["rank"][i]<1:
        rank_ind = 3
    else:
        rank_ind = Data["rank"][i]-1


    # Make a scatter plot of the data
    if Data["accuracy"][i]<acc_tr:
        ax[0, rank_ind].scatter(measure_rate(means_dist),measure_phase(angs_dist),
                                 marker =ms[acc_ind], color = cc.to_rgba(plt_colors[acc_ind],alpha=alpha), s=s, zorder=zorder,
                                 edgecolors=cc.to_rgba(plt_colors[acc_ind], alpha=edge_alpha))
    accs[rank_ind][acc_ind].append(Data["accuracy"][i])

xlim = max(measure_rate(mean) for mean in Data["means_dist"])
ylim = max(measure_phase(ang) for ang in Data["angs_dist"])

# Plot accuracy as box plots
for i in range(4):
    for j in range(4):
        c=plt_colors[j]
        ax[1,i].boxplot(accs[i][j],positions=[j],widths=.6,patch_artist=True,
                        boxprops=dict(facecolor=cc.to_rgba(c, alpha=alpha), color=c),
                        capprops=dict(color=c),
                        whiskerprops=dict(color=c),
                        medianprops=dict(color=c),
                        flierprops={'marker': 'o', 'markersize': 1, 'markerfacecolor':c, 'markeredgecolor':c})
    ax[1,0].set_ylabel("validatation loss")
    ax[1,i].set_ylim(0,.5)
    ax[1,i].spines['bottom'].set_visible(False) 
    ax[1,i].set_xticks([]) 
    ax[1,i].set_yticks([0,.25,.5])

    ax[1,i].set_yticklabels([])

    ax[0,0].set_xlabel("rate difference")
    ax[0,i].set_xlim(-.2,xlim*1.1)
    ax[0,i].set_ylim(-.1,ylim*1.1)
    ax[0,0].set_ylabel("phase difference")
    #ax[0,i].set_xticks([0,2,4])
    ax[0,i].set_yticks([0,1,2])
    ax[0,i].spines['bottom'].set_position(('outward',3)) 
    ax[0,i].spines['left'].set_position(('outward', 3)) 
    ax[0,i].set_box_aspect(1)
    ax[1,i].set_box_aspect(1)

    if i <1:
        ax[0,3].set_title("full rank")
    else:
        ax[0,i-1].set_title("rank "+str(i))
        ax[0,i].set_yticklabels([])
    handles, labels = plt.gca().get_legend_handles_labels()
ax[1,0].set_yticklabels(["0",".25",".5"])

# create manual legend
s=4
point1 = Line2D([0], [0], label='tanh-out, no-reg', marker=ms[0], markersize=s, 
        markeredgecolor=cc.to_rgba(plt_colors[0], alpha=edge_alpha), markerfacecolor=cc.to_rgba(plt_colors[0], alpha=alpha), linestyle='')
point2 = Line2D([0], [0], label='tanh-out, reg', marker=ms[1], markersize=s, 
        markeredgecolor=cc.to_rgba(plt_colors[1], alpha=edge_alpha), markerfacecolor=cc.to_rgba(plt_colors[1], alpha=alpha), linestyle='')
point3 = Line2D([0], [0], label='lin-out, no-reg', marker=ms[2], markersize=s, 
        markeredgecolor=cc.to_rgba(plt_colors[2], alpha=edge_alpha), markerfacecolor=cc.to_rgba(plt_colors[2], alpha=alpha), linestyle='')
point4 = Line2D([0], [0], label='lin-out, reg', marker=ms[3], markersize=s, 
        markeredgecolor=cc.to_rgba(plt_colors[3], alpha=edge_alpha), markerfacecolor=cc.to_rgba(plt_colors[3], alpha=alpha), linestyle='')
handles.extend([point1,point2,point3,point4])

plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=.6)
plt.legend(handles=handles,bbox_to_anchor=[2.7,3])
plt.savefig("../figures/sweep.svg")