# About

- `Title:` "Lower and Upper Bounds on the VC-Dimension of Tensor Network Models"
- `Main Author:` Behnoush Khavari
- `Source:` [ArXiv](https://arxiv.org/abs/2106.11827)
- `Publish Date:` 22-06-2021
- `Reviewed Date:` 22-11-2021

## Citation

```latex
@article{khavari2021lower,
  title={Lower and Upper Bounds on the VC-Dimension of Tensor Network Models},
  author={Khavari, Behnoush and Rabusseau, Guillaume},
  journal={arXiv preprint arXiv:2106.11827},
  year={2021}
}
```

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
%env JAX_ENABLE_X64=1
%env JAX_PLATFORM_NAME=cpu

env: JAX_ENABLE_X64=1
env: JAX_PLATFORM_NAME=cpu


In [3]:
import jax
import jax.numpy as jnp
from jax import jit, grad

In [4]:
from typing import Any

PRNGKeyArray = Any # type alias

In [5]:
import matplotlib.pyplot as plt

plt.style.use('classic')
plt.rcParams['xtick.labelsize'] = 16
plt.rcParams['ytick.labelsize'] = 16

In [6]:
import pickle
import time
import itertools
from functools import reduce

# Helpers

In [7]:
from collections import defaultdict

def idict():
    """Infinitely nested dict"""
    return defaultdict(idict)

def idict2dict(dic):
    if isinstance(dic, defaultdict):
        dic = {k: idict2dict(v) for k, v in dic.items()}
    return dic

In [8]:
import os
import pickle

def save_pkl(file_path, data):
    dir_name = os.path.dirname(file_path)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)

# Random Data Generation

In [9]:
def random_mps(
    key: PRNGKeyArray,
    size: int,
    local_dim: int,
    bond_dim: int,
    dtype=jnp.float64):
    """
    Generate a random MPS where each core tensor
    is drawn i.i.d. from a uniform distribution 
    between -1 and 1.

    Input:
    ------
    key:        The random key.
    size:       The size (length) of an MPS.
    local_dim:  The local dimension size.
    bond_dim:   The bond dimension size.
    dtype:      The type of data to return.
    """
    # initialize MPS data collection
    mps = []
     
    for i in range(size):
        key, _ = jax.random.split(key)
        if i == 0:  # left most tensor
            tensor = jax.random.uniform(
                key, shape=(1, local_dim, bond_dim), minval=-1, maxval=1, dtype=dtype)
        elif i == size-1:  # right most tensor
            tensor = jax.random.uniform(
                key, shape=(bond_dim, local_dim, 1), minval=-1, maxval=1, dtype=dtype)
        else:  # middle tensors
            tensor = jax.random.uniform(
                key, shape=(bond_dim, local_dim, bond_dim), minval=-1, maxval=1, dtype=dtype)
        mps.append(tensor)

    return mps

In [10]:
def random_sample(
    key: PRNGKeyArray,
    size: int,
    local_dim: int,
    n_factors: int,
    dtype=jnp.float64):
    """
    Generate random data samples where components
    corrsponding to MPS tensors are drawn i.i.d. 
    from a normal distribution.

    Input:
    ------
    key:        The random key.
    size:       The sample size.
    local_dim:  The dimension of each sample.
    n_factors:  The number of factors (equal to the MPS size).
    dtype:      The type of data to return.
    """
    # initialize the collection
    samples = []
        
    for _ in range(size):
        x = []  # collects the components of a single sample 
        for _ in range(n_factors):
            key, _ = jax.random.split(key, num=2)
            x.append(
                jax.random.normal(
                    key, 
                    shape=(local_dim,), 
                    dtype=dtype)
            )
        samples.append(x)

    return jnp.asarray(samples)

In [11]:
def fully_contract(mps):
    """Fully contract the MPS with its conjugate."""
    tensors = [jnp.einsum('pqr,uqv->purv', t, t.conj()) for t in mps]
    res = reduce(lambda x,y: jnp.einsum('purv,rvts->puts', x,y), tensors)
    return res.squeeze()

def mps_norm(mps):
    return fully_contract(mps)

def contract(params, samples):
    """Contract MPS with data"""
    
    def _multiply(params, x):
        """
        Contract an individual core tensor with a data component.
        [p] --W-- [r]
          [q] |       = [p] --(Wx)-- [r]
              x
        """
        tensors = [
            jnp.einsum('pqr,q->pr', w, xi)
            for w, xi in zip(params, x)
        ]
        # full reduction (contraction)
        return reduce(lambda x, y: x @ y, tensors)
    
    prods = jnp.asarray([
        _multiply(params, s)
        for s in samples
    ])
    
    return prods.squeeze()

In [12]:
# key = jax.random.PRNGKey(0)
# mps = random_mps(key, 3, 3, 5)

# norm_1 = norm_squared(mps).item()

# a = jnp.tensordot(mps[0], mps[0].conj(), axes=((1),(1)))
# b = jnp.tensordot(mps[1], mps[1].conj(), axes=((1),(1)))
# c = jnp.tensordot(mps[2], mps[2].conj(), axes=((1),(1)))

# ab  = jnp.tensordot(a, b, axes=((1,3),(0,2)))
# abc = jnp.tensordot(ab, c, axes=((2,3),(0,2)))

# norm_2 = abc.item()

# import numpy as np
# from ncon import ncon

# mps = list(map(np.array, mps))
# TensorArray = [mps[0], mps[0].conj(), mps[1], mps[1].conj(), mps[2], mps[2].conj()]
# IndexArray = [[-1,1,2],[-2,1,4],[2,3,6],[4,3,8],[6,5,-4],[8,5,-3]]
# norm_3 = ncon(TensorArray, IndexArray).item()

# print (norm_1, norm_2, norm_3)

## Constants

In [13]:
# PRNG seed
SEED = 161803

# model size (MPS_SIZE * LOCAL_DIM + BOND_DIM = TRAIN_SIZE)
MPS_SIZE = 4
LOCAL_DIM = 4
BOND_DIM = 6

# data sample
TRAIN_SIZE = 100
TEST_SIZE  = 200

# training params
LEARNING_RATE = 1e-1

# max num of iterations
MAX_STEPS = int(5e4)

# APPROX RANK
APPROX_RANK = list(range(2,11,2))

# NOISE MODEL
PERCENT_NOISE = [0.1, 0.25, 0.5, 1, 5, 10] # noise level in percentages to the data std

# SAVE/PRINT after that many epochs
SAVE_AFTER_EPOCHS = 50

In [14]:
def make_dirs(root_dir, noise_level):
    # A timestamp used in the experiments to store data
    time_stamp = time.strftime('%Y%m%d', time.localtime())
    exp_dir = os.path.join(f'{root_dir}/{time_stamp}-noise-{noise_level}')
    lrn_dir = os.path.join(exp_dir, 'learning')  # stores the learning progress
    res_dir = os.path.join(exp_dir, 'results')   # stores the experimenal results

    for d in (lrn_dir, res_dir):
        if not os.path.isdir(d):
            os.makedirs(d)
        elif len(os.listdir(d)) > 0:
            raise FileExistsError(f'Directory {d} is not empty!')
        else:
            pass
        
    return exp_dir, lrn_dir, res_dir

In [15]:
def save_settings(settings_file):
    if not os.path.exists(settings_file):
        # storing for records
        with open(settings_file, 'w') as f:
            txt = f"""
SEED = {SEED}

# model size (MPS_SIZE * LOCAL_DIM + BOND_DIM = TRAIN_SIZE)
MPS_SIZE = {MPS_SIZE}
LOCAL_DIM = {LOCAL_DIM}
BOND_DIM = {BOND_DIM}

# data sample
TRAIN_SIZE = {TRAIN_SIZE}
TEST_SIZE = {TEST_SIZE}

# training params
LEARNING_RATE = {LEARNING_RATE}

# max num of iterations
MAX_STEPS = {MAX_STEPS}

# APPROX RANK
APPROX_RANK = {APPROX_RANK}

# NOISE MODEL
PERCENT_NOISE = {PERCENT_NOISE}
"""
            f.write(txt)
    else:
        raise FileExistsError(f'File {settings_file} already exists - STOP!')

## Loss function

In [16]:
@jit
def loss(params, data):
    inputs, targets = data
    outputs = contract(params, inputs)
    err = jnp.power(targets - outputs, 2)
    return 0.5 * jnp.mean(jnp.log(err + 0.5))
    # return jnp.sqrt(0.5 * jnp.mean(err))

# computing the gradient wrt the loss
grad_loss = jit(grad(loss, argnums=0))

def line_search(params, data, lr=1e-2, beta=1e-6):
    """Armijo line search"""
    grads = grad_loss(params, data)
    while True:
        params_new = [
            p - lr * g 
            for p, g in zip(params, grads)
        ]
        if loss(params_new, data) > loss(params, data):#  - beta * lr * mps_norm(grads):
            lr /= 2
        else:
            break
    return params_new

## Data Prep

In [17]:
key = jax.random.PRNGKey(SEED)

In [18]:
# Spliting the key
key_mps, key_data, key_noise, key_run = jax.random.split(key, num=4)

In [19]:
# target MPS model
true_params = random_mps(key_mps, size=MPS_SIZE, local_dim=LOCAL_DIM, bond_dim=BOND_DIM)

# generate samples
data = random_sample(key_data, size=TRAIN_SIZE+TEST_SIZE, local_dim=LOCAL_DIM, n_factors=MPS_SIZE)

# train/test split
train_data, test_data = data[:TRAIN_SIZE], data[TRAIN_SIZE:]

# test targets
test_targets = contract(true_params, test_data)

## Execution

In [20]:
def gauss_noise(key, sample_size, scale=1.0):
    return scale * jax.random.normal(key, shape=(sample_size,))

In [21]:
root_dir = './experiment'

for perc_noise in PERCENT_NOISE:
    
    # timer
    tic = time.time()
    
    # noise model
    noise = gauss_noise(key_noise, sample_size=TRAIN_SIZE, scale=perc_noise * jnp.std(train_data))

    # generate outputs by contracting MPS with data
    train_targets = contract(true_params, train_data) + noise

    # making exp directory
    exp_dir, lrn_dir, res_dir = make_dirs(root_dir, perc_noise)
    
    # storing the settings into a file
    settings_file = os.path.join(exp_dir, 'settings.txt')
    save_settings(settings_file)

    # storing results
    results = idict()

    ref_loss_tr = 1e6
    ref_loss_te = 1e6

    for approx_rank in APPROX_RANK:
            
        print(f"Approximation rank: {approx_rank}")
        print('='*50)
        
        loss_tr = []
        loss_te = []
        
        # initialize MPS parameters randomly but different SEED than was used for true params
        params = random_mps(key_run, size=MPS_SIZE, local_dim=LOCAL_DIM, bond_dim=approx_rank)

        # counter
        counter = itertools.count()
        step = next(counter)
            
        # Looping until condition
        while step < MAX_STEPS:

            # update parameters
            params_new = line_search(params, (train_data, train_targets), LEARNING_RATE)
                    
            # Generalization risk
            l_tr = loss(params_new, (train_data, train_targets))
            l_te = loss(params_new, (test_data, test_targets))
                    
            # printing epochs
            if step % SAVE_AFTER_EPOCHS == 0:
                
                # storing errors for statistics (saving memory)
                loss_tr.append(l_tr)
                loss_te.append(l_te)
                
                print(f'Step: {step:<15} \t|\t Train loss: {l_tr:<10.3f} \t|\t Test loss: {l_te:<10.3f}')
                
                # storing parameters during training
                file_path = os.path.join(lrn_dir, f'./approx_rank_{approx_rank}/step_{step}.pkl')
                save_pkl(file_path, params)

            # update the parameters
            params = params_new

            # advance the counter
            step = next(counter)
            
            # update the reference
            ref_loss_tr = l_tr
            ref_loss_te = l_te
            
        # storing train/test loss
        results["train"][approx_rank] = loss_tr
        results["test"][approx_rank] = loss_te

        print('-'*100)
        print(f'Time for rank {approx_rank}: {(time.time() - tic):0.2f} sec')
        print(f'Train loss: {ref_loss_tr:0.2f}')
        print(f'Test loss: {ref_loss_te:0.2f}')
        print('='*100)
        
        file_path = os.path.join(res_dir, 'loss.pkl')
        save_pkl(file_path, idict2dict(results))

Approximation rank: 2
Step: 0               	|	 Train loss: 1.915      	|	 Test loss: 2.012     
Step: 50              	|	 Train loss: 1.739      	|	 Test loss: 2.015     
Step: 100             	|	 Train loss: 1.688      	|	 Test loss: 1.987     
Step: 150             	|	 Train loss: 1.651      	|	 Test loss: 1.975     
Step: 200             	|	 Train loss: 1.641      	|	 Test loss: 1.977     
Step: 250             	|	 Train loss: 1.632      	|	 Test loss: 1.979     
Step: 300             	|	 Train loss: 1.621      	|	 Test loss: 1.979     
Step: 350             	|	 Train loss: 1.612      	|	 Test loss: 1.978     
Step: 400             	|	 Train loss: 1.607      	|	 Test loss: 1.977     
Step: 450             	|	 Train loss: 1.602      	|	 Test loss: 1.976     
Step: 500             	|	 Train loss: 1.597      	|	 Test loss: 1.974     
Step: 550             	|	 Train loss: 1.591      	|	 Test loss: 1.972     
Step: 600             	|	 Train loss: 1.584      	|	 Test loss: 1.969     
Ste