In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
%env JAX_ENABLE_X64=1
%env JAX_PLATFORM_NAME=cpu

env: JAX_ENABLE_X64=1
env: JAX_PLATFORM_NAME=cpu


In [3]:
import jax
from jax import jit

import jax.numpy as jnp

In [4]:
from typing import Any, List

PRNGKeyArray = Any # type alias


In [5]:
from functools import reduce, partial

# Random Data Generation

In [6]:
def random_mps(
    key: PRNGKeyArray,
    size: int,
    local_dim: int,
    bond_dim: int,
    dtype=jnp.float64) -> List[jnp.DeviceArray]:
    """
    Generate a random MPS where each core tensor
    is drawn i.i.d. from a uniform distribution 
    between -1 and 1.

    Input:
    ------
    key:        The random key.
    size:       The size (length) of an MPS.
    local_dim:  The local dimension size.
    bond_dim:   The bond dimension size.
    dtype:      The type of data to return.
    """
    # initialize MPS data collection
    mps = []
     
    for i in range(size):
        key, _ = jax.random.split(key)
        if i == 0:  # left most tensor
            tensor = jax.random.uniform(
                key, shape=(1, local_dim, bond_dim), minval=-1, maxval=1, dtype=dtype)
        elif i == size-1:  # right most tensor
            tensor = jax.random.uniform(
                key, shape=(bond_dim, local_dim, 1), minval=-1, maxval=1, dtype=dtype)
        else:  # middle tensors
            tensor = jax.random.uniform(
                key, shape=(bond_dim, local_dim, bond_dim), minval=-1, maxval=1, dtype=dtype)
        mps.append(tensor)

    return mps

In [7]:
def random_sample(
    key: PRNGKeyArray,
    size: int,
    local_dim: int,
    n_factors: int,
    dtype=jnp.float64) -> List[jnp.DeviceArray]:
    """
    Generate random data samples where components
    corrsponding to MPS tensors are drawn i.i.d. 
    from a normal distribution.

    Input:
    ------
    key:        The random key.
    size:       The sample size.
    local_dim:  The dimension of each sample.
    n_factors:  The number of factors (equal to the MPS size).
    dtype:      The type of data to return.
    """
    # initialize the collection
    samples = []
        
    for _ in range(size):
        x = []  # collects the components of a single sample 
        for _ in range(n_factors):
            key, _ = jax.random.split(key, num=2)
            x.append(
                jax.random.normal(
                    key, 
                    shape=(local_dim,), 
                    dtype=dtype)
            )
        samples.append(jnp.asarray(x))

    return samples

In [8]:
# def contract(mps):
#     """Fully contract the MPS with its conjugate."""
#     tensors = [jnp.einsum('pqr,uqv->purv', t, t.conj()) for t in mps]
#     res = reduce(lambda x,y: jnp.einsum('purv,rvts->puts', x,y), tensors)
#     return res.squeeze()

# # func alias
# mps_norm = contract

# def dot_one(mps: List[DeviceArray], sample: DeviceArray):
#     # contracts each mps tensor with a corresponding sample component
#     dot_wx = lambda w, x: jnp.einsum('pqr,q->pr', w, x)
#     # contract all tensor in the mps (equiv to a dotproduct of an mps)
#     res = reduce(jnp.matmul, jax.tree_multimap(dot_wx, mps, list(sample)))
#     return res.squeeze()

# def dot_many(mps: List[DeviceArray], samples: List[DeviceArray]):
#     """Contract MPS with data"""
#     res = jax.tree_map(lambda s: dot_one(mps, s), samples)
#     return res

In [16]:
def _align_dims(data: List[jnp.DeviceArray]):
    """Add extra dims to align tensor in an MPS"""
    if data[-1].ndim == 1:
        # adding an extra dimensions
        data = [x[jnp.newaxis,:,jnp.newaxis] for x in data]
    return data

def dot(mps1: List[jnp.DeviceArray], mps2: List[jnp.DeviceArray]) -> jnp.double:
    """
    Dot product of an MPS with another mps or data (i.e. castrated MPS)
    --A1----A2--...--An-- (MPS)
      |     |        |
    
      |     |        |
      x1    x2  ...  xn   (data sample)
    """
    mps1 = _align_dims(mps1)
    mps2 = _align_dims(mps2)
    # contracts individual components
    dot = lambda x, y: jnp.einsum('pqr,uqv->purv', x, y)
    # multiply two neighbouring tensors
    mult = lambda x, y: jnp.einsum('purv,rvts->puts', x, y)
    # contract all
    res = reduce(mult, jax.tree_multimap(dot, mps1, mps2))
    return res.squeeze()

def mps_norm(mps: List[jnp.DeviceArray]) -> jnp.double:
    """Computing the norm of an MPS"""
    mps_c = [t.conj() for t in mps]
    return dot(mps, mps_c)

def mdot(mps: List[jnp.DeviceArray], samples: List[jnp.DeviceArray]) -> List[jnp.DeviceArray]:
    """Apply dot product to many samples"""
    res = jax.tree_map(lambda s: dot(mps, s), samples)
    return res

## Constants

In [10]:
# PRNG seed
SEED = 161803

# model size (MPS_SIZE * LOCAL_DIM + BOND_DIM = TRAIN_SIZE)
MPS_SIZE = 3
LOCAL_DIM = 2
BOND_DIM = 4

# data sample
SAMPLE_SIZE = 10000

In [11]:
key = jax.random.PRNGKey(SEED)

# Spliting the key
key_params, key_data, key_noise, key_run = jax.random.split(key, num=4)

# target MPS model
mps = random_mps(key_params, size=MPS_SIZE, local_dim=LOCAL_DIM, bond_dim=BOND_DIM)
# generate samples
samples = random_sample(key_data, size=SAMPLE_SIZE, local_dim=LOCAL_DIM, n_factors=MPS_SIZE)

In [15]:
%%time
z = mdot(mps, samples)

CPU times: user 29.6 s, sys: 279 ms, total: 29.9 s
Wall time: 30.1 s
