In [20]:
%load_ext autoreload
%autoreload 2

from src.dataset_utils import bars_and_stripes, hypertn_from_data, plot_binary_data

import numpy as np
import quimb.tensor as qtn
import matplotlib.pyplot as plt

import cotengra as ctg
ctg_opt = ctg.ReusableHyperOptimizer(max_time = 20, 
                                     minimize='combo',
                                     slicing_opts={'target_size': 2**40},         # first do basic slicing
                                     slicing_reconf_opts={'target_size': 2**28},  # then advanced slicing with reconfiguring
                                     reconf_opts={'subtree_size': 14},            # then finally just higher quality reconfiguring
                                     parallel=True,
                                     progbar=True,
                                     directory=True)

chi = 100
num_sites = 100
dim = int(num_sites ** 0.5)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
psi = qtn.MPS_rand_state(num_sites, chi, dtype=np.complex128)

bas = bars_and_stripes(dim)
htn_data = hypertn_from_data(bas)
htn_data

In [22]:
def nnl_loss(psi, htn_data):
    """
    Calculate the KL divergence between the MPS and the dataset.
    """
    loss = (psi | htn_data).contract(output_inds = ['hyper'], 
                                     optimize = ctg_opt).data
    
    return -1. * np.mean(np.log(np.abs(loss) ** 2))

In [23]:
nnl_loss(psi, htn_data)

  0%|          | 0/128 [00:00<?, ?it/s]

F=9.57 C=9.79 S=18.64 P=21.53:  35%|███▌      | 45/128 [00:20<00:37,  2.20it/s]                 


np.float64(70.50344336403545)