In [None]:
# analyzing gradients

pred = Predicate('x', lambda x: x)
rec = EventuallyRecurrent(pred > 0.)
ev = Eventually(pred > 0.)


In [None]:
T = 10
signal = jnp.arange(T) * 1.
# signal = jnp.array(np.random.randn(T))
signal_flip = jnp.flip(signal)


In [None]:
approx_method = "logsumexp"
temperature = 1.
axis = 0
keepdims = True



def test_grad(signal, temperature, approx_method):
    signal_flip = jnp.flip(signal)
    m = lambda x: maxish(x, axis=axis, keepdims=keepdims, approx_method=approx_method, temperature=temperature)
    # s0 = jnp.array([signal[0]])
    s0 = jnp.zeros(0)
    for s in signal_flip:
        s0 = m(jnp.concat([s0, jnp.array([s])]))
    return s0.squeeze() 


mask_grads = jax.grad(ev.robustness)(signal, approx_method=approx_method, temperature=temperature)
rec_grads = jax.grad(test_grad)(signal, temperature, approx_method)
rec_grads = jax.grad(rec.robustness)(signal_flip, approx_method=approx_method, temperature=temperature)


In [None]:
def rec_robustness(signal, approx_method, temperature, padding=1E9):
    signal_flip = jnp.flip(signal)
    return rec(signal_flip, approx_method=approx_method, temperature=temperature, padding=padding)[-1]

In [None]:
def plot_gradients(temperature, approx_method):
    mask_grads = jax.grad(ev.robustness)(signal, approx_method=approx_method, temperature=temperature)
    # rec_grads = jnp.flip(jax.grad(rob, 1)(rec, signal_flip, approx_method=approx_method, temperature=temperature))
    rec_grads = jax.grad(test_grad)(signal, temperature, approx_method)
    # rec_grads = jax.grad(rec_robustness)(signal, approx_method=approx_method, temperature=temperature, padding=1E9)
    plt.bar(jnp.arange(T)-0.125, mask_grads, label="Mask", width=0.22, alpha=0.5)
    plt.bar(jnp.arange(T)+0.124, rec_grads, label="Rec.", width=0.22, alpha=0.5)
    height = max(mask_grads.max(), rec_grads.max())
    plt.vlines(jnp.argmax(signal), 0, height, zorder=-5, linestyle='--', label='True max')
    # plt.ylim([-0.2, 1.2])
    plt.grid()
    plt.title("$\\tau$ = %.2f"%temp)
    if i in [2, 3]:
        plt.xlabel("Timestep",labelpad=-2)
    if i in [0,2]:
        plt.ylabel("Gradient",labelpad=-2)

In [None]:
plt.figure(figsize=(5,3))
for (i, temp) in enumerate([0.1, 1., 2., 10.]):
    plt.subplot(2,2,i+1)
    plot_gradients(temp, approx_method)

plt.legend(loc="upper left")
plt.tight_layout()
# plt.savefig("figs/softmax_gradients.png", dpi=200, transparent=True)

In [None]:
temp_slider = widgets.FloatSlider(value=1., min=0.1, max=10., step=0.1, description='temperature:')
approx_method_slider = widgets.Dropdown(
    options=['true', "logsumexp", "softmax"],
    value='true',
    description='approx method:',
    disabled=False,
)
interact(plot_gradients, temperature=temp_slider, approx_method=approx_method_slider)
