In [1]:
%load_ext autoreload
%autoreload 2



In [2]:
import jax
from stljax.formula import *
import numpy as np


In [2]:
class UntilRecurrent2(STL_Formula):
    """
    The Until STL operator U. Subformula1 U_[a,b] subformula2
    Arg:
        subformula1: subformula for lhs of the Until operation
        subformula2: subformula for rhs of the Until operation
        interval: time interval [a,b] where a, b are indices along the time dimension. It is up to the user to keep track of what the timestep is.
        overlap: If overlap=True, then the last time step that ϕ is true, ψ starts being true. That is, sₜ ⊧ ϕ and sₜ ⊧ ψ at a common time t. If overlap=False, when ϕ stops being true, ψ starts being true. That is sₜ ⊧ ϕ and sₜ+₁ ⊧ ψ, but sₜ ¬⊧ ψ
    """

    def __init__(self, subformula1, subformula2, interval=None, overlap=True):
        super().__init__()
        self.subformula1 = subformula1
        self.subformula2 = subformula2
        self.interval = interval
        if overlap == False:
            self.subformula2 = Eventually(subformula=subformula2, interval=[0,1])
        self.LARGE_NUMBER = 1E9

    def robustness_trace(self, signal, **kwargs):
        """
        Computing robustness trace of subformula1 U subformula2  (see paper)

        Args:
            signal: input signal for the formula. If using Expressions to define the formula, then inputs a tuple of signals corresponding to each subformula. If using Predicates to define the formula, then inputs is just a single jnp.array. Not need for different signals for each subformula. Expected signal is size [batch_size, time_dim, x_dim]
            time_dim: axis for time_dim. Default: 1
            kwargs: Other arguments including time_dim, approx_method, temperature

        Returns:
            robustness_trace: jnp.array. Same size as signal.
        """


        # TODO (karenl7) this really assumes axis=1 is the time dimension. Can this be generalized?
        time_dim = 0  # assuming signal is [time_dim,...]
        LARGE_NUMBER = self.LARGE_NUMBER

        if isinstance(signal, tuple):
            # for formula defined using Expression
            assert signal[0].shape[time_dim] == signal[1].shape[time_dim]
            trace1 = self.subformula1(signal[0], **kwargs)
            trace2 = self.subformula2(signal[1], **kwargs)
            n_time_steps = signal[0].shape[time_dim] # TODO: WIP
        else:
            # for formula defined using Predicate
            trace1 = self.subformula1(signal, **kwargs)
            trace2 = self.subformula2(signal, **kwargs)
            n_time_steps = signal.shape[time_dim] # TODO: WIP

        Alw = AlwaysRecurrent(subformula=Identity(name=str(self.subformula1)))
        LHS = jnp.permute_dims(jnp.repeat(jnp.expand_dims(trace2, -1), n_time_steps, axis=-1), [1,0])  # [sub_signal, t_prime]
        RHS = jnp.ones_like(LHS) * -LARGE_NUMBER  # [sub_signal, t_prime]

        # Case 1, interval = [0, inf]
        if self.interval == None:
            for i in range(n_time_steps):
                RHS = RHS.at[i:,i].set(Alw(trace1[i:], **kwargs))

        # Case 2 and 4: self.interval is [a, b], a ≥ 0, b < ∞
        elif self.interval[1] < jnp.inf:
            a = self.interval[0]
            b = self.interval[1]
            for i in range(n_time_steps):
                end = i+b+1
                RHS = RHS.at[i+a:end,i].set(Alw(trace1[i:end], **kwargs)[a:])

        # Case 3: self.interval is [a, np.inf), a ≂̸ 0
        else:
            a = self.interval[0]
            for i in range(n_time_steps):
                RHS = RHS.at[i+a:,i].set(Alw(trace1[i:], **kwargs)[a:])

        return maxish(minish(jnp.stack([LHS, RHS], axis=-1), axis=-1, keepdims=False, **kwargs), axis=-1, keepdims=False, **kwargs)

    def robustness(self, signal, **kwargs):
        """
        Computes the robustness value. Extracts the last entry along time_dim of robustness trace.

        Args:
            signal: jnp.array or Expression. Expected size [bs, time_dim, state_dim]
            kwargs: Other arguments including time_dim, approx_method, temperature

        Return: jnp.array, same as input with the time_dim removed.
        """
        return self.__call__(signal, **kwargs)[-1]
        # return jnp.rollaxis(self.__call__(signal, **kwargs), time_dim)[-1]

    def _next_function(self):
        """ next function is the input subformulas. For visualization purposes """
        return [self.subformula1, self.subformula2]

    def __str__(self):
        return  "(" + str(self.subformula1) + ")" + " U " + "(" + str(self.subformula2) + ")"



In [390]:
# analyzing gradients
interval = None
pred = Predicate('x', lambda x: x)
rec = UntilRecurrent(pred > 0., pred < 5., interval=interval)
# rec2 = UntilRecurrent2(pred > 0., pred > 0., interval=interval)

ev = Until(pred > 0., pred < 5., interval=interval)

rec_jit = jax.jit(rec.robustness)
# rec2_jit = jax.jit(rec2.robustness)
ev_jit = jax.jit(ev.robustness)


In [388]:
T = 128
bs = 32
signal = jnp.array(np.random.randn(bs, T)) * 1.
signal_flip = jnp.flip(signal)


In [389]:
%timeit jax.vmap(jax.grad(rec_jit))(signal)
%timeit jax.vmap(jax.grad(rec2_jit))(signal)
%timeit jax.vmap(jax.grad(ev_jit))(signal)

KeyboardInterrupt: 

In [23]:
%timeit jax.vmap(rec_jit)(signal)
%timeit jax.vmap(rec2_jit)(signal)
%timeit jax.vmap(ev_jit)(signal)

130 μs ± 634 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
132 μs ± 2.52 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
1.22 ms ± 7.96 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [3]:
T = 8
bs = 32
# signal = jnp.array(np.random.randn(T)) * 1.
signal = jnp.arange(T)
signal_flip = jnp.flip(signal)


In [5]:
pred = Predicate('x', lambda x: x)

subformula1 = pred > 0.
subformula2 = pred < 5.
interval = [2, T-1]

# TODO (karenl7) this really assumes axis=1 is the time dimension. Can this be generalized?
time_dim = 0  # assuming signal is [time_dim,...]
LARGE_NUMBER = 10

if isinstance(signal, tuple):
    # for formula defined using Expression
    assert signal[0].shape[time_dim] == signal[1].shape[time_dim]
    trace1 = subformula1(signal[0])
    trace2 = subformula2(signal[1])
    n_time_steps = signal[0].shape[time_dim] # TODO: WIP
else:
    # for formula defined using Predicate
    trace1 = subformula1(signal)
    trace2 = subformula2(signal)
    n_time_steps = signal.shape[time_dim] # TODO: WIP

Alw = AlwaysRecurrent(subformula=Identity(name=str(subformula1)))
LHS = jnp.permute_dims(jnp.repeat(jnp.expand_dims(trace2, -1), n_time_steps, axis=-1), [1,0])  # [sub_signal, t_prime]
RHS = jnp.ones_like(LHS) * -LARGE_NUMBER  # [sub_signal, t_prime]

trace1_padded = jnp.concatenate([trace1, jnp.repeat(trace1[-1], trace1.shape[0]-1)])

def f_(carry, i):
    y_ = jnp.ones(2 * T) * -LARGE_NUMBER
    subsignal_padded = jax.lax.dynamic_slice(trace1_padded, (i,), (T,))
    update_values = Alw(subsignal_padded)
    y = jax.lax.dynamic_update_slice(y_, update_values[:interval[1] - interval[0] + 1], (i + interval[0],))
    return carry, y

jax.lax.scan(f_, 0, jnp.arange(n_time_steps))[1].T[:T]

# Case 1, interval = [0, inf]
if interval == None:
    # for i in range(n_time_steps):
    #     RHS = RHS.at[i:,i].set(Alw(trace1[i:]))
    
    def f_(carry, i):
        y_ = jnp.ones(2 * T) * -LARGE_NUMBER
        subtrace1 = jax.lax.dynamic_slice(trace1_padded, (i,), (T,))
        update_values = Alw(subtrace1)
        y = jax.lax.dynamic_update_slice(y_, update_values, (i,))
        return carry, y

    RHS = jax.lax.scan(f_, 0, jnp.arange(n_time_steps))[1].T[:T]
elif interval[1] < jnp.inf:
    a = interval[0]
    b = interval[1]
    for i in range(n_time_steps):
        end = i+b+1
        RHS = RHS.at[i+a:end,i].set(Alw(trace1[i:end])[a:])
else:
    hidden_state_ = ev.M @ hidden_state + ev.b * x
    hx = jnp.concatenate([hidden_state, x], axis=time_dim)                             # [rnn_dim+1, ]
    input_ = hx[:ev.steps]                               # [self.steps, ]
    output = maxish(input_, time_dim)               # [1, ]
RHS

Array([[-10., -10., -10., -10., -10., -10., -10., -10.],
       [-10., -10., -10., -10., -10., -10., -10., -10.],
       [  0., -10., -10., -10., -10., -10., -10., -10.],
       [  0.,   1., -10., -10., -10., -10., -10., -10.],
       [  0.,   1.,   2., -10., -10., -10., -10., -10.],
       [  0.,   1.,   2.,   3., -10., -10., -10., -10.],
       [  0.,   1.,   2.,   3.,   4., -10., -10., -10.],
       [  0.,   1.,   2.,   3.,   4.,   5., -10., -10.]],      dtype=float32, weak_type=True)

In [7]:
signal_padded

Array([0., 1., 2., 3., 4., 5., 6., 7., 7., 7., 7., 7., 7., 7., 7., 7.],      dtype=float32)

In [8]:
signal_padded = jnp.concatenate([jnp.arange(T) , jnp.ones(T) * (T-1)])

def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, jnp.stack(ys)

def f_(carry, i):
    y_ = jnp.ones(2 * T) * -LARGE_NUMBER
    subsignal_padded = jax.lax.dynamic_slice(signal_padded, (i,), (T,))
    update_values = Alw(subsignal_padded)
    y = jax.lax.dynamic_update_slice(y_, update_values[:interval[1] - interval[0] + 1], (i + interval[0],))
    return carry, y
scan(f_, 0, jnp.arange(n_time_steps))[1].T[:T]
# jax.lax.scan(f_, 0, jnp.arange(n_time_steps))[1].T[:T]

Array([[-10., -10., -10., -10., -10., -10., -10., -10.],
       [-10., -10., -10., -10., -10., -10., -10., -10.],
       [  0., -10., -10., -10., -10., -10., -10., -10.],
       [  0.,   1., -10., -10., -10., -10., -10., -10.],
       [  0.,   1.,   2., -10., -10., -10., -10., -10.],
       [  0.,   1.,   2.,   3., -10., -10., -10., -10.],
       [  0.,   1.,   2.,   3.,   4., -10., -10., -10.],
       [  0.,   1.,   2.,   3.,   4.,   5., -10., -10.]], dtype=float32)

In [31]:
y_ = jnp.ones(2 * T) * -LARGE_NUMBER
subsignal = jax.lax.dynamic_slice(signal, (i,), (T,))
update_values = Alw(subsignal)
update_values
y = jax.lax.dynamic_update_slice(y_, update_values, (i + 3,))
y


Array([-10., -10., -10., -10., -10., -10., -10., -10.,   0.,   0.,   0.,
         0.,   0.,   0.,   1.,   2.], dtype=float32)

In [397]:
signal1 = trace1
signal2 = trace2
signal1, signal2


(Array([-0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.], dtype=float32, weak_type=True),
 Array([ 5.,  4.,  3.,  2.,  1.,  0., -1., -2.], dtype=float32, weak_type=True))

In [471]:
hidden_dim = T
M = jnp.diag(jnp.ones(hidden_dim-1), k=1)
b = jnp.zeros(hidden_dim)
b = b.at[-1].set(1)


interval = [0, 4]
if interval is None:
    start_idx = 0
    end_idx = hidden_dim
else:
    if interval[1] is jnp.inf:
        start_idx = 0
    else:
        start_idx = -interval[1] - 1

    if interval[0] == 0:
        end_idx = hidden_dim
    else: 
        end_idx = -interval[0]

end_idx, start_idx, signal[start_idx:end_idx], signal
@jax.jit
def cell(state, hidden):
    x1, x2 = state
    h1, h2 = hidden
    h1_new = minish(jnp.stack([M @ h1 + b * x1, x1 * jnp.ones(hidden_dim)]), axis=0, keepdims=False)
    h2_new = M @ h2 + b * x2
    # z = minish(jnp.stack([h1_new, h2_new]), axis=0, keepdims=False)[start_idx:end_idx]
    z = minish(jnp.stack([h1_new, h2_new]), axis=0, keepdims=False)
    def g_(carry, x):
        carry = maxish(jnp.array([carry, x]), axis=0, keepdims=False)
        return carry, carry
    output, _ = jax.lax.scan(g_,  -10, z[start_idx:end_idx])
    # output = jax.lax.scan(g_, jnp.ones(1) * -10, z[start_idx:end_idx])
    # output = maxish(z, axis=0)
    return output, (h1_new, h2_new)
    

In [472]:

h1 = -10 * jnp.ones_like(trace1)
h2 = -10 * jnp.ones_like(trace2)


os = []
h1s = [h1]
h2s = [h2]
for (x1,x2) in zip(signal1, signal2):
    o, (h1, h2) = cell((x1, x2), (h1, h2))
    os.append(o)
    h1s.append(h1)
    h2s.append(h2)
jnp.stack(os)

Array([0., 1., 2., 2., 2., 2., 2., 2.], dtype=float32)

In [459]:
rec = UntilRecurrent(pred > 0., pred < 5., interval=interval, )
rec.LARGE_NUMBER = 10
rec(signal)

Array([-0.,  1.,  2.,  2.,  2.,  2.,  2.,  2.], dtype=float32)

In [67]:
def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, jnp.stack(ys)

class UntilRecurrent3(STL_Formula):
    
    def __init__(self, subformula1, subformula2, interval=None, overlap=True):
        super().__init__()
        self.subformula1 = subformula1
        self.subformula2 = subformula2
        self.interval = interval
        if overlap == False:
            self.subformula2 = Eventually(subformula=subformula2, interval=[0,1])
        self.LARGE_NUMBER = 1E9
        self.Alw = AlwaysRecurrent(subformula=Identity(name=str(self.subformula1)))
        
        if self.interval is None:
            self.hidden_dim = None
        elif interval[1] == jnp.inf:
            self.hidden_dim = None
        else:
            self.hidden_dim = interval[1] + 1
            
        

        
    def _initialize_hidden_state(self, signal, padding=None, **kwargs):
        time_dim = 0  # assuming signal is [time_dim,...]        

        if isinstance(signal, tuple):
            # for formula defined using Expression
            assert signal[0].shape[time_dim] == signal[1].shape[time_dim]
            trace1 = self.subformula1(signal[0], **kwargs)
            trace2 = self.subformula2(signal[1], **kwargs)
            n_time_steps = signal[0].shape[time_dim]
        else:
            # for formula defined using Predicate
            trace1 = self.subformula1(signal, **kwargs)
            trace2 = self.subformula2(signal, **kwargs)
            n_time_steps = signal.shape[time_dim]
            
        # compute hidden dim if signal length was needed
        if self.hidden_dim is None:
            self.hidden_dim = n_time_steps
        if self.interval is None: 
            self.interval = [0, n_time_steps - 1]
        elif self.interval[1] == jnp.inf:
            self.interval[1] = n_time_steps - 1
            
        self.ones_array = jnp.ones(self.hidden_dim)
            
        # set shift operations given hidden_dim
        self.M = jnp.diag(jnp.ones(self.hidden_dim-1), k=1)
        self.b = jnp.zeros(self.hidden_dim)
        self.b = self.b.at[-1].set(1)
        
        pad_value = -self.LARGE_NUMBER
            
        h1 = pad_value * self.ones_array
        h2 = pad_value * self.ones_array
        return (h1, h2), trace1, trace2
        
    def _get_interval_indices(self):
        start_idx = -self.hidden_dim
        end_idx = -self.interval[0]

        return start_idx, (None if end_idx  == 0 else end_idx)

    def _cell(self, state, hidden, **kwargs):
        x1, x2 = state
        h1, h2 = hidden
        h1_shift = self.M @ h1 + self.b * x1
        h1_new = jnp.flip(self.Alw(jnp.flip(h1_shift), **kwargs))
        # h1_new = minish(jnp.stack([self.M @ h1 + self.b * x1, 
        #                            x1 * self.ones_array]), 
        #                 axis=0, keepdims=False, **kwargs)
        h2_new = self.M @ h2 + self.b * x2
        start_idx, end_idx = self._get_interval_indices()
        z = minish(jnp.stack([h1_new, h2_new]), axis=0, keepdims=False, **kwargs)[start_idx:end_idx]
        
            
        def g_(carry, x):
            carry = maxish(jnp.array([carry, x]), axis=0, keepdims=False, **kwargs)
            return carry, carry

        # output, _ = scan(g_,  -self.LARGE_NUMBER, z)
        # output = maxish(z[start_idx:end_idx], axis=0, keepdims=False, **kwargs)
        output, _ = jax.lax.scan(g_,  -self.LARGE_NUMBER, z[start_idx:end_idx])

        return output, (h1_new, h2_new)
    
    def robustness_trace(self, signal, padding=None, **kwargs):
        """
        Function to run a signal through a cell T times, where T is the length of the signal in the time dimension

        Args:
            signal: input signal, size = [time_dim,]
            time_dim: axis corresponding to time_dim. Default: 0
            kwargs: Other arguments including time_dim, approx_method, temperature

        Return:
            outputs: list of outputs
            states: list of hidden_states
        """
        # outputs = []
        # states = []
        hidden_state, trace1, trace2 = self._initialize_hidden_state(signal, padding=padding, **kwargs) # [hidden_dim]
        
        
        def f_(hidden, state):
            o, hidden = self._cell(state, hidden, **kwargs)
            return hidden, o
        # os = []
        # for (t1, t2) in zip(trace1, trace2):
        #     hidden_state, output = f_(hidden_state, (t1, t2))
        #     os.append(output)
        # return jnp.stack(os)
        
        # _, outputs_stack = scan(f_, hidden_state, jnp.stack([trace1, trace2], axis=1))
        
        

        _, outputs_stack = jax.lax.scan(f_, hidden_state, jnp.stack([trace1, trace2], axis=1))
        return outputs_stack

    def robustness(self, signal, **kwargs):
        """
        Computes the robustness value. Extracts the last entry along time_dim of robustness trace.

        Args:
            signal: jnp.array or Expression. Expected size [bs, time_dim, state_dim]
            kwargs: Other arguments including time_dim, approx_method, temperature

        Return: jnp.array, same as input with the time_dim removed.
        """
        return self.__call__(signal, **kwargs)[-1]
        # return jnp.rollaxis(self.__call__(signal, **kwargs), time_dim)[-1]
    def _next_function(self):
        """ next function is the input subformulas. For visualization purposes """
        return [self.subformula1, self.subformula2]

    def __str__(self):
        return  "(" + str(self.subformula1) + ")" + " U " + "(" + str(self.subformula2) + ")"



In [1]:
%load_ext autoreload
%autoreload 2

In [2]:


import jax
from stljax.formula import *
import numpy as np

In [36]:
pred = Predicate('x', lambda x: x)
interval = [3,6]
phi = pred > 0
psi = pred < 0
# rec3 = UntilRecurrent3(phi, psi, interval=interval)
rec = UntilRecurrent(phi, psi, interval=interval)
until = Until(phi, psi, interval=interval)
alw = AlwaysRecurrent(pred > 0)


In [43]:
T = 64
signal = jnp.array(np.random.randn(T))
# signal = jnp.arange(T) * 1.
signal_flip = jnp.flip(signal)
kwargs = {"approx_method" : "logsumexp", "temperature" : .1}

In [44]:
jnp.isclose(jnp.flip(rec(signal_flip, **kwargs)),  until(signal, **kwargs))

Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True], dtype=bool)

In [45]:
jnp.isclose(jax.grad(until.robustness)(signal, **kwargs), jnp.flip(jax.grad(rec.robustness)(signal_flip, **kwargs)))

Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True], dtype=bool)

In [22]:
def foo(signal, **kwargs):
    signal_flip = jnp.flip(signal)
    h1 = jnp.flip(AlwaysRecurrent(pred > 0.)(jnp.flip(phi(signal_flip)), **kwargs))
    h2 = psi(signal_flip, **kwargs)
    return maxish(minish(jnp.stack([h1, h2]), axis=0, keepdims=False, **kwargs), axis=0, keepdims=False, **kwargs)

In [46]:
T = 3
signal = jnp.array(np.random.randn(T))
# signal = jnp.arange(T) * 1.
signal_flip = jnp.flip(signal)

In [47]:
def foo(signal, **kwargs):
    signal_flip = jnp.flip(signal)
    h1 = jnp.flip(AlwaysRecurrent(pred > 0.)(jnp.flip(phi(signal_flip)), **kwargs))
    h2 = psi(signal_flip, **kwargs)
    return maxish(minish(jnp.stack([h1, h2]), axis=0, keepdims=False, **kwargs), axis=0, keepdims=False, **kwargs)

In [48]:
kwargs = {"approx_method" : "true", "temperature" : 1.}

foo(signal, **kwargs), rec(signal_flip, **kwargs)[-1]

(Array(-0.43033767, dtype=float32), Array(-0.81656384, dtype=float32))

In [49]:
until(signal, **kwargs), rec(signal_flip, **kwargs)

(Array([-0.43033767, -0.43033767, -0.43033767], dtype=float32),
 Array([-0.43033767, -0.43033767, -0.81656384], dtype=float32))

In [13]:
until(signal), jnp.flip(rec(signal_flip))

(Array([-5.73977537e-04, -5.73977537e-04, -5.73977537e-04, -5.73977537e-04,
        -5.73977537e-04, -1.61394596e+00, -4.42734033e-01, -6.88218832e-01,
        -6.13067687e-01, -8.75325859e-01, -1.05930471e+00, -4.99185435e-02,
        -4.99185435e-02, -4.99185435e-02, -3.03746879e-01, -1.93259704e+00,
        -1.00921392e+00, -6.05401397e-02, -6.05401397e-02, -6.36790395e-02,
        -1.01843126e-01, -5.11364974e-02, -7.50980735e-01, -3.18519950e-01,
        -3.18519950e-01, -4.13318336e-01, -4.13318336e-01, -5.32822967e-01,
        -1.83786780e-01, -1.86913177e-01, -1.86913177e-01, -1.86913177e-01,
        -1.86913177e-01, -2.53182411e-01, -2.53182411e-01, -2.53563792e-01,
        -2.53563792e-01, -2.53563792e-01, -5.73892951e-01, -5.73892951e-01,
        -9.22766149e-01, -1.37707424e+00, -1.37707424e+00, -3.25560033e-01,
        -1.04137695e+00, -1.04137695e+00, -1.50983977e+00, -1.94225740e+00,
        -1.28538325e-01, -1.28538325e-01], dtype=float32),
 Array([-5.73977537e-04, -5.7

In [14]:
jax.grad(until.robustness)(signal, **kwargs), jnp.flip(jax.grad(rec.robustness)(signal_flip, **kwargs))

(Array([ 8.0887549e-02,  7.1398571e-02,  5.2506059e-02,  4.4118419e-02,
         3.8021035e-02,  3.9438665e-02,  2.5784096e-02,  2.5979640e-02,
         2.2618465e-02,  2.0733533e-02,  1.8976128e-02,  1.4920744e-02,
         1.2367797e-02,  1.2386629e-02,  1.0970533e-02,  1.3010621e-02,
         1.0763054e-02,  7.7401712e-03,  8.0887312e-03,  7.5811846e-03,
         7.0417966e-03,  6.4791893e-03,  6.5016444e-03,  5.2549215e-03,
         4.9364488e-03,  4.2471630e-03,  4.6072234e-03,  4.3215789e-03,
         3.8389773e-03,  3.2825232e-03,  2.9449437e-03,  2.6948359e-03,
         2.6464318e-03,  2.3802374e-03,  2.3305425e-03,  1.9873180e-03,
         1.7358368e-03,  1.6275193e-03,  1.2770338e-03,  1.2412656e-03,
         1.0411248e-03,  8.1419945e-04,  1.0423237e-03,  7.6544331e-04,
         4.7755538e-04,  3.9527827e-04,  2.5827708e-04,  3.1924795e-04,
         9.8878321e-05, -2.8274153e-06], dtype=float32),
 Array([ 7.59146959e-02,  1.02493875e-01,  5.28910011e-02,  2.42940970e-02,
   

In [177]:
foo(signal, **kwargs), goo(signal, **kwargs), jax.grad(foo)(signal, **kwargs), jax.grad(goo)(signal, **kwargs)


(Array(4.9634266, dtype=float32),
 Array(4.9634266, dtype=float32),
 Array([0.10051197, 0.08580505, 0.05046772, 0.02427344, 0.01971575,
        0.0054059 ], dtype=float32),
 Array([0.10051195, 0.08580505, 0.05046772, 0.02427343, 0.01971575,
        0.0054059 ], dtype=float32))

In [175]:
def foo(signal, **kwargs):
    # return jnp.concatenate([minish(jnp.concatenate([minish(s1[:i+1], axis=0), s2[i:i+1]]), axis=0) for (i, s1i) in enumerate(s1)])
    s1, s2 = signal, -signal
    return maxish(jnp.concatenate([minish(jnp.concatenate([minish(s1[:i+1], axis=0, **kwargs), s2[i:i+1]]), axis=0, **kwargs) for (i, s1i) in enumerate(s1)]), axis=0, keepdims=False, **kwargs)

def goo(signal, **kwargs):
    s1, s2 = signal, -signal
    T = signal.shape[0]
    m1 = jnp.ones([2*T, T]) * 1E9
    for i in range(signal.shape[0]):
        m1 = m1.at[:i+1,i].set(s1[:i+1])
    return maxish(minish(jnp.stack([minish(m1, axis=0, keepdims=False, **kwargs), s2]), axis=0, keepdims=False, **kwargs), axis=0, keepdims=False, **kwargs)
        
        

In [176]:
foo(signal, **kwargs), goo(signal, **kwargs), jax.grad(foo)(signal, **kwargs), jax.grad(goo)(signal, **kwargs)

(Array(4.9634266, dtype=float32),
 Array(4.9634266, dtype=float32),
 Array([0.10051197, 0.08580505, 0.05046772, 0.02427344, 0.01971575,
        0.0054059 ], dtype=float32),
 Array([0.10051195, 0.08580505, 0.05046772, 0.02427343, 0.01971575,
        0.0054059 ], dtype=float32))

In [109]:
fill_value = -1E2
def foo1(x, **kwargs):
    s = jnp.concatenate([x, jnp.array([fill_value] * 1)])
    return maxish(s, axis=0, keepdims=False, **kwargs)

def foo2(x, **kwargs):
    s = jnp.concatenate([x, jnp.array([fill_value] * 10)])
    return maxish(s, axis=0, keepdims=False, **kwargs)


approx_method="logsumexp"
temperature = .01
jax.grad(foo1)(signal, approx_method=approx_method, temperature=temperature), \
    jax.grad(foo2)(signal, approx_method=approx_method, temperature=temperature)

(Array([0.29642978, 0.29811913, 0.2959793 ], dtype=float32),
 Array([0.14931634, 0.15016729, 0.14908943], dtype=float32))

In [78]:
LARGE_NUMBER = 1E9

jnp.flip(jax.grad(rec3.robustness)(signal_flip, approx_method=approx_method, temperature=temperature)), \
    jnp.flip(jax.grad(rec.robustness)(signal_flip, approx_method=approx_method, temperature=temperature)), \
        jax.grad(until.robustness)(signal * 1., approx_method=approx_method, temperature=temperature)

(Array([-0.14885019, -0.22690026, -0.1048732 ], dtype=float32),
 Array([-0.01470247, -0.19181027, -0.06849375], dtype=float32),
 Array([-0.0147025 , -0.19181027, -0.06849375], dtype=float32))

In [21]:
jnp.flip(rec3(signal_flip)),\
    jnp.flip(rec(signal_flip)),\
        until(signal)

(Array([-1.0590287e+00, -1.0590287e+00,  1.4155310e-01,  1.4155310e-01,
         1.4155310e-01, -1.0285740e+00, -1.0285740e+00, -1.0285740e+00,
        -1.0000000e+09, -1.0000000e+09], dtype=float32),
 Array([-1.0590287e+00, -1.0590287e+00,  1.4155310e-01,  1.4155310e-01,
         1.4155310e-01, -1.0285740e+00, -1.0285740e+00, -1.0285740e+00,
        -1.0000000e+09, -1.0000000e+09], dtype=float32),
 Array([-1.0590287e+00, -1.0590287e+00,  1.4155310e-01,  1.4155310e-01,
         1.4155310e-01, -1.0285740e+00, -1.0285740e+00, -1.0285740e+00,
        -1.0000000e+09, -1.0000000e+09], dtype=float32))

In [442]:
def foo(signal, **kwargs):
    # return jnp.concatenate([minish(jnp.concatenate([minish(s1[:i+1], axis=0), s2[i:i+1]]), axis=0) for (i, s1i) in enumerate(s1)])
    s1, s2 = signal, -signal
    return maxish(jnp.concatenate([minish(jnp.concatenate([minish(s1[:i+1], axis=0, **kwargs), s2[i:i+1]]), axis=0, **kwargs) for (i, s1i) in enumerate(s1)]), axis=0, keepdims=False, **kwargs)

def goo(signal, **kwargs):
    s1, s2 = signal, -signal
    m1 = jnp.ones([6, 3]) * 1E9
    for i in range(signal.shape[0]):
        m1 = m1.at[:i+1,i].set(s1[:i+1])
    return maxish(minish(jnp.stack([minish(m1, axis=0, keepdims=False, **kwargs), s2]), axis=0, keepdims=False, **kwargs), axis=0, keepdims=False, **kwargs)
        
        

In [443]:
kwargs = {"approx_method" : "logsumexp", "temperature" : .1}
foo(signal[:3], **kwargs), goo(signal[:3], **kwargs)

(Array(0.77383757, dtype=float32), Array(0.77383757, dtype=float32))

In [444]:
jax.grad(foo)(signal[:3], **kwargs), jax.grad(goo)(signal[:3], **kwargs)

(Array([ 0.13727403,  0.08534604, -0.00567825], dtype=float32),
 Array([ 0.13727403,  0.08534605, -0.00567825], dtype=float32))

In [454]:
rec3 = UntilRecurrent3(pred > 0, pred < 0, interval=None)
rec = UntilRecurrent(pred > 0, pred < 0, interval=None)
until = Until(pred > 0, pred < 0, interval=None)

In [457]:
rec3.robustness(signal[:3])

(Array([-1.000000e+09, -1.000000e+09,  3.660597e-01], dtype=float32), Array([-1.000000e+09, -1.000000e+09, -3.660597e-01], dtype=float32))
(Array([-1.0000000e+09, -1.1143353e+00, -1.1143353e+00], dtype=float32), Array([-1.0000000e+09, -3.6605969e-01,  1.1143353e+00], dtype=float32))
(Array([-1.1143353, -1.1143353,  0.5115534], dtype=float32), Array([-0.3660597,  1.1143353, -0.5115534], dtype=float32))


Array(-0.5115534, dtype=float32)

In [458]:
signal[:3]

Array([ 0.3660597, -1.1143353,  0.5115534], dtype=float32)

In [451]:
jax.grad(until.robustness)(signal[:3], **kwargs),\
    jax.grad(rec3.robustness)(signal[:3], **kwargs)

(Traced<ConcreteArray([-1.000000e+09 -1.000000e+09 -6.565412e+00], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([-1.000000e+09, -1.000000e+09, -6.565412e+00], dtype=float32)
  tangent = Traced<ShapedArray(float32[3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x33f117dd0>, in_tracers=(Traced<ShapedArray(float32[2,3]):JaxprTrace(level=1/0)>, Traced<ConcreteArray([ True  True  True], dtype=bool):JaxprTrace(level=1/0)>, Traced<ConcreteArray(0.10000000149011612, dtype=float32):JaxprTrace(level=1/0)>, Traced<ConcreteArray([[4.4376106 4.4376106 1.       ]
 [0.        0.        1.       ]], dtype=float32):JaxprTrace(level=1/0)>, Traced<ConcreteArray([4.4376106 4.4376106 2.       ], dtype=float32):JaxprTrace(level=1/0)>, Traced<ConcreteArray(0.10000000149011612, dtype=float32):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x341b482c0; to 'JaxprTracer' at 0x341b4b9d0>], out_avals=[Sha

(Array([ 0.13727404,  0.08534604, -0.00567825], dtype=float32),
 Array([0.04494014, 0.15934917, 0.24514408], dtype=float32))

In [392]:
approx_method = "logsumexp", temperature = .1

In [384]:

foo(signal, approx_method=approx_method, temperature=temperature), \
    Until(pred > 0, pred < 0, interval=None).robustness(signal, approx_method=approx_method, temperature=temperature), \
        UntilRecurrent3(pred > 0, pred < 0, interval=None).robustness(signal_flip, approx_method=approx_method, temperature=temperature)



(Array(-0.36597282, dtype=float32),
 Array(-0.36597282, dtype=float32),
 Array(-0.36606702, dtype=float32))

In [388]:
approx_method = "logsumexp"
jax.grad(foo)(signal, approx_method=approx_method, temperature=temperature), \
    jax.grad(Until(pred > 0, pred < 0, interval=None).robustness)(signal, approx_method=approx_method, temperature=temperature), \
        jnp.flip(jax.grad(UntilRecurrent3(pred > 0, pred < 0, interval=None).robustness)(signal_flip, approx_method=approx_method, temperature=temperature))

(Array([-9.9715149e-01,  1.2852616e-03, -1.3487851e-06, -1.4053319e-04,
        -8.5113323e-05,  1.6848637e-05, -1.8629791e-09,  1.9984489e-12,
        -1.1765958e-12,  1.6355593e-10], dtype=float32),
 Array([-9.9715149e-01,  1.2852616e-03, -1.3487851e-06, -1.4053320e-04,
        -8.5113323e-05,  1.6848637e-05, -1.8629791e-09,  1.9984489e-12,
        -1.1765958e-12,  1.6355593e-10], dtype=float32),
 Array([-9.9611372e-01,  1.0050855e-03, -1.3500086e-06, -1.4066583e-04,
        -8.5193635e-05,  1.5187554e-05, -1.8647366e-09,  2.7083767e-12,
        -1.1107879e-12,  3.2738981e-10], dtype=float32))

In [88]:
UntilRecurrent3(pred > 0, pred < 0, interval=None).robustness(signal_flip, approx_method=approx_method, temperature=temperature)


Array(0.15078372, dtype=float32)

In [89]:
jax.grad(foo)(signal, approx_method=approx_method, temperature=temperature)

Array([ 3.78097028e-01,  1.11037388e-01, -3.25710773e-02,  1.04564965e-01,
        2.04220619e-02, -3.12179327e-05,  3.38417664e-02,  1.11742159e-02,
        1.78319607e-02, -2.81446660e-03], dtype=float32)

In [53]:
i = 9
minish(s1[:i], axis=0)

Array([-1.04302], dtype=float32)

In [54]:
minish(np.concatenate([minish(s1[:i], axis=0), s2[i-1:i]]), axis=0)

Array([-1.04302], dtype=float32)

In [260]:
class TemporalOperator(STL_Formula):

    def __init__(self, subformula, interval=None):
        super().__init__()
        self.subformula = subformula
        self.interval = interval
        
        if self.interval is None:
            self.hidden_dim = None
        elif interval[1] == jnp.inf:
            self.hidden_dim = None
        else:
            self.hidden_dim = interval[1] + 1
        self.LARGE_NUMBER = 1E9
        self.operation = None

    def _get_interval_indices(self):
        if self.interval is None:
            start_idx = 0
            end_idx = self.hidden_dim
        else:
            if self.interval[1] is jnp.inf:
                start_idx = 0
            else:
                start_idx = -self.interval[1] - 1

            if self.interval[0] == 0:
                end_idx = self.hidden_dim
            else: 
                end_idx = -self.interval[0]
                
        return start_idx, end_idx

    def _run_cell(self, signal, padding=None, **kwargs):

        hidden_state = self._initialize_hidden_state(signal, padding=padding) # [hidden_dim]
        def f_(hidden, state):
            hidden, o = self._cell(state, hidden, **kwargs)
            return hidden, o

        _, outputs_stack = jax.lax.scan(f_, hidden_state, signal)
        return outputs_stack

    def _initialize_hidden_state(self, signal, padding=None):

        if padding == "last":
            pad_value = jax.lax.stop_gradient(signal)[0]
        elif padding == "mean":
            pad_value = jax.lax.stop_gradient(signal).mean(0)
        else:
            pad_value = self.sign * self.LARGE_NUMBER
            
        n_time_steps = signal.shape[0] 
            
        if self.hidden_dim is None:
            self.hidden_dim = n_time_steps 
            
        self.M = jnp.diag(jnp.ones(self.hidden_dim-1), k=1)
        self.b = jnp.zeros(self.hidden_dim)
        self.b = self.b.at[-1].set(1)
        
        h0 = jnp.ones(self.hidden_dim) * pad_value

        return h0

    def _cell(self, state, hidden, **kwargs):
    
        h_new = self.M @ hidden + self.b * state
        start_idx, end_idx = self._get_interval_indices()        
        output = self.operation(h_new[start_idx:end_idx], axis=0, keepdims=False, **kwargs)
        
        return h_new, output
        

    def robustness_trace(self, signal, padding=None, **kwargs):

        trace = self.subformula(signal, **kwargs)
        outputs = self._run_cell(trace, padding, **kwargs)
        return outputs                  

    def robustness(self, signal, **kwargs):
        return self.__call__(signal, **kwargs)[-1]


    def _next_function(self):
        return [self.subformula]

class AlwaysRecurrent2(TemporalOperator):

    def __init__(self, subformula, interval=None):
        super().__init__(subformula=subformula, interval=interval)
        self.operation = minish
        self.sign = 1.

    def __str__(self):
        return "◻ " + str(self._interval) + "( " + str(self.subformula) + " )"
        
class EventuallyRecurrent2(TemporalOperator):

    def __init__(self, subformula, interval=None):
        super().__init__(subformula=subformula, interval=interval)
        self.operation = maxish
        self.sign = -1.
        
    def __str__(self):
        return "♢ " + str(self._interval) + "( " + str(self.subformula) + " )"

In [261]:
interval = None
aa = EventuallyRecurrent2(phi, interval=interval)
al = Eventually(phi, interval=interval)

In [262]:
jnp.flip(jax.grad(aa.robustness)(signal_flip, approx_method="logsumexp", temperature=temperature)) - jax.grad(al.robustness)(signal, approx_method="logsumexp", temperature=temperature)

Array([-1.3234890e-23, -6.7762636e-21, -5.9604645e-08, -1.6155871e-27,
       -2.7755576e-17,  0.0000000e+00, -1.4551915e-11, -8.2718061e-25,
       -7.1054274e-15, -1.6940659e-21], dtype=float32)

In [263]:
jnp.flip(aa(signal_flip, approx_method="logsumexp", temperature=temperature)) - al(signal, approx_method="logsumexp", temperature=temperature)

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)