In [None]:
%load_ext autoreload
%autoreload 2

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import functools

from stljax.formula import Predicate, DifferentiableAlways
from stljax.utils import smooth_mask


from matplotlib import animation
from IPython.display import HTML
from matplotlib import rc

rc("font", **{"family": "serif", "serif": ["Palatino"]})
rc("text", usetex=True)

jax.config.update("jax_enable_x64", True)


### Generate data

Generate noisy trajectories resembling a "bump" between (normalized) time [t_start, t_end]

In [None]:
T = 20
fontsize = 14
true_t_start = 0.21
true_t_end = 0.59
bs = 32


key = jax.random.key(1701)
signal_data = (
    jax.vmap(smooth_mask, [None, 0, None, None])(
        T, true_t_start + 0.02 * jax.random.normal(key, shape=(bs,)), true_t_end, 3.0
    )
    + jax.random.normal(
        key,
        shape=(
            bs,
            T,
        ),
    )
    * 0.1
    - 0.5
)
plt.figure(figsize=(5, 2))
plt.plot(jnp.linspace(0, 1, T), signal_data.T, color="black", alpha=0.2)
plt.xlabel("Normalized time", fontsize=fontsize, labelpad=-1)
plt.ylabel("Signal", fontsize=fontsize, labelpad=-1)
plt.grid()
plt.tight_layout()


Define the STL formula with differentiable time intervals and loss function

In [None]:
pred = Predicate("x", lambda x: x)  # define a predicate that is the identity function
phi = DifferentiableAlways(pred > 0.0) # define phi using differentiable time interval

# define robustness of phi, and jit it for performance
phi_robustness_jit = jax.jit(phi.robustness, static_argnames=("approx_method"))

# maximize robustness and maximize time interval
@functools.partial(jax.jit, static_argnames=("approx_method"))
def loss(signal_data, t_start, t_end, scale, approx_method, temperature, coeff):
    rob_partial = functools.partial(
        phi_robustness_jit,
        t_start=t_start,
        t_end=t_end,
        scale=scale,
        approx_method=approx_method,
        temperature=temperature,
    )
    robustness_ = jax.vmap(rob_partial, [0])(signal_data)
    # nan entries if t_start >= t_end - 0.05
    robustness = jax.nn.relu(
        -jnp.where(t_start < (t_end - 0.05), robustness_, jnp.nan)
    ).mean()
    return robustness + coeff * (t_start - t_end)

# define gradient of the loss function with respect to t_start and t_end
# this is used to update the parameters t_start and t_end during optimization
grad_loss = jax.jit(jax.grad(loss, [1, 2]), static_argnames=("approx_method"))


## Set up gradient descent routine and perform gradient descent!

In [None]:
approx_method = "logsumexp"
a = -2.0 # starting point for t_start, to be passed through sigmoid
b = 2.0 # starting point for t_end, to be passed through sigmoid
lr = 1e-2
max_steps = 5000
scale_start = 0.1 # starting scale value for the smooth mask approximation
scale_end = 20 # ending scale value for the smooth mask approximation
temperature_start = 0.1 # starting temperature for the logsumexp approximation
temperature_end = 20 # ending temperature for the logsumexp approximation
a_list = [a] # list to store t_start values during optimization
b_list = [b] # list to store t_end values during optimization
coeff_start = 0.1 # starting coefficient for the time interval penalty
coeff_end = 0.0 # ending coefficient for the time interval penalty

# Gradient descent!
for i in range(max_steps):
    j = i / max_steps
    s = (1 - j) * scale_start + j * scale_end # linear schedule for scale
    t = (1 - j) * temperature_start + j * temperature_end # linear schedule for temperature
    c = (1 - j) * coeff_start + j * coeff_end # linear schedule for coefficient
    a_ = jax.nn.sigmoid(a) # apply sigmoid to t_start
    b_ = jax.nn.sigmoid(b) # apply sigmoid to t_end
    ga, gb = grad_loss(signal_data, a_, b_, s, approx_method, t, c)
    a -= lr * ga * a_ * (1 - a_)
    b -= lr * gb * b_ * (1 - b_)
    a_list.append(a)
    b_list.append(b)
    # print(a,b)
a_list = jnp.stack(a_list)
b_list = jnp.stack(b_list)

## Visualize results

In [None]:
step_size = 50
coeff = 0.1

def visualize_loss_landscape(ax,
    loss_func, signal, scale, approx_method, temperature, coeff
):
    N = 100
    fontsize = 14
    levels = 10
    starts, ends = jnp.meshgrid(jnp.linspace(0, 1, N), jnp.linspace(0, 1, N))
    losses = jax.vmap(loss_func, [None, 0, 0, None, None, None, None])(
        signal,
        starts.reshape([-1, 1]),
        ends.reshape([-1, 1]),
        scale,
        approx_method,
        temperature,
        coeff,
    ).reshape([N, N])

    ax.contourf(starts, ends, losses, levels=levels, cmap="jet", alpha=0.4)
    # plt.colorbar()
    if approx_method != "true":
        match approx_method:
            case "logsumexp":
                app = "LSE"
            case "softmax":
                app = "soft"
        ax.set_title(
            "Loss landscape \n  (c = %.2f, $\\tau_\\mathrm{%s}$ = %.2f)"
            % (scale, app, temperature)
        )
    else:
        ax.set_title("Loss landscape \n (c = %.2f)" % (scale))
    ax.set_xlabel("$a$", fontsize=fontsize, labelpad=-3)
    ax.set_ylabel("$b$", fontsize=fontsize, labelpad=-3)
    ax.grid(zorder=-5, alpha=0.2)


def visualize_results(ax, i, approx_method):
    a, b = jax.nn.sigmoid(a_list[i]), jax.nn.sigmoid(b_list[i])
    j = i / max_steps
    s = (1 - j) * scale_start + j * scale_end
    t = (1 - j) * temperature_start + j * temperature_end
    ell = loss(signal_data, a, b, s, approx_method, t, coeff)

    # plt.figure(figsize=(4,3))

    visualize_loss_landscape(ax, loss, signal_data, s, approx_method, t, coeff)
    ax.plot(
        jax.nn.sigmoid(a_list[::step_size]),
        jax.nn.sigmoid(b_list[::step_size]),
        "o-",
        markersize=3,
        color="black",
        linewidth=1,
    )
    current_loss = loss(signal_data, a, b, s, approx_method, t, coeff)
    gt_loss = loss(signal_data, true_t_start, true_t_end, s, approx_method, t, coeff)
    ax.scatter(
        [a],
        [b],
        marker="o",
        s=40,
        color="magenta",
        label="Current: %.3f" % current_loss,
        edgecolor="black",
        zorder=4,
    )
    ax.scatter(
        [true_t_start],
        [true_t_end],
        marker="*",
        s=100,
        color="orange",
        label="Ground truth: %.3f" % gt_loss,
        edgecolor="black",
        zorder=4,
    )

    ax.vlines(a, 0, 1, color="red", linestyle="--")
    ax.hlines(b, 0, 1, color="blue", linestyle="--")
    if approx_method != "true":
        match approx_method:
            case "logsumexp":
                app = "LSE"
            case "softmax":
                app = "soft"
        ax.set_title(
            f"Step {i}, " + "Loss = %.5f (c = %.2f, $\\tau_\\mathrm{%s}$ = %.2f)" % (ell, s, app, t)
        )
    else:
        ax.set_title(f"Step {i}, " + "Loss = %.5f (c = %.2f)" % (ell, s))
    ax.legend(loc="lower right")


Animate results

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
step_size = 100

def animate(idx):
    ax.clear()
    visualize_results(ax, idx, approx_method)
    # ax.set_title(f"Step {idx}")

frames = range(0, max_steps, step_size)
ani = animation.FuncAnimation(fig, animate, frames=frames, interval=100, repeat=False)

plt.close(fig)  # Prevent duplicate static plot in notebook output
HTML(ani.to_jshtml()) # Display the animation in the notebook