In [None]:
import numpy as np
import torch
import sys

sys.path.append("../")
from fixed_points.find_fixed_points_analytic import find_fixed_points_analytic
from fixed_points.constrained_scify import run_scify
from py_rnn.train import load_rnn
import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline

In [None]:
# Extract loadings of RNN
rnn_osc, model_params, task_params, training_params = load_rnn(
    "../data/student_teacher/reach_rnn"
)
alpha = rnn_osc.rnn.dt / rnn_osc.rnn.tau
z = np.ones(2)
W2 = torch.clone(rnn_osc.rnn.m.detach()).numpy()
W1 = torch.clone(rnn_osc.rnn.n.detach() / model_params["n_rec"]).numpy() * alpha
decay = 1 - alpha
a = np.ones(2) * decay
A = np.diag(a)
h2 = rnn_osc.rnn.b_rec.detach().numpy()
h1 = np.zeros(2)

In [None]:
# obtain fixed points using analytic method
D_list, D_inds, z_list, n_inverses_an = find_fixed_points_analytic(a, W1, W2, h1, -h2)
true_n_fps = len(z_list)

In [None]:
# Run approximate methods

n_iterations = 20  # run mutliple times to get distribution over n fps found
all_results = []
all_results_constrain = []
n_inverses_maxs = np.arange(1000, 7001, 500)
max_outer = 10e6  # some large number so scify will run till n_inverses_max is reached (this is asserted)
round_dec = 4  # 2 seems to be to little (some fps are not distinguished)
for i in range(n_iterations):
    print("iteration ", i)
    results = []
    results_constrain = []
    inverses = []
    inverses_constrain = []
    for n_inverses_max in n_inverses_maxs:
        dyn_objects, eigenvals, n_inverses = run_scify(
            A,
            W1,
            W2,
            h1,
            h2,
            constrain=False,
            n_inverses_max=n_inverses_max,
            round_dec=round_dec,
            outer_loop_iterations=max_outer,
        )
        results.append(len(dyn_objects[0]))
        assert n_inverses == n_inverses_max
        dyn_objects, eigenvals, n_inverses = run_scify(
            A,
            W1,
            W2,
            h1,
            h2,
            constrain=True,
            n_inverses_max=n_inverses_max,
            round_dec=round_dec,
            outer_loop_iterations=max_outer,
        )
        results_constrain.append(len(dyn_objects[0]))
    all_results.append(results)
    all_results_constrain.append(results_constrain)

all_results = np.array(all_results)
all_results_constrain = np.array(all_results_constrain)

In [None]:
# get min and mix fps found for plotting
mean = np.mean(all_results, axis=0)
max = np.max(all_results, axis=0)
min = np.min(all_results, axis=0)
mean_constrain = np.mean(all_results_constrain, axis=0)
max_constrain = np.max(all_results_constrain, axis=0)
min_constrain = np.min(all_results_constrain, axis=0)

In [None]:
# make plot
n_start = 2  # combined method has threshold cost
with mpl.rc_context(fname="matplotlibrc"):

    plt.figure(figsize=(1, 1))
    plt.plot(n_inverses_maxs, mean, label="approximate", marker="o", color="C0")
    plt.plot(
        n_inverses_maxs[n_start:],
        mean_constrain[n_start:],
        label="combined",
        marker="o",
        color="C1",
    )
    plt.fill_between(n_inverses_maxs, min, max, alpha=0.2, color="C0")
    plt.fill_between(
        n_inverses_maxs[n_start:],
        min_constrain[n_start:],
        max_constrain[n_start:],
        alpha=0.2,
        color="C1",
    )
    plt.scatter(
        n_inverses_an,
        true_n_fps,
        zorder=1000,
        color="purple",
        marker="*",
        s=100,
        label="analytic",
    )
    plt.gca().set_box_aspect(1)
    plt.legend(loc="upper right", bbox_to_anchor=(2.1, 1))
    plt.xlabel("# inverses")
    plt.ylabel("# fixed points found")
    plt.yticks([0, 10, true_n_fps])
    plt.ylim(10, true_n_fps + 0.75)
    plt.xticks([1000, 4000, 7000])
    plt.xlim(1000, 7000)
    plt.savefig("../figures/FigFP.pdf")  # , bbox_inches="tight")