In [None]:
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm_notebook
from scipy.optimize import root

import warnings

warnings.filterwarnings('ignore')

%matplotlib inline

In [None]:
default_tau = 0.15
default_std = 2
default_a = 1
default_b = 2

In [None]:
def brusselator_euler(x, tau=default_tau, a=default_a, b=default_b):
    x_n = x[0] + tau * (x[0] ** 2 * x[1] - (b + 1) * x[0] + a)
    y_n = x[1] + tau * (-x[0] ** 2 * x[1] + b * x[0])
    return np.array([x_n, y_n])

In [None]:
def get_initial(a=default_a, b=default_b, std=default_std):
    mean = a, b/a
    return [np.random.normal(mean[0], std), np.random.normal(mean[1], std)]

In [None]:
def gen_states(tau=default_tau, iter_step=3000, a=default_a, b=default_b, std=default_std):
    initial = get_initial(a=a, b=b, std=std)
    states = [initial]
    for i in range(iter_step):
        state = brusselator_euler(x=states[-1], tau=tau, a=a, b=b)
        states.append(state)
    states = np.array(states)
    return states

In [None]:
def gen_states_true(tau=default_tau, iter_step=3000, a=default_a, b=default_b, std=default_std, max_allowed_step_size=1):
    states = gen_states(tau=tau, iter_step=iter_step, a=a, b=b, std=std)
    
    def check_bad(states):
        steps = np.diff(states, axis=0)
        return np.isnan(states).any()
    while check_bad(states):
        states = gen_states(tau=tau, iter_step=iter_step, a=a, b=b, std=std)
    return states

In [None]:
def J_curve0(x, tau=default_tau, a=default_a, b=default_b):
    y = (tau * (1 - tau) * x ** 2 + tau * (b + 1) - 1) / (2 * x * tau)
    return y

In [None]:
def analytical_preimages(states, tau=default_tau, a=default_a, b=default_b):
    x_n, y_n = states[:, 0], states[:, 1]
    A = tau - tau ** 2
    B = tau ** 2 - tau * x_n - tau * y_n
    C = tau * b + tau - 1
    D = x_n - tau * a
    preimages = []
    for i in range(states.shape[0]):
        x_pre = np.sort(np.roots([A, B[i], C, D[i]]))
        if np.abs(x_pre[1] - x_n[i]) > np.abs(x_pre[2] - x_n[i]):
            x_pre = x_pre[[0, 2, 1]]
        y_pre = (x_n[i] - x_pre + (b + 1) * tau * x_pre - tau * a) / (tau * x_pre ** 2)
        pre = np.vstack([x_pre, y_pre])
        preimages.append(pre)
    return np.array(preimages)

In [None]:
fig, ax = plt.subplots(figsize=(12, 12), ncols=2, nrows=2)
b = [2, 2.5, 3, 3.2]
for i in range(2):
    for j in range(2):
        indice_b = 2 * i + j
        states = gen_states_true(b=b[indice_b])
        states = states[-1000:, :]
        J0_x1 = np.linspace(-10, 0, 10000, endpoint=False)
        J0_x2 = np.linspace(10, 0, 10000, endpoint=False)[::-1]

        J0_y1 = J_curve0(J0_x1, b=b[indice_b])
        J0_y2 = J_curve0(J0_x2, b=b[indice_b])

        preimages = analytical_preimages(states, b=b[indice_b])

        ax[i, j].scatter(J0_x1, J0_y1, color='red', alpha=0.5, label=r'$J_0$', s=10, zorder=99)
        ax[i, j].scatter(J0_x2, J0_y2, color='red', alpha=0.5, s=10, zorder=99)

        ax[i, j].scatter(states[:-1, 0], states[:-1, 1], alpha=0.7, color='orange', label=r'$\Gamma$, $F^{-1}(\Gamma)$', s=10, zorder=99)
        pre1, pre2, pre3 = preimages[:, :, 0], preimages[:, :, 1], preimages[:, :, 2]
        real_1 = np.logical_and(*np.isreal(pre1).T)
        real_2 = np.logical_and(*np.isreal(pre2).T)
        real_3 = np.logical_and(*np.isreal(pre3).T)
        all_real = np.logical_and(np.logical_and(real_1, real_2), real_3)
        ax[i, j].scatter(pre3[all_real, 0], pre3[all_real, 1], alpha=0.7, color='green',label=r'$F^{-1}(\Gamma)^{\prime}$', s=10, zorder=50)
        ax[i, j].scatter(pre1[all_real, 0], pre1[all_real, 1], alpha=0.7, color='purple',label=r'$F^{-1}(\Gamma)^{\prime \prime}$', s=10, zorder=50)

        if not i and not j:
            ax[i, j].legend(fontsize=16)

        ax[i, 0].set_ylabel('$y$')
        ax[0, j].set_xlabel('$x$')

        ax[i, j].set_xlim(-5, 10)
        ax[i, j].set_ylim(-5, 10)
        ax[i, j].set_title(r'$(a, b) = $ {}'.format((1, b[indice_b])), fontsize=20)

        ax[i, j].set_xticklabels(['' for x in ax[i, j].get_xticks()])
        ax[i, j].set_yticklabels(['' for y in ax[i, j].get_yticks()])

plt.tight_layout()

# plt.savefig('inv.pdf', bbox_inches='tight')