# Imports and settings

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
%env JAX_ENABLE_X64=1
%env JAX_PLATFORM_NAME=cpu

env: JAX_ENABLE_X64=1
env: JAX_PLATFORM_NAME=cpu


In [3]:
import jax
import jax.numpy as jnp
from jax import jit, grad
from jax.example_libraries import optimizers

In [4]:
import os
import pickle
import time
import itertools

from collections import defaultdict
from functools import reduce, partial

In [5]:
from typing import Any, List, DefaultDict, Dict

## Types

In [6]:
# type alias
PRNGKeyArray = Any
DeviceArray = jnp.DeviceArray

# Random Data Generation

## MPS

In [7]:
def random_mps(
    key: PRNGKeyArray,
    size: int,
    local_dim: int,
    bond_dim: int,
    dtype=jnp.double) -> List[jnp.DeviceArray]:
    """
    Generate a random MPS where each core tensor
    is drawn i.i.d. from a uniform distribution 
    between -1 and 1.

    Input:
    ------
    key:        The random key.
    size:       The size (length) of an MPS.
    local_dim:  The local dimension size.
    bond_dim:   The bond dimension size.
    dtype:      The type of data to return.
    """
    # initialize MPS data collection
    mps = []
     
    for i in range(size):
        key, _ = jax.random.split(key)
        if i == 0:  # left most tensor
            tensor = jax.random.uniform(
                key, shape=(1, local_dim, bond_dim), minval=-1, maxval=1, dtype=dtype)
        elif i == size-1:  # right most tensor
            tensor = jax.random.uniform(
                key, shape=(bond_dim, local_dim, 1), minval=-1, maxval=1, dtype=dtype)
        else:  # middle tensors
            tensor = jax.random.uniform(
                key, shape=(bond_dim, local_dim, bond_dim), minval=-1, maxval=1, dtype=dtype)
        mps.append(tensor)

    return mps

In [8]:
def dot(mps1: List[jnp.DeviceArray], mps2: List[jnp.DeviceArray]) -> jnp.double:
    """
    Dot product of an MPS with another mps.
    --A1----A2--...--An-- (MPS1)
      |     |        |
    
      |     |        |
    --A1----A2--...--An-- (MPS2)
    """
    # contracts individual components
    dot = lambda x, y: jnp.einsum('pqr,uqv->purv', x, y)
    # multiply two neighbouring tensors
    mult = lambda x, y: jnp.einsum('purv,rvts->puts', x, y)
    # contract all
    res = reduce(mult, jax.tree_multimap(dot, mps1, mps2))
    return res.squeeze()

def mps_norm(mps: List[jnp.DeviceArray]) -> jnp.double:
    """Computing the squared norm of an MPS"""
    mps_c = jax.tree_map(jnp.conj, mps)
    return dot(mps, mps_c)

## Samples

In [9]:
def _random_sample(key: PRNGKeyArray, num_factors: int, local_dim: int) -> DeviceArray:
    """Generate a single sample with a number of factors
    where each factor is generated from a Normal distribution.
    """
    keys = jax.random.split(key, num=num_factors)
    func = lambda k: jax.random.normal(k, (local_dim,), dtype)
    return jax.vmap(func)(keys)

def random_samples(key: PRNGKeyArray, sample_size: int, num_factors: int, local_dim: int) -> DeviceArray:
    """Genarate random samples of a specific size"""
    keys = jax.random.split(key, num=sample_size)
    return jax.vmap(lambda k: _random_sample(k, local_dim, num_factors))(keys)

In [37]:
def sample_as_mps(sample: DeviceArray) -> List[DeviceArray]:
    """
    Represent a data sample as an MPS. Useful for contracting with another MPS.
    |    |    |       |        |     |     |          |
    x1   x2   x3 ...  xn  -> --x1----x2----x3-- ... --xn--
    """
    return list(sample[:,jnp.newaxis,:,jnp.newaxis])

def samples_dot(mps: List[jnp.DeviceArray], samples: jnp.DeviceArray) -> jnp.DeviceArray:
    """Apply dot product to many samples"""
    return jax.vmap(lambda s: dot(mps,sample_as_mps(s)))(samples)

# Helpers

In [None]:
def save_pkl(file_path: str, data: Any) -> None:
    dir_name = os.path.dirname(file_path)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)

In [None]:
def idict() -> DefaultDict:
    """Infinitely nested dict"""
    return defaultdict(idict)

def idict2dict(dic) -> Dict:
    if isinstance(dic, defaultdict):
        dic = {k: idict2dict(v) for k, v in dic.items()}
    return dic

## Constants

In [None]:
# PRNG seed
SEED = 161803

# model size (MPS_SIZE * LOCAL_DIM + BOND_DIM = TRAIN_SIZE)
MPS_SIZE = 15
LOCAL_DIM = 3
BOND_DIM = 6

# data sample
TRAIN_SIZE = 20
TEST_SIZE  = 50

# training params
LEARNING_RATE = 1e-2

# max num of epochs
NUM_EPOCHS = 500

# batch size
BATCH_SIZE = 10

# APPROX RANK
APPROX_RANK = list(range(2,11,2))

# NOISE MODEL
PERCENT_NOISE = [0.1] #, 0.25, 0.5, 1, 5, 10] # noise level in percentages to the data std

# SAVE/PRINT after that many epochs
SAVE_AFTER_EPOCHS = 50

In [None]:
def save_settings(settings_file):
    if not os.path.exists(settings_file):
        # storing for records
        with open(settings_file, 'w') as f:
            txt = f"""
SEED = {SEED}

# model size (MPS_SIZE * LOCAL_DIM + BOND_DIM = TRAIN_SIZE)
MPS_SIZE = {MPS_SIZE}
LOCAL_DIM = {LOCAL_DIM}
BOND_DIM = {BOND_DIM}

# data sample
TRAIN_SIZE = {TRAIN_SIZE}
TEST_SIZE = {TEST_SIZE}

# training params
LEARNING_RATE = {LEARNING_RATE}

# max num of epochs
NUM_EPOCHS = {NUM_EPOCHS}

# batch size
BATCH_SIZE = {BATCH_SIZE}

# APPROX RANK
APPROX_RANK = {APPROX_RANK}

# NOISE MODEL
PERCENT_NOISE = {PERCENT_NOISE}
"""
            f.write(txt)
    else:
        raise FileExistsError(f'File {settings_file} already exists - STOP!')

In [None]:
def make_dirs(root_dir, noise_level):
    # A timestamp used in the experiments to store data
    time_stamp = time.strftime('%Y%m%d', time.localtime())
    exp_dir = os.path.join(f'{root_dir}/{time_stamp}-noise-{noise_level}')
    lrn_dir = os.path.join(exp_dir, 'learning')  # stores the learning progress
    res_dir = os.path.join(exp_dir, 'results')   # stores the experimenal results

    for d in (lrn_dir, res_dir):
        if not os.path.isdir(d):
            os.makedirs(d)
        elif len(os.listdir(d)) > 0:
            raise FileExistsError(f'Directory {d} is not empty!')
        else:
            pass
        
    return exp_dir, lrn_dir, res_dir

## Loss function

In [None]:
def loss(params, data):
    inputs, targets = data
    outputs = mdot(params, inputs)
    err = jax.tree_multimap(jnp.subtract, targets, outputs)
    return 0.5 * jnp.mean(jnp.log(jnp.power(err, 2) + 10))
    # return jnp.sqrt(0.5 * jnp.mean(err)

# computing the gradient wrt the loss
# grad_loss = jit(grad(loss, argnums=0))
grad_loss = grad(loss, argnums=0)

## Data Prep

In [None]:
key = jax.random.PRNGKey(SEED)

# Spliting the key
key_params, key_data, key_noise, key_run = jax.random.split(key, num=4)

# target MPS model
true_params = random_mps(key_params, size=MPS_SIZE, local_dim=LOCAL_DIM, bond_dim=BOND_DIM)

# generate samples
data = random_sample(key_data, size=TRAIN_SIZE+TEST_SIZE, local_dim=LOCAL_DIM, n_factors=MPS_SIZE)

# train/test split
train_data, test_data = data[:TRAIN_SIZE], data[TRAIN_SIZE:]

# test targets
test_targets = mdot(true_params, test_data)

# Optimization

In [None]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

## Execution

In [None]:
def gauss_noise(key, sample_size, scale=1.0):
    data =  scale * jax.random.normal(key, shape=(sample_size,))
    return jax.tree_map(lambda x: jnp.asarray(x, dtype=jnp.double), data.tolist())

In [None]:
root_dir = './experiment'

# determining the step size for SGD
num_complete_batches, leftover = divmod(TRAIN_SIZE, BATCH_SIZE)
num_batches = num_complete_batches + bool(leftover)

In [None]:
for perc_noise in PERCENT_NOISE:
    
    # timer
    tic = time.time()
    
    # noise model
    noise = gauss_noise(key_noise, sample_size=TRAIN_SIZE, scale=perc_noise * jnp.std(jnp.asarray(train_data)))

    # generate outputs by contracting MPS with data
    train_targets = jax.tree_multimap(jnp.add, mdot(true_params, train_data), noise)
    
    # making exp directory
    exp_dir, lrn_dir, res_dir = make_dirs(root_dir, perc_noise)
    
    # storing the settings into a file
    settings_file = os.path.join(exp_dir, 'settings.txt')
    save_settings(settings_file)

    # storing results
    results = idict()

    ref_loss_tr = 1e6
    ref_loss_te = 1e6

    for approx_rank in APPROX_RANK:
            
        print(f'Approximation rank: {approx_rank}')
        print('='*70)
        
        loss_tr = []
        loss_te = []
        
        def data_iterator():
            while True:
                perm = jax.random.permutation(key_run, TRAIN_SIZE)
                for i in range(num_batches):
                    batch_idx = perm[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
                    train_X, train_y = zip(*[(train_data[i], train_targets[i]) for i in batch_idx])
                    yield train_X, train_y

        batches = data_iterator()
                
        # params for the optimization (initial guess)
        opt_init, opt_update, get_params = optimizers.adam(step_size=LEARNING_RATE)
        
        # initialize MPS parameters randomly but different SEED than was used for true params
        init_params = random_mps(key_run, size=MPS_SIZE, local_dim=LOCAL_DIM, bond_dim=approx_rank)
        opt_state = opt_init(init_params)

        # counter
        itercounter = itertools.count()
            
        # Looping until condition
        for epoch in range(NUM_EPOCHS):

            # update parameters
            # params_new = line_search(params, (train_data, train_targets), LEARNING_RATE)
            for _ in range(num_batches):
                opt_state = update(next(itercounter), opt_state, next(batches))
            
            params = get_params(opt_state)
            
            # import pdb;pdb.set_trace()
            
            # Generalization risk
            l_tr = loss(params, (train_data, train_targets))
            l_te = loss(params, (test_data, test_targets))
                    
            # storing errors for statistics (saving memory)
            loss_tr.append(l_tr)
            loss_te.append(l_te)
            
            # printing epochs
            if epoch % SAVE_AFTER_EPOCHS == 0:
                
                print(f'Epoch: {epoch:<15} \t|\t Train loss: {l_tr:<10.3f} \t|\t Test loss: {l_te:<10.3f}')
            
                # storing parameters during training
                file_path = os.path.join(lrn_dir, f'./approx_rank_{approx_rank}/epoch_{epoch}.pkl')
                save_pkl(file_path, params)
            
            # update the reference
            ref_loss_tr = l_tr
            ref_loss_te = l_te
            
        # storing train/test loss
        results["train"][approx_rank] = loss_tr
        results["test"][approx_rank] = loss_te

        print('-'*100)
        print(f'Time for rank {approx_rank}: {(time.time() - tic):0.2f} sec')
        print(f'Train loss: {ref_loss_tr:0.2f}')
        print(f'Test loss: {ref_loss_te:0.2f}')
        print('='*100)
        
        file_path = os.path.join(res_dir, 'loss.pkl')
        save_pkl(file_path, idict2dict(results))