In [31]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 
import nbimporter
%run -i plotting.ipynb

In [32]:
def run_experiment(exp_name, parameter_list, change_index, niters, x_axis):

    # Standard list
    params = [np.linspace(0.1,1,5), 2000, 0.2*np.ones(5), 2.5, 0.9, 0, 2, "two_spike", "threshold", "all_start", "average"]
    
    if type(change_index)!= list:
        change_index = [change_index]
    else:
        params[change_index[1]] = exp_name

    max_e = np.zeros((len(parameter_list), len(params[0])))
    when_max_e = np.zeros((len(parameter_list), len(params[0])))
    max_r = np.zeros(len(parameter_list))

    for (i, param) in enumerate(parameter_list):
        params[change_index[0]] = param
        u, e, r =  plot_results(params, exp_name , round(param, 2), niters)
        max_e[i, :] = np.max(e, axis = 0)
        when_max_e[i, :] = np.where(e == max_e[i, :])[0] 
        max_r[i] = np.max(r)

    colors = sns.color_palette("crest", 5)
    fig = plt.figure()
    ax = fig.add_subplot()
    ax.set_prop_cycle('color', colors)
    fig.tight_layout(pad = 1.9)

    ax.plot(parameter_list, max_e[:, 1: ] , linewidth = 3, label = ["$P_k = " + str(b) + "$" for b in params[0][1:]])
    ax.legend()
    ax.set_xlabel(x_axis)
    ax.set_ylabel("Maximum value of $E(P_k,t)$")
    plt.savefig("results/{}/max.png".format(exp_name), dpi = 300)
    plt.close(fig)

    fig = plt.figure()
    ax = fig.add_subplot()
    ax.set_prop_cycle('color', colors)
    fig.tight_layout(pad = 1.9)

    ax.plot(parameter_list, when_max_e[:, :-1 ]*1e-3 , linewidth = 3, label = ["$P_k = " + str(b) + "$" for b in params[0][1:]])
    ax.legend()
    ax.set_xlabel(x_axis)
    ax.set_ylabel("Time taken to reach maximum value of $E(P_k,t)$")
    plt.savefig("results/{}/max_time.png".format(exp_name), dpi = 300)    
    plt.close(fig)


    fig = plt.figure()
    ax = fig.add_subplot()
    ax.set_prop_cycle('color', colors)
    fig.tight_layout(pad = 1.9)

    ax.plot(parameter_list, max_r , color =  colors[-1], linewidth = 3)
    ax.set_xlabel(x_axis)
    ax.set_ylabel("Maximum value of $R(t)$")
    plt.savefig("results/{}/max_r.png".format(exp_name), dpi = 300)
    plt.close(fig)

    print(np.where(max_r == np.max(max_r)))