In [2]:
import jax.numpy as jnp
import numpy as np
import unittest

In [3]:
class MPS:
    def __init__(self, tensors):
        self.components = tensors
        self.rank = len(tensors)

    def dot(self, rhs):               #scalar product function
        if self.rank != rhs.rank:
            return "ranks do not match"
        result = jnp.tensordot(self.components[0], rhs.components[0], [[0], [0]])
        for i in range(1, self.rank):
           result = jnp.tensordot(result, self.components[i], [[0],[0]])
           result = jnp.tensordot(result, rhs.components[i], [[0,1], [0,1]])
        return result

    def view(self):                  #prints shape and elements of the resulting tensor
        result = self.components[0]
        for i in range(1, self.rank):
            result = jnp.tensordot(result, self.components[i], 1)
        print(result.shape)
        print(result)

    def get_tensor(self):            #returns the resulting tensor
        result = self.components[0]
        for i in range(1, self.rank):
            result = jnp.tensordot(result, self.components[i], 1)
        return result



A small test just for me:

In [4]:
e = jnp.array([[1., 0],
                [2,1]])
m = MPS([e, e])
print(m.dot(m))
print(jnp.linalg.norm(m.get_tensor())**2)



18.0
17.999998


In [11]:
class Test(unittest.TestCase):
    def test_dot(self):
        self.max_dim = 6         #maximum dimension of any index

        for rank in range (2, 6):

            self.outer_dims = np.random.randint(2, self.max_dim, rank)          #visible dimensions of each tensor
            self.inner_dims_a = np.random.randint(2, self.max_dim, rank-1)      #bond dimensions of tensor A
            self.inner_dims_b = np.random.randint(2, self.max_dim, rank-1)

            print(self.outer_dims, self.inner_dims_a, self.inner_dims_b)

            self.comp_a = [np.random.randn(self.outer_dims[0], self.inner_dims_a[0])]  #filling the first components of mps
            self.comp_b = [np.random.randn(self.outer_dims[0], self.inner_dims_b[0])]

            for i in range (1, rank-1):
            
                self.comp_a.append(np.random.randn(self.inner_dims_a[i-1], self.outer_dims[i], self.inner_dims_a[i]))
                self.comp_b.append(np.random.randn(self.inner_dims_b[i-1], self.outer_dims[i], self.inner_dims_b[i]))

            self.comp_a.append(np.random.randn(self.inner_dims_a[-1], self.outer_dims[-1]))   #the last components of mps
            self.comp_b.append(np.random.randn(self.inner_dims_b[-1], self.outer_dims[-1]))

            self.a = MPS(self.comp_a)
            self.b = MPS(self.comp_b)

            result = self.a.dot(self.b)
            expected = jnp.tensordot(self.a.get_tensor(), self.b.get_tensor(), rank)
            
            self.assertAlmostEqual(result, expected, places=2, msg=None, delta=None)
 

In [12]:
unittest.main(argv=[''], verbosity=2, exit=False)

test_dot (__main__.Test) ... 

[2 2] [5] [4]
[3 5 5] [3 2] [4 4]
[2 5 5 4] [4 4 5] [3 2 2]
[4 2 5 5 5] [5 5 3 4] [5 4 2 5]


ok

----------------------------------------------------------------------
Ran 1 test in 0.827s

OK


<unittest.main.TestProgram at 0x7f5cddf48210>