In [1]:
import thermox
import jax
import jax.numpy as jnp
from scipy.linalg import solve, inv, expm

### In this notebook we show how to run three basic thermodynamic algorithms:
1. Thermodynamic linear solver: find $x$ such that $Ax = b$,
2. Thermodynamic matrix inverse: find $A^{-1}$,
3. Thermodynamic matrix exponential: find $\exp{(A)}$.

These algorithms are all based on the multivariate Ornstein-Uhlenbeck process, defined as
$$ dx = - A(x - b) dt + \mathcal{N}(0, 2D) $$

Let us start with solving a linear system $Ax = b$. In this case, $D = \mathbb{I}$.

In [2]:
key = jax.random.PRNGKey(0) # random PRNG key
dimension = 100 # problem size
mean = jnp.zeros(dimension) # mean vector
A = jax.random.normal(key, shape=(dimension, 2*dimension,))
A = A @ A.T # random positive-semi definite matrix from the Wishart distribution
b = jax.random.normal(key, shape=(dimension,))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
x = thermox.linalg.solve(A, b)
x

Array([-0.04291015,  0.00502984,  0.01298528, -0.00028778, -0.01526742,
       -0.00567479,  0.01533074, -0.01003827, -0.0050336 , -0.00972303,
       -0.01738689, -0.00678701,  0.02312178,  0.00757717, -0.00638804,
        0.00790122, -0.01896622,  0.01315646, -0.02966048, -0.01008006,
        0.00765205, -0.00272523, -0.01280687,  0.01449428, -0.00697023,
        0.00979751,  0.00458921, -0.00285588,  0.01644798, -0.00951699,
        0.01204343, -0.02854497,  0.0011238 , -0.00382948, -0.01173987,
        0.02511922, -0.01580777, -0.02315124,  0.00045882,  0.00419719,
       -0.00103077, -0.01266642,  0.02562701,  0.01339333,  0.02357866,
        0.00899012,  0.00454874, -0.00576508, -0.00708199,  0.00854757,
       -0.00158417,  0.0159873 ,  0.01952366, -0.00064415,  0.0175536 ,
       -0.01097133,  0.00723158,  0.03912605, -0.00986775, -0.01613855,
        0.00183138,  0.03305778, -0.00103753,  0.0260798 , -0.00595791,
        0.02330123,  0.00846258, -0.01343022, -0.00485881,  0.03

We know look at the absolute error $||\bar{x} - A^{-1}b||$:

In [9]:
print(r"||\bar{x} - A^{-1}b|| = ", jnp.linalg.norm(solve(A,b) - x))

||\bar{x} - A^{-1}b|| =  0.0100456355


## Thermodynamic matrix inverse

This time, no need to define the vector $b$. The matrix is simply defined as the continuous-time correlation matrix
 $$A^{-1} \approx C(t,t') = \langle x(t) x(t')\rangle$$