# Imports and settings

In [1]:
%matplotlib inline
%env JAX_ENABLE_X64=1
%env JAX_PLATFORM_NAME=cpu

env: JAX_ENABLE_X64=1
env: JAX_PLATFORM_NAME=cpu


In [2]:
import jax
import jax.numpy as jnp
from jax import jit, grad
from jax.example_libraries import optimizers

In [3]:
import os
import pickle
import time
import itertools

from collections import defaultdict
from functools import reduce

In [4]:
from typing import Any, List, DefaultDict, Dict, Tuple, Optional

## Types

In [5]:
# type alias
PRNGKeyArray = Any
DeviceArray = jnp.DeviceArray

# Random Data Generation

## MPS

In [250]:
from copy import deepcopy

class MPS:

    """MPS class"""
    
    def __init__(self, size: int, local_dim: int, bond_dim: int):
        self.size = size
        self.local_dim = local_dim
        self.bond_dim = bond_dim
        self.svs = [None]*size
        self.dw = [None]*size
        self._tensors = [None]*size
        self._normalized = False
        self._normalization_method = None
        self._norm = None

    def __repr__(self) -> str:
        return f'An MPS of size: {self.size}; ' +\
               f'local dim: {self.local_dim}; ' +\
               f'bond dim: {self.bond_dim}.'

    def __setitem__(self,idx, tensor):
        self._tensors[idx] = tensor

    def __getitem__(self, idx) -> DeviceArray:
        return self._tensors[idx]

    def __len__(self) -> int:
        return self.size

    def __iter__(self):
        pass

    @property
    def norm(self) -> jnp.double:
        if self._norm is None:
            self._norm = self.get_overlap(self.conjugate)
        return jnp.real(self._norm)

    @property
    def conjugate(self):
        self_conj = deepcopy(self)
        self_conj._tensors = jax.tree_map(jnp.conj, self._tensors)
        return self_conj

    @property
    def is_normalized(self) -> bool:
        return self._normalized

    @property
    def normalization_method(self) -> str:
        return self._normalization_method

    def get_overlap(self, other) -> jnp.double:
        """
        Overlap of this MPS with another MPS.
        --A1----A2--...--An-- (this)
          |     |        |
        
          |     |        |
        --B1----B2--...--Bn-- (other)
        """
        # contracts individual components
        dot = lambda x, y: jnp.einsum('pqr,uqv->purv', x, y)
        # multiply two neighbouring tensors
        mult = lambda x, y: jnp.einsum('purv,rvts->puts', x, y)
        # contract all
        res = reduce(mult, jax.tree_multimap(dot, self._tensors, other._tensors))
        return res.squeeze()

    def check_normalization(self, method: str, idx=-1, tol=1e-5):

        def _check(start, stop, method):
            bad_sites = {}  # keep track where normalization fails
            for i in range(start, stop, 1):
                A = self[i]
                u, _, w = A.shape
                if method == 'left':
                    A = A.reshape(-1, w)
                elif method == "right":
                    A = A.reshape(u, -1)
                    A = A.conj().T
                delta = jnp.max(abs(A.conj().T @ A - jnp.eye(A.shape[1])))
                if  delta > tol:
                    bad_sites[i] = delta
            if bad_sites:
                for i, delta in bad_sites.items():
                    print(f'Badly normalized {i}-th site; delta: {delta}')
            else:
                print(f'The MPS is {method} normalized correctly between {start} and {stop-1}')

        if method == 'left':
            if idx == -1:
                stop = self.size
            else:
                stop = idx
            _check(0, stop, 'left')
        elif method == 'right':
            if idx == -1:
                start = 0
            else:
                start = idx
            _check(start, self.size, 'right')
        elif method == 'site':
            if not 0 < idx < self.size - 1:
                raise ValueError('Need to provide site idx')
            _check(0, idx, 'left')
            _check(idx+1, self.size, 'right')
        elif method == 'bond':
            if not 0 < idx < self.size - 1:
                raise ValueError('Need to provide site idx')
            _check(0, idx+1, 'left')
            _check(idx+1, self.size, 'right')
        else:
            raise ValueError(f'Unknown method {method}')

    def normalize(self, method, idx=-1, max_bond=jnp.inf, force=False):
        """
        Normalize to one of the canonical forms.

        Input:
        ------
        method: str         One of the canonical forms, incl.
                            'left', 'right', 'site' or 'bond'.
        idx: int            A site index which is used to stop
                            the normalization process. The value of -1
                            indicates a full propagations. Default: -1.
        max_bond: int       Max bond dimension to keep Default: np.inf.
        force: bool         Force normalizaion even if already normalized.
        """
        if self.is_normalized and not force:
            print('The MPS is already normalized using' +\
                  f' the "{self.normalization_method}" method.' +\
                  f' Use force=True to renormalize')
            return

        {
            'left': self._left_normalize,
            'right': self._right_normalize,
            'bond': self._bond_normalize,
            'site': self._site_normalize
        }[method](idx, max_bond)

        self._normalized = True
        self._normalization_method = method

    def _left_normalize(self, idx, max_bond):
        """
        Normalize to left canonical form.

        Input:
        ------
        idx: int    A site index which is used to stop
                    the normalization process. The value of -1
                    indicates a full propagations. Default: -1.
        """   
        # determine where to stop
        if idx == -1:
            stop = self.size
        else:
            assert 0 < idx < self.size-1, 'The stoping site index must be 0 < idx < self.size-1'
            stop = idx

        # getting the cuttof dimension
        max_bond = min(max_bond, self.bond_dim)

        for i in range(stop):

            # get the current tensor
            M = self[i]

            # reshape into a matrix by merging left/bottom legs
            u, v, w = M.shape
            M = M.reshape(u*v, w)

            # perform SVD
            U, s, Vh = jnp.linalg.svd(M, full_matrices=False)

            # cutting off the bond dimension
            U  = U[:,:max_bond]
            Vh = Vh[:max_bond,:]

            # calculating the discarded weight before shrinking
            self.dw[i] = (s[max_bond:]**2).sum()
            s  = s[:max_bond]

            # storing the singular values
            self.svs[i] = s

            # assign the current tensor to U and reshape
            self[i] = U.reshape(U.shape[0] // v, v, U.shape[1])
                        
            if i < self.size-1:
                M_next = self[i+1].reshape(Vh.shape[1], -1)
                self[i+1] = (jnp.diag(s) @ Vh @ M_next).reshape(U.shape[1], v, -1)
            else:
                self._norm = s**2

        return

    def _right_normalize(self, idx, max_bond):
        """
        Normalize to right canonical form.

        Input:
        ------
        idx: int    Index of the site indicating where to stop
                    the normalization process. The value of -1
                    indicate a full propagations. Default: -1.
        """
        # assertions
        if idx == -1:
            stop = -1
        else:
            assert 0 < idx < self.size-1, 'The stoping index must be 0 < idx < self.size-1'
            stop = idx

        # getting the cuttof dimension
        max_bond = min(max_bond, self.bond_dim)

        for i in range(self.size-1, stop, -1):
            M = self[i]
            u, v, w = M.shape
            M = M.reshape(u, v*w)

            # perform SVD
            U, s, Vh = jnp.linalg.svd(M, full_matrices=False)

            # cutting off the bond dimension
            U  = U[:,:max_bond]
            Vh = Vh[:max_bond,:]

            # calculating the discarded weight before shrinking
            self.dw[i] = (s[max_bond:]**2).sum()
            s  = s[:max_bond]

            # storing the singular values
            self.svs[i] = s

            self[i] = Vh.reshape(Vh.shape[0], v, Vh.shape[1] // v)
            if i > 0:
                M_new = self[i-1].reshape(-1, U.shape[0])
                self[i-1] = (M_new @ U @ jnp.diag(s)).reshape(-1, v, Vh.shape[0])
            else:
                self._norm = s**2

        return
    
    def _site_normalize(self, idx, max_bond):
        """
        Normalize to site canonical form.

        Input:
        ------
        idx: int    Index of the site indicating where to stop
                    the normalization process.
        """

        self._left_normalize(idx, max_bond)
        self._right_normalize(idx, max_bond)

        return

    def _bond_normalize(self, idx, max_bond):
        """
        Normalize to bond canonical form.

        Input:
        ------
        idx: int    Index of the site indicating where to stop
                    the normalization process. The site with the
                    index 'idx' is included to the left normalized
                    part of the MPS, while the site with the index
                    'idx+1' is included to the right normalized
                    part of the MPS.
        """

        self._left_normalize(idx, max_bond)
        self._right_normalize(idx, max_bond)
        
        M = self[idx]
        u, v, w = M.shape
        M = M.reshape(u*v, w)

        # perform SVD
        U, s, Vh = jnp.linalg.svd(M, full_matrices=False)

        # insert the bond singular values between idx and idx+1,
        # shifting all to the right
        self.svs[idx] = s

        self[idx] = U.reshape(U.shape[0] // v, v, U.shape[1])

        M_next = self[idx+1].reshape(Vh.shape[1], -1)
        self[idx+1] = (Vh @ M_next).reshape(Vh.shape[0], v, -1)

        return

    def site_expectation(self, op, idx):
        """
        Computing the expectation value of the operator.

        Input:
        ------
        op:     Operator
        """
        assert op.shape[0] == op.shape[1]
        assert op.shape[0] == self[idx].shape[1]

        self._site_normalize(idx, max_bond=jnp.inf)

        M = self[idx]
        A = jnp.tensordot(M.conj(), M, [[0, 2], [0, 2]])
        
        return jnp.trace(A @ op)

    def entanglement_entropy(self, idx):
        """
        Entanglement entropy calculation between sites idx and idx+1.
        """

        self._bond_normalize(idx, max_bond=jnp.inf)
        svs2 = jnp.square(self.svs[idx])

        return -jnp.sum(svs2 * jnp.log2(svs2))

In [251]:
key = jax.random.PRNGKey(0)
MPS_SIZE = 4
LOCAL_DIM = 3
BOND_DIM = 6 

# Spliting the key
key_mps, key_data, key_noise, key_run = jax.random.split(key, num=4)

# target MPS model
mps = random_mps(key_mps, size=MPS_SIZE, local_dim=LOCAL_DIM, bond_dim=BOND_DIM, dtype=jnp.complex128)

In [248]:
mps.normalize('left')

In [252]:
mps.norm

DeviceArray(1.52819807, dtype=float64)

In [253]:
mps.get_overlap(mps.conjugate)

DeviceArray(1.52819807-2.42861287e-17j, dtype=complex128)

In [178]:
mps.is_normalized, mps._normalization_method, mps.norm

(True, 'left', 1.528198068605834)

In [46]:
def random_matrix(
    key: PRNGKeyArray,
    dim1: int,
    dim2: Optional[int] = None,
    dtype=jnp.complex128
    ) -> DeviceArray:
    """
    Genarate a random matrix of size (dim1, dim2).

    Input:
    ------
    dim1:   First dimension.
    dim2:   Second dimension. Optional, set to be
            equal to the first dimension if not provided.
    """
    if not dim2:
        # generate a square matrix
        dim2 = dim1
    return jax.random.normal(key, (dim1, dim2), dtype=dtype)

def random_unit_spectrum(
    key: PRNGKeyArray,
    dim1: int,
    dim2: Optional[int]=None,
    dtype=jnp.complex128
    ) -> DeviceArray:
    """
    Genarate a random matrix of size (dim1, dim2)
    with a unit spectrum.

    Input:
    ------
    dim1:   First dimension.
    dim2:   Second dimension. Optional, set to be
            equal to the first dimension if not provided.
    """
    if not dim2:
        dim2 = dim1
    X = random_matrix(key, dim1, dim2, dtype=dtype)
    U, _, Vh = jnp.linalg.svd(X, full_matrices=False)
    return (U * jnp.ones(U.shape[1])) @ Vh

In [131]:
def random_mps(
    key: PRNGKeyArray,
    size: int,
    local_dim: int,
    bond_dim: int,
    dtype=jnp.complex128) -> List[DeviceArray]:
    """
    Generate a random MPS where each core tensor
    is drawn i.i.d. from a uniform distribution 
    between -1 and 1.

    Input:
    ------
    key:        The random key.
    size:       The size (length) of an MPS.
    local_dim:  The local dimension size.
    bond_dim:   The bond dimension size.
    dtype:      The type of data to return.
    """
    # initialize MPS data collection
    mps = MPS(size, local_dim, bond_dim)
     
    for i in range(size):
        key, _ = jax.random.split(key)
        if i == 0:  # left most tensor
            tensor = random_unit_spectrum(key, local_dim, bond_dim, dtype=dtype)
            tensor = tensor.reshape(1, local_dim, bond_dim)
        elif i == size-1:  # right most tensor
            tensor = random_unit_spectrum(key, bond_dim, local_dim, dtype=dtype)
            tensor = tensor.reshape(bond_dim, local_dim, 1)
        else:  # middle tensors
            tensor = random_unit_spectrum(key, bond_dim*local_dim, bond_dim, dtype=dtype)
            tensor = tensor.reshape(bond_dim, local_dim, bond_dim)
        mps[i] = tensor
    return mps

In [132]:
mps.norm

DeviceArray(1.47651315+1.04083409e-17j, dtype=complex128)

In [49]:
mps.normalized()

AttributeError: 'list' object has no attribute 'normalized'

In [7]:
def dot(mps1: List[DeviceArray], mps2: List[DeviceArray]) -> jnp.double:
    """
    Dot product of an MPS with another mps.
    --A1----A2--...--An-- (MPS1)
      |     |        |
    
      |     |        |
    --B1----B2--...--Bn-- (MPS2)
    """
    # contracts individual components
    cdot = lambda x, y: jnp.einsum('pqr,uqv->purv', x, y)
    # multiply two neighbouring tensors
    mult = lambda x, y: jnp.einsum('purv,rvts->puts', x, y)
    # contract all
    res = reduce(mult, jax.tree_multimap(cdot, mps1, mps2))
    return res.squeeze()

def mps_norm(mps: List[DeviceArray]) -> jnp.double:
    """Computing the squared norm of an MPS"""
    mps_conj = jax.tree_map(jnp.conj, mps)
    return dot(mps, mps_conj)

## Samples

In [8]:
def _random_sample(
    key: PRNGKeyArray, 
    num_factors: int, 
    local_dim: int, 
    dtype=jnp.double) -> DeviceArray:
    """Generate a single sample with a number of factors
    where each factor is generated from a Normal distribution.
    """
    keys = jax.random.split(key, num=num_factors)
    func = lambda k: jax.random.normal(k, (local_dim,), dtype)
    return jax.vmap(func)(keys)

def random_samples(
    key: PRNGKeyArray, 
    sample_size: int, 
    num_factors: int, 
    local_dim: int, 
    dtype=jnp.double) -> DeviceArray:
    """Genarate random samples of a specific size"""
    keys = jax.random.split(key, num=sample_size)
    return jax.vmap(lambda k: _random_sample(k, num_factors, local_dim, dtype))(keys)

In [9]:
def sample_as_mps(sample: DeviceArray) -> List[DeviceArray]:
    """
    Represent a data sample as an MPS. Useful for contracting with another MPS.
    |    |    |       |        |     |     |          |
    x1   x2   x3 ...  xn  => --x1----x2----x3-- ... --xn--
    """
    return list(sample[:,jnp.newaxis,:,jnp.newaxis])

def dot_samples(mps: List[DeviceArray], samples: DeviceArray) -> DeviceArray:
    """Apply dot product to many samples"""
    return jax.vmap(lambda s: dot(mps,sample_as_mps(s)))(samples)

# Helpers

In [10]:
def save_pkl(file_path: str, data: Any) -> None:
    dir_name = os.path.dirname(file_path)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)

In [11]:
def idict() -> DefaultDict:
    """Infinitely nested dict"""
    return defaultdict(idict)

def idict2dict(dic) -> Dict:
    if isinstance(dic, defaultdict):
        dic = {k: idict2dict(v) for k, v in dic.items()}
    return dic

In [12]:
def make_dirs(root_dir: str, noise_level: float) -> None:
    # A timestamp used in the experiments to store data
    time_stamp = time.strftime('%Y%m%d', time.localtime())
    exp_dir = os.path.join(f'{root_dir}/{time_stamp}-noise-{noise_level}')
    lrn_dir = os.path.join(exp_dir, 'learning')  # stores the learning progress
    res_dir = os.path.join(exp_dir, 'results')   # stores the experimenal results
    for d in (lrn_dir, res_dir):
        if not os.path.isdir(d):
            os.makedirs(d)
        elif len(os.listdir(d)) > 0:
            raise FileExistsError(f'Directory {d} is not empty!')
        else:
            pass
    return exp_dir, lrn_dir, res_dir

## Constants

In [13]:
# PRNG seed
SEED = 123

# model size (MPS_SIZE * LOCAL_DIM * BOND_DIM ~ TRAIN_SIZE)
MPS_SIZE = 4
LOCAL_DIM = 4
BOND_DIM = 8

# data sample
TRAIN_SIZE = 1000
TEST_SIZE  = 5000

# training params
LEARNING_RATE = 1e-4

# max num of epochs
NUM_EPOCHS = int(1)

# batch size
BATCH_SIZE = 50

# APPROX RANK
APPROX_RANK = [16, 14, 12, 10, 8, 6, 4, 2]

# NOISE MODEL
PERCENT_NOISE = [0.1, 0.25, 0.5, 1, 5, 10] # noise level in percentages to the data std

# SAVE/PRINT after that many epochs
SAVE_AFTER_EPOCHS = 100

In [14]:
def save_settings(settings_file):
    with open(settings_file, 'w') as f:
        txt = f"""
SEED = {SEED}

# model size (MPS_SIZE * LOCAL_DIM + BOND_DIM = TRAIN_SIZE)
MPS_SIZE = {MPS_SIZE}
LOCAL_DIM = {LOCAL_DIM}
BOND_DIM = {BOND_DIM}

# data sample
TRAIN_SIZE = {TRAIN_SIZE}
TEST_SIZE = {TEST_SIZE}

# training params
LEARNING_RATE = {LEARNING_RATE}

# max num of epochs
NUM_EPOCHS = {NUM_EPOCHS}

# batch size
BATCH_SIZE = {BATCH_SIZE}

# APPROX RANK
APPROX_RANK = {APPROX_RANK}

# NOISE MODEL
PERCENT_NOISE = {PERCENT_NOISE}
"""
        f.write(txt)


## Data Prep

In [15]:
key = jax.random.PRNGKey(SEED)

# Spliting the key
key_params, key_data, key_noise, key_run = jax.random.split(key, num=4)

# target MPS model
true_params = random_mps(key_params, size=MPS_SIZE, local_dim=LOCAL_DIM, bond_dim=BOND_DIM)

# generate samples
data = random_samples(key_data, sample_size=TRAIN_SIZE+TEST_SIZE, num_factors=MPS_SIZE, local_dim=LOCAL_DIM)

# train/test split
train_data, test_data = data[:TRAIN_SIZE], data[TRAIN_SIZE:]

# test targets
test_targets = dot_samples(true_params, test_data)

## Loss function

In [16]:
def loss(params: List[DeviceArray], data: Tuple[DeviceArray, DeviceArray]) -> jnp.double:
    inputs, targets = data
    outputs = dot_samples(params, inputs)
    err = jnp.subtract(targets, outputs)
    return 0.5 * jnp.mean(jnp.log(jnp.power(err, 2) + 10))
    # return jnp.sqrt(0.5 * jnp.mean(err))

# computing the gradient wrt the loss
grad_loss = jit(grad(loss, argnums=0))

## Optimization

In [17]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

## Execution

In [18]:
def gauss_noise(key: PRNGKeyArray, sample_size: int, scale=1.0) -> DeviceArray:
    return scale * jax.random.normal(key, shape=(sample_size,))

In [19]:
root_dir = './experiment'

# determining the step size for SGD
num_complete_batches, leftover = divmod(TRAIN_SIZE, BATCH_SIZE)
num_batches = num_complete_batches + bool(leftover)

In [20]:
def data_iterator():
    while True:
        perm = jax.random.permutation(key_run, TRAIN_SIZE)
        for i in range(num_batches):
            batch_idx = perm[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            yield train_data[batch_idx], train_targets[batch_idx]


In [21]:
for perc_noise in PERCENT_NOISE:
    
    # noise model
    noise = gauss_noise(key_noise, sample_size=TRAIN_SIZE, scale=perc_noise * jnp.std(jnp.asarray(train_data)))

    # generate outputs by contracting MPS with data
    train_targets = dot_samples(true_params, train_data) + noise
    
    # making exp directory
    exp_dir, lrn_dir, res_dir = make_dirs(root_dir, perc_noise)
    
    # storing the settings into a file
    settings_file = os.path.join(exp_dir, 'settings.txt')
    save_settings(settings_file)

    # storing results
    results = idict()

    ref_loss_tr = jnp.inf
    ref_loss_te = jnp.inf

    for approx_rank in APPROX_RANK:
    
        # timer
        tic = time.time()
        
        print(f'Approximation rank: {approx_rank}')
        print('='*100)
        
        loss_tr = []
        loss_te = []
        
        # get access to the data batches stream
        batches = data_iterator()
        
        # params for the optimization
        opt_init, opt_update, get_params = optimizers.adam(step_size=LEARNING_RATE)
        
        # initialize MPS parameters randomly but a different key is used from the true params
        init_params = random_mps(key_run, size=MPS_SIZE, local_dim=LOCAL_DIM, bond_dim=approx_rank)
        opt_state = opt_init(init_params)

        # iteration counter
        itercounter = itertools.count()
            
        # Main loop
        for epoch in range(NUM_EPOCHS):

            # update parameters
            for _ in range(num_batches):
                opt_state = update(next(itercounter), opt_state, next(batches))
            
            # get new params
            params = get_params(opt_state)
            
            # Generalization risk
            l_tr = loss(params, (train_data, train_targets))
            l_te = loss(params, (test_data, test_targets))
                    
            # storing errors for statistics (saving memory)
            loss_tr.append(l_tr)
            loss_te.append(l_te)
            
            # printing epochs
            if epoch % SAVE_AFTER_EPOCHS == 0:
                
                print(f'Epoch: {epoch:<15} \t|\t Train loss: {l_tr:<10.3f} \t|\t Test loss: {l_te:<10.3f}')
            
                # storing parameters during training
                file_path = os.path.join(lrn_dir, f'./approx_rank_{approx_rank}/epoch_{epoch}.pkl')
                save_pkl(file_path, params)
            
            # update the reference
            ref_loss_tr = l_tr
            ref_loss_te = l_te
            
        # storing train/test loss
        results["train"][approx_rank] = loss_tr
        results["test"][approx_rank] = loss_te

        print('-'*100)
        print(f'Time for rank {approx_rank}: {(time.time() - tic):0.2f} sec')
        print(f'Train loss: {ref_loss_tr:0.2f}')
        print(f'Test loss: {ref_loss_te:0.2f}')
        print('='*100)
        
        file_path = os.path.join(res_dir, f'./approx_rank_{approx_rank}/loss.pkl')
        save_pkl(file_path, idict2dict(results))

Approximation rank: 16
Epoch: 0               	|	 Train loss: 3.706      	|	 Test loss: 3.693     
----------------------------------------------------------------------------------------------------
Time for rank 16: 2.12 sec
Train loss: 3.71
Test loss: 3.69
Approximation rank: 14
Epoch: 0               	|	 Train loss: 3.547      	|	 Test loss: 3.565     
----------------------------------------------------------------------------------------------------
Time for rank 14: 1.48 sec
Train loss: 3.55
Test loss: 3.56
Approximation rank: 12
Epoch: 0               	|	 Train loss: 3.219      	|	 Test loss: 3.242     
----------------------------------------------------------------------------------------------------
Time for rank 12: 1.41 sec
Train loss: 3.22
Test loss: 3.24
Approximation rank: 10
Epoch: 0               	|	 Train loss: 3.199      	|	 Test loss: 3.203     
----------------------------------------------------------------------------------------------------
Time for rank 10: 1.