In [53]:
from classes import MPS, MPO
from jax.random import PRNGKey, normal
from helper_functions import get_norm_from_MPS

Let's create a Matrix Product State with randomly distributed entries:

In [9]:
key = PRNGKey(13)

A = normal(key, shape=(1, 5, 3))
B = normal(key, shape=(3, 6, 2))
C = normal(key, shape=(2, 5, 1))

In [34]:
mps = MPS([A, B, C])

In [20]:
print(repr(mps))

MPS( component 0 of size (1, 5, 3)
component 1 of size (3, 6, 2)
component 2 of size (2, 5, 1)
)


MPS class has method dot implementing scalar product. As expected, multiplying object by itself produces its (Frobenius) norm.

In [24]:
get_norm_from_MPS(mps)

29.219938

In [26]:
jnp.sqrt(mps.dot(mps))

DeviceArray(29.219929, dtype=float32, weak_type=True)

Also, MPS can be adduced to canonical form, either left or right

In [30]:
for x in mps.components:
    print(jnp.tensordot(x, x, 3))

15.94236
31.23802
10.811666


In [35]:
mps.left_canonical()

In [32]:
for x in mps.components:
    print(jnp.tensordot(x, x, 3))

3.0000002
2.0
853.8036


Canonical form enables to implement SVD-truncation for a given truncation value

In [40]:
mps.left_svd_trunc(3)
get_norm_from_MPS(mps)

26.905174

Another class is Matrix Product Operator which allows to introduce the evolution of MPS

In [45]:
D = normal(key, shape=(1, 5, 3, 5))
E = normal(key, shape=(3, 6, 3, 6))
F = normal(key, shape=(3, 5, 1, 5))

mpo = MPO([D, E, F])

In [50]:
mpo.process(mps)