In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from stljax.formula import *
from stljax.viz import *

import functools

## NOTE
If using Expressions to define formulas, `stljax` expects input signals to be of size `[time_dim]`.
If using Predicates to define formulas, `stljax` expects input signals to be of size `[time_dim, state_dim]` where `state_dim` is the expected input size of your predicate function.



In [2]:
def compute_distance_to_origin(states):
    return jnp.linalg.norm(states[...,:2], axis=-1, keepdims=False)

In [3]:
T = 10
compute_distance_to_origin(jnp.ones([T, 2]))

Array([1.4142135, 1.4142135, 1.4142135, 1.4142135, 1.4142135, 1.4142135,
       1.4142135, 1.4142135, 1.4142135, 1.4142135], dtype=float32)

## Using Expressions
Expressions are placeholders for input signals. Specifically, it is assuming the signal is already a 1D array, such as the output of a predicate function. 

This is useful if you have signals from predicates computed already. 

In general, this is useful for readability and visualization.

In [4]:
distance_to_origin_exp = Expression("magnitude", value=None) # can define an Expression without setting values for the expression right now
formula_exp = Eventually(distance_to_origin_exp < 0.5) # can define an STL formula given an expression, again, the value of the expression does not need to be set yet


formula_exp(distance_to_origin_exp) # <---- this will throw an error since the expression does not have values set yet



AssertionError: Input Expression does not have numerical values

In [5]:
# so let's go ahead and set a value for the expression
T = 5
states = jnp.array(np.random.randn(T, 2))
states_norm = compute_distance_to_origin(states)   # compute distance to origin

distance_to_origin_exp.set_value(states_norm)   # set value for Expression

# compute robustness trace
formula_exp(distance_to_origin_exp) # <---- this will no longer throw an error since the expression has a value set

# alternatively, we can directly plug any jnp.array and evaluate the robustness without 
states2 = jnp.array(np.random.randn(T, 2))
states_norm2 = compute_distance_to_origin(states2)   # compute distance to origin
formula_exp(states_norm2) 



Array([ 0.24982172,  0.24982172,  0.24982172,  0.24982172, -0.9981775 ],      dtype=float32)

We can compute the robustness value (instead of trace) and take the derivative

In [6]:
robustness = formula_exp.robustness(states_norm) 
print(f"Robustness value: {robustness:.3f}\n")

gradient = jax.grad(formula_exp.robustness)(states_norm) 
print(f"Gradient of robustness value w.r.t. input:\n {gradient}")


Robustness value: -0.041

Gradient of robustness value w.r.t. input:
 [-0. -0. -1. -0. -0.]


We can apply a smooth max/min approximation by selecting a `approx_method` and `temperature`.
The default `approx_method` is `true`.


In [7]:
approx_method = "logsumexp"  # or "softmax"
temperature = 1. # needs to be > 0

robustness = formula_exp.robustness(states_norm, approx_method=approx_method, temperature=temperature) 
print(f"Robustness value: {robustness:.3f}\n")

gradient = jax.grad(formula_exp.robustness)(states_norm, approx_method=approx_method, temperature=temperature) 
print(f"Gradient of robustness value w.r.t. input:\n {gradient}") # <----- gradients are spread across different values

Robustness value: 1.030

Gradient of robustness value w.r.t. input:
 [-0.09700805 -0.2143781  -0.34265578 -0.23554076 -0.1104174 ]


For formulas that are defined with two different Expressions, we need to be careful about the signals we are feeding in.

In [8]:
# if both subformulas use the same signal, then we can do this
phi = (distance_to_origin_exp > 0) & (distance_to_origin_exp < 0.5)  
phi(states_norm)


# if the formula depends on two different signals, then we need to provide the two signals as tuple
distance_to_origin_exp = Expression("magnitude", value=None)
speed_exp = Expression("speed", value=None)

phi = (distance_to_origin_exp > 0) & (speed_exp < 0.5)  

phi(states_norm) # <--- Will give WRONG ANSWER


speed = jnp.array(np.random.randn(T))
input_correct_order = (states_norm, speed)
input_wrong_order = (speed, states_norm)
phi(input_correct_order) # <--- Will give desired answer
phi(input_wrong_order) # <--- Will give WRONG ANSWER since the ordering of the input does not correspond to how phi is defined




Array([-1.8956004 , -1.1781825 , -0.04106313, -0.41590554, -1.1735218 ],      dtype=float32)

## Using Predicates
Predicates are the functions that an N-D signal is passed through and its outputs are then passed through each operation of the STL formula.
We can construct an STL formula by specifying the predicate functions and the connectives and temporal operations.


In [9]:
distance_to_origin_pred = Predicate("magnitude", predicate_function=compute_distance_to_origin) # define a predicate function with a name and the function
formula_pred = Eventually(distance_to_origin_pred < 0.5) # define the STL formula

# so let's go ahead and set a value for the input N-D array which will be the input into the predicate function.
T = 5
states = jnp.array(np.random.randn(T, 2))  # 2D signal
output_from_using_predicate = formula_pred(states) # compute distance to origin INSIDE 


# NOTE: this is equivalent to the following with expressions
states_norm = compute_distance_to_origin(states)   # computes distance to origin OUTSIDE 
output_from_using_expression = formula_exp(states_norm) 


# check if we get the same answer
jnp.isclose(output_from_using_predicate, output_from_using_expression)

Array([ True,  True,  True,  True,  True], dtype=bool)

Similarly, we can compute the robustness value (instead of trace) and take the derivative. 

In [10]:
approx_method = "logsumexp"  # or "softmax"
temperature = 1. # needs to be > 0

robustness = formula_pred.robustness(states, approx_method=approx_method, temperature=temperature) 
print(f"Robustness value: {robustness:.3f}\n")

gradient = jax.grad(formula_pred.robustness)(states, approx_method=approx_method, temperature=temperature) 
print(f"Gradient of robustness value w.r.t. input:\n {gradient}") # <----- gradients are spread across different values

Robustness value: 0.951

Gradient of robustness value w.r.t. input:
 [[-0.05231676  0.04973718]
 [-0.02268155 -0.20005907]
 [ 0.23647855  0.29368952]
 [ 0.21631472 -0.10630771]
 [-0.0618272   0.08902159]]


Note that when taken gradients with formulas defined with predicates, the input is the N-D signal which is passed into the predicate function and other robustness formulas. That is to say, the gradient will be influenced by the choice of the predicate. 

To get the same gradient output when using Expressions, we need to do the following:

In [11]:
def foo(states):
    states_norm = compute_distance_to_origin(states)   # compute distance to origin
    return formula_exp.robustness(states_norm, approx_method=approx_method, temperature=temperature) 

jax.grad(foo)(states)

Array([[-0.05231676,  0.04973718],
       [-0.02268155, -0.20005907],
       [ 0.23647855,  0.29368952],
       [ 0.21631472, -0.10630771],
       [-0.0618272 ,  0.08902159]], dtype=float32)