In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
from stljax.formula import *
import numpy as np


In [13]:
class Temporal_Operator2(STL_Formula):
    """
    Class to compute Eventually and Always. This builds a recurrent cell to perform dynamic programming

    Args:
        subformula: The subformula that the temporal operator is applied to.
        interval: The time interval that the temporal operator operates on. Default: None which means [0, jnp.inf]. Other options car be: [a, b] (b < jnp.inf), [a, jnp.inf] (a > 0)

    NOTE: Assume that the interval is describing the INDICES of the desired time interval. The user is responsible for converting the time interval (in time units) into indices (integers) using knowledge of the time step size.
    """
    def __init__(self, subformula, interval=None):
        super().__init__()
        self.subformula = subformula
        self.interval = interval
        self._interval = [0, jnp.inf] if self.interval is None else self.interval
        self.hidden_dim = 1 if not self.interval else self.interval[-1]    # hidden_dim=1 if interval is [0, ∞) otherwise hidden_dim=end of interval
        if self.hidden_dim == jnp.inf:
            self.hidden_dim = self.interval[0]
        self.steps = 1 if not self.interval else self.interval[-1] - self.interval[0] + 1   # steps=1 if interval is [0, ∞) otherwise steps=length of interval
        self.operation = None
        # Matrices that shift a vector and add a new entry at the end.
        self.M = jnp.diag(jnp.ones(self.hidden_dim-1), k=1)
        self.b = jnp.zeros(self.hidden_dim)
        self.b = self.b.at[-1].set(1)
        self.LARGE_NUMBER = 1E9

    def _cell(self, x, hidden_state, **kwargs):
        """
        This function describes the operation that takes place at each recurrent step.
        Args:
            x: the input state at time t [batch_size, 1, ...]
            hidden_state: the hidden state. It is either a tensor, or a tuple of tensors, depending on the interval chosen and other arguments. Generally, the hidden state is of size [batch_size, hidden_dim,...]

        Return:
            output and next hidden_state
        """
        raise NotImplementedError("_cell is not implemented")

    def _run_cell(self, signal, padding, **kwargs):
        """
        Function to run a signal through a cell T times, where T is the length of the signal in the time dimension

        Args:
            signal: input signal, size = [time_dim,]
            time_dim: axis corresponding to time_dim. Default: 0
            kwargs: Other arguments including time_dim, approx_method, temperature

        Return:
            outputs: list of outputs
            states: list of hidden_states
        """
        time_dim = 0  # assuming signal is [time_dim,...]
        outputs = []
        states = []
        hidden_state = self._initialize_hidden_state(signal, padding=padding) # [hidden_dim]
        signal_split = jnp.split(signal, signal.shape[time_dim], time_dim)    # list of x at each time step
        
        for i in range(signal.shape[time_dim]):
            o, hidden_state = self._cell(signal_split[i], hidden_state, **kwargs)
            outputs.append(o)
            states.append(hidden_state)
        return outputs, states



    def _robustness_trace(self, signal, padding, **kwargs):
        """
        Function to compute robustness trace of a temporal STL formula
        First, compute the robustness trace of the subformula, and use that as the input for the recurrent computation

        Args:
            signal: input signal, size = [bs, time_dim, ...]
            time_dim: axis corresponding to time_dim. Default: 1
            kwargs: Other arguments including time_dim, approx_method, temperature

        Returns:
            robustness_trace: jnp.array. Same size as signal.
        """
        time_dim = 0  # assuming signal is [time_dim,...]
        trace = self.subformula(signal, **kwargs)
        outputs, _ = self._run_cell(trace, padding, **kwargs)
        return jnp.concatenate(outputs, axis=time_dim)     

    def robustness(self, signal, **kwargs):
        """
        Computes the robustness value. Extracts the last entry along time_dim of robustness trace.

        Args:
            signal: jnp.array or Expression. Expected size [bs, time_dim, state_dim]
            kwargs: Other arguments including time_dim, approx_method, temperature

        Return: jnp.array, same as input with the time_dim removed.
        """
        return self.__call__(signal, **kwargs)[-1]
        # return jnp.rollaxis(self.__call__(signal, **kwargs), time_dim)[-1]

    def _next_function(self):
        """ next function is the input subformula. For visualization purposes """
        return [self.subformula]

class AlwaysRecurrent2(Temporal_Operator):
    """
    The Always STL formula □_[a,b] subformula
    The robustness value is the minimum value of the input trace over a prespecified time interval

    Args:
        subformula: subformula that the Always operation is applied on
        interval: time interval [a,b] where a, b are indices along the time dimension. It is up to the user to keep track of what the timestep size is.
    """
    def __init__(self, subformula, interval=None):
        super().__init__(subformula=subformula, interval=interval)


    def _initialize_hidden_state(self, signal, padding):
        """
        Compute the initial hidden state.

        Args:
            signal: the input signal. Expected size [time_dim,]

        Returns:
            h0: initial hidden state is [hidden_dim,]

        Notes:
        Initializing the hidden state requires padding on the signal. Currently, the default is to extend the last value.
        TODO: have option on this padding

        """
        # Case 1, 2, 4
        # TODO: make this less hard-coded. Assumes signal is [bs, time_dim, signal_dim], and already reversed
        # pads with the signal value at the last time step.
        if padding == "last":
            pad_value = jax.lax.stop_gradient(signal)[0]
        elif padding == "mean":
            pad_value = jax.lax.stop_gradient(signal).mean(0)
        else:
            pad_value = self.LARGE_NUMBER

        h0 = jnp.ones([self.hidden_dim, *signal.shape[1:]]) * pad_value

        # Case 3: if self.interval is [a, jnp.inf), then the hidden state is a tuple (like in an LSTM)
        if (self._interval[1] == jnp.inf) & (self._interval[0] > 0):
            c0 = signal[:1]
            return (c0, h0)
        return h0

    def _cell(self, x, hidden_state, **kwargs):
        """
        see Temporal_Operator._cell
        """
        time_dim = 0  # assuming signal is [time_dim,...]
        # Case 1, interval = [0, inf]
        if self.interval is None:
            input_ = jnp.concatenate([hidden_state, x], axis=time_dim)                # [rnn_dim+1,]
            output = minish(input_, time_dim, keepdims=True, **kwargs)       # [1,]
            return output, output

        # Case 3: self.interval is [a, np.inf)
        if (self._interval[1] == jnp.inf) & (self._interval[0] > 0):
            c, h = hidden_state
            ch = jnp.concatenate([c, h[:1]], axis=time_dim)                             # [2,]
            output = minish(ch, time_dim, keepdims=True, **kwargs)               # [1,]
            hidden_state_ = (output, self.M @ h + self.b * x)

        # Case 2 and 4: self.interval is [a, b]
        else:
            hidden_state_ = self.M @ hidden_state + self.b * x
            hx = jnp.concatenate([hidden_state, x], axis=time_dim)                             # [rnn_dim+1,]
            input_ = hx[:self.steps]                               # [self.steps,]
            output = minish(input_, time_dim, **kwargs)               # [1,]
        return output, hidden_state_

    def robustness_trace(self, signal, padding=1E6, **kwargs):
        """
        Function to compute robustness trace of a temporal STL formula
        First, compute the robustness trace of the subformula, and use that as the input for the recurrent computation

        Args:
            signal: input signal, size = [bs, time_dim, ...]
            time_dim: axis corresponding to time_dim. Default: 1
            kwargs: Other arguments including time_dim, approx_method, temperature

        Returns:
            robustness_trace: jnp.array. Same size as signal.
        """
        return self._robustness_trace(signal, padding, **kwargs)

    def __str__(self):
        return "◻ " + str(self._interval) + "( " + str(self.subformula) + " )"


class EventuallyRecurrent2(Temporal_Operator):
    """
    The Eventually STL formula ♢_[a,b] subformula
    The robustness value is the minimum value of the input trace over a prespecified time interval

    Args:
        subformula: subformula that the Eventually operation is applied on
        interval: time interval [a,b] where a, b are indices along the time dimension. It is up to the user to keep track of what the timestep size is.
    """
    def __init__(self, subformula, interval=None):
        super().__init__(subformula=subformula, interval=interval)

    def _initialize_hidden_state(self, signal, padding):
        """
        Compute the initial hidden state.

        Args:
            signal: the input signal. Expected size [time_dim,]

        Returns:
            h0: initial hidden state is [hidden_dim,]

        Notes:
        Initializing the hidden state requires padding on the signal. Currently, the default is to extend the last value.
        TODO: have option on this padding

        """
        # Case 1, 2, 4
        # TODO: make this less hard-coded. Assumes signal is [bs, time_dim, signal_dim], and already reversed
        # pads with the signal value at the last time step.
        if padding == "last":
            pad_value = jax.lax.stop_gradient(signal)[0]
        elif padding == "mean":
            pad_value = jax.lax.stop_gradient(signal).mean(0)
        else:
            pad_value = -self.LARGE_NUMBER

        h0 = jnp.ones([self.hidden_dim, *signal.shape[1:]]) * pad_value

        # Case 3: if self.interval is [a, jnp.inf), then the hidden state is a tuple (like in an LSTM)
        if (self._interval[1] == jnp.inf) & (self._interval[0] > 0):
            c0 = signal[:1]
            return (c0, h0)
        return h0

    def _cell(self, x, hidden_state, **kwargs):
        """
        see Temporal_Operator._cell
        """
        time_dim = 0  # assuming signal is [time_dim,...]
        # Case 1, interval = [0, inf]
        if self.interval is None:
            input_ = jnp.concatenate([hidden_state, x], axis=time_dim)                # [rnn_dim+1, ]
            output = maxish(input_, time_dim, keepdims=True, **kwargs)       # [1, ]
            return output, output

        # Case 3: self.interval is [a, np.inf)
        if (self._interval[1] == jnp.inf) & (self._interval[0] > 0):
            c, h = hidden_state
            ch = jnp.concatenate([c, h[:1]], axis=time_dim)                             # [2, ]
            output = maxish(ch, time_dim, keepdims=True, **kwargs)               # [1, ]
            hidden_state_ = (output, self.M @ h + self.b * x)

        # Case 2 and 4: self.interval is [a, b]
        else:
            hidden_state_ = self.M @ hidden_state + self.b * x
            hx = jnp.concatenate([hidden_state, x], axis=time_dim)                             # [rnn_dim+1, ]
            input_ = hx[:self.steps]                               # [self.steps, ]
            output = maxish(input_, time_dim, **kwargs)               # [1, ]
        return output, hidden_state_

    def robustness_trace(self, signal, padding=-1E6, **kwargs):
        """
        Function to compute robustness trace of a temporal STL formula
        First, compute the robustness trace of the subformula, and use that as the input for the recurrent computation

        Args:
            signal: input signal, size = [bs, time_dim, ...]
            time_dim: axis corresponding to time_dim. Default: 1
            kwargs: Other arguments including time_dim, approx_method, temperature

        Returns:
            robustness_trace: jnp.array. Same size as signal.
        """
        return self._robustness_trace(signal, padding, **kwargs)


    def __str__(self):
        return "♢ " + str(self._interval) + "( " + str(self.subformula) + " )"


In [16]:

# analyzing gradients
interval = None
pred = Predicate('x', lambda x: x)
rec = EventuallyRecurrent((pred > 0.), interval=interval)
rec2 = EventuallyRecurrent2((pred > 0.), interval=interval)

ev = Eventually((pred > 0.), interval=interval)

rec_jit = jax.jit(rec.robustness)
rec2_jit = jax.jit(rec2.robustness)
ev_jit = jax.jit(ev.robustness)



In [5]:
class UntilRecurrent2(STL_Formula):
    """
    The Until STL operator U. Subformula1 U_[a,b] subformula2
    Arg:
        subformula1: subformula for lhs of the Until operation
        subformula2: subformula for rhs of the Until operation
        interval: time interval [a,b] where a, b are indices along the time dimension. It is up to the user to keep track of what the timestep is.
        overlap: If overlap=True, then the last time step that ϕ is true, ψ starts being true. That is, sₜ ⊧ ϕ and sₜ ⊧ ψ at a common time t. If overlap=False, when ϕ stops being true, ψ starts being true. That is sₜ ⊧ ϕ and sₜ+₁ ⊧ ψ, but sₜ ¬⊧ ψ
    """

    def __init__(self, subformula1, subformula2, interval=None, overlap=True):
        super().__init__()
        self.subformula1 = subformula1
        self.subformula2 = subformula2
        self.interval = interval
        if overlap == False:
            self.subformula2 = Eventually(subformula=subformula2, interval=[0,1])
        self.LARGE_NUMBER = 1E9

    def robustness_trace(self, signal, **kwargs):
        """
        Computing robustness trace of subformula1 U subformula2  (see paper)

        Args:
            signal: input signal for the formula. If using Expressions to define the formula, then inputs a tuple of signals corresponding to each subformula. If using Predicates to define the formula, then inputs is just a single jnp.array. Not need for different signals for each subformula. Expected signal is size [batch_size, time_dim, x_dim]
            time_dim: axis for time_dim. Default: 1
            kwargs: Other arguments including time_dim, approx_method, temperature

        Returns:
            robustness_trace: jnp.array. Same size as signal.
        """


        # TODO (karenl7) this really assumes axis=1 is the time dimension. Can this be generalized?
        time_dim = 0  # assuming signal is [time_dim,...]
        LARGE_NUMBER = self.LARGE_NUMBER

        if isinstance(signal, tuple):
            # for formula defined using Expression
            assert signal[0].shape[time_dim] == signal[1].shape[time_dim]
            trace1 = self.subformula1(signal[0], **kwargs)
            trace2 = self.subformula2(signal[1], **kwargs)
            n_time_steps = signal[0].shape[time_dim] # TODO: WIP
        else:
            # for formula defined using Predicate
            trace1 = self.subformula1(signal, **kwargs)
            trace2 = self.subformula2(signal, **kwargs)
            n_time_steps = signal.shape[time_dim] # TODO: WIP

        Alw = AlwaysRecurrent(subformula=Identity(name=str(self.subformula1)))
        LHS = jnp.permute_dims(jnp.repeat(jnp.expand_dims(trace2, -1), n_time_steps, axis=-1), [1,0])  # [sub_signal, t_prime]
        RHS = jnp.ones_like(LHS) * -LARGE_NUMBER  # [sub_signal, t_prime]

        # Case 1, interval = [0, inf]
        if self.interval == None:
            for i in range(n_time_steps):
                RHS = RHS.at[i:,i].set(Alw(trace1[i:], **kwargs))

        # Case 2 and 4: self.interval is [a, b], a ≥ 0, b < ∞
        elif self.interval[1] < jnp.inf:
            a = self.interval[0]
            b = self.interval[1]
            for i in range(n_time_steps):
                end = i+b+1
                RHS = RHS.at[i+a:end,i].set(Alw(trace1[i:end], **kwargs)[a:])

        # Case 3: self.interval is [a, np.inf), a ≂̸ 0
        else:
            a = self.interval[0]
            for i in range(n_time_steps):
                RHS = RHS.at[i+a:,i].set(Alw(trace1[i:], **kwargs)[a:])

        return maxish(minish(jnp.stack([LHS, RHS], axis=-1), axis=-1, keepdims=False, **kwargs), axis=-1, keepdims=False, **kwargs)

    def robustness(self, signal, **kwargs):
        """
        Computes the robustness value. Extracts the last entry along time_dim of robustness trace.

        Args:
            signal: jnp.array or Expression. Expected size [bs, time_dim, state_dim]
            kwargs: Other arguments including time_dim, approx_method, temperature

        Return: jnp.array, same as input with the time_dim removed.
        """
        return self.__call__(signal, **kwargs)[-1]
        # return jnp.rollaxis(self.__call__(signal, **kwargs), time_dim)[-1]

    def _next_function(self):
        """ next function is the input subformulas. For visualization purposes """
        return [self.subformula1, self.subformula2]

    def __str__(self):
        return  "(" + str(self.subformula1) + ")" + " U " + "(" + str(self.subformula2) + ")"



In [6]:
# analyzing gradients
interval = None
pred = Predicate('x', lambda x: x)
rec = UntilRecurrent(pred > 0., pred > 0., interval=interval)
rec2 = UntilRecurrent2(pred > 0., pred > 0., interval=interval)

ev = Until(pred > 0., pred > 0., interval=interval)

rec_jit = jax.jit(rec.robustness)
rec2_jit = jax.jit(rec2.robustness)
ev_jit = jax.jit(ev.robustness)


In [21]:
T = 128
bs = 32
signal = jnp.array(np.random.randn(bs, T)) * 1.
signal_flip = jnp.flip(signal)


In [22]:
%timeit jax.vmap(jax.grad(rec_jit))(signal)
%timeit jax.vmap(jax.grad(rec2_jit))(signal)
%timeit jax.vmap(jax.grad(ev_jit))(signal)

857 μs ± 148 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
875 μs ± 98.6 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.66 ms ± 72.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
%timeit jax.vmap(rec_jit)(signal)
%timeit jax.vmap(rec2_jit)(signal)
%timeit jax.vmap(ev_jit)(signal)

130 μs ± 634 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
132 μs ± 2.52 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
1.22 ms ± 7.96 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
