In [None]:
import random
import numpy as np
import math
import quimb.tensor as qtn
import smpo

In [None]:
mpo = smpo.SpacedMatrixProductOperator.rand(n=8, spacing=1)

In [None]:
mpo.norm()

#### Check normalization

In [None]:
mpo.normalize(insert=0)

In [None]:
mpo.norm()

#### Check trace

In [None]:
mpo.trace(left_inds=mpo.upper_inds, right_inds=mpo.lower_inds)

#### Check canonization

In [None]:
mpo.canonize(0)

In [None]:
mpo.tensors[0]

In [None]:
mpo.tensors[0].norm()

In [None]:
mpo.tensors[1].norm()

In [None]:
mpo.tensors[2].norm()

In [None]:
(mpo.tensors[1] & mpo.tensors[1]) ^ all

In [None]:
mpo.norm()

In [None]:
mpo.canonize((1,4))

In [None]:
mpo.norm()

In [None]:
mpo.tensors[0].norm()

In [None]:
mpo.calc_current_orthog_center()

#### Check MPS&MPO

In [None]:
#mps = qtn.MPS_rand_state(L=8, bond_dim=4)
mps = qtn.MPS_rand_computational_state(L=8) # all 0 and 1s

In [None]:
(mpo&mps).draw(show_tags=False, show_inds=True, output_inds=[*mpo.lower_inds, *mpo.upper_inds])

In [None]:
((mpo&mps)^mpo.site_tag(0)^mpo.site_tag(1)^mpo.site_tag(2)^mpo.site_tag(3)^mpo.site_tag(4)^mpo.site_tag(5)^mpo.site_tag(6)^mpo.site_tag(7)).draw()

In [None]:
mpo_copy = mpo.copy(deep=True)
mpo_copy.draw()

### Optimization

In [None]:
from tqdm import tqdm
import embeddings as e
import FeatureMap as fm
import itertools

In [None]:
import importlib
importlib.reload(smpo)

In [None]:
def loss_miss(phi, P):
    #[phiH, PH, phi, P] = qtn.tensor_network_align(phi.H, P.H, P, phi)
    phi_orig_renamed = phi.reindex({f'k0': f'k_0', f'k1': f'k_1', 'k2':'k_2', 'k3':'k_3', 'k4':'k_4', 'k5':'k_5', 'k6':'k_6', 'k7':'k_7'})
    P_orig_renamed = P.reindex({f'k0': f'k_0', f'k1': f'k_1', 'k2':'k_2', 'k3':'k_3', 'k4':'k_4', 'k5':'k_5', 'k6':'k_6', 'k7':'k_7'})
    return math.pow((math.log((phi_orig_renamed.H&P_orig_renamed.H&P&phi)^all) - 1), 2)

In [None]:
def gradient_miss(phi, P_orig, P_rem, sites):
    index_to_remove = sites[0] if (sites[1] == 0 and sites[1] < sites[0]) or (sites[1] == N_features-1 and sites[1] > sites[0]) else sites[1]
    phi_orig_renamed = phi.reindex({f'k0': f'k_0', f'k1': f'k_1', 'k2':'k_2', 'k3':'k_3', 'k4':'k_4', 'k5':'k_5', 'k6':'k_6', 'k7':'k_7'})
    P_orig_renamed = P_orig.reindex({'k0': f'k_0', f'k1': f'k_1', 'k2':'k_2', 'k3':'k_3', 'k4':'k_4', 'k5':'k_5', 'k6':'k_6', 'k7':'k_7',\
                                     f'bond_{index_to_remove}': f'bond{index_to_remove}'})
    l2_norm = (phi_orig_renamed.H&P_orig_renamed.H&P_orig&phi)^all
    
    first = (phi.H&P_rem.H&P_orig_renamed&phi_orig_renamed)^all
    second = (phi_orig_renamed.H&P_orig_renamed.H&P_rem&phi)^all
    
    return 2*(math.log(l2_norm) - 1) * (1 / l2_norm) * (first+second)

In [None]:
def loss_reg(P, alpha):
    return alpha*max(0, math.log((P.H&P)^all))

In [None]:
def gradient_reg(P_orig, P_rem, alpha, sites):
    frob_norm_sq = (P_orig.H&P_orig)^all
    index_to_remove = sites[0] if (sites[1] == 0 and sites[1] < sites[0]) or (sites[1] == N_features-1 and sites[1] > sites[0]) else sites[1]
    P_rem_renamed = P_rem.reindex({f'bond_{index_to_remove}': f'b_{index_to_remove}'})
    return 2*alpha*(1/frob_norm_sq) * (((P_rem_renamed.H&P_orig)^all)&((P_orig.H&P_rem_renamed)^all)^all if frob_norm_sq >= 1 else 0)

In [None]:
lamda = 0.5
alpha = 0.3 # regularization
N_features = 8

In [None]:
batch_size = 100
n_epochs = int(10000/batch_size)

In [None]:
inputs = np.random.rand(10000, N_features)
inputs_batch = np.split(inputs, n_epochs)

In [None]:
# init P
P_orig = smpo.SpacedMatrixProductOperator.rand(n=N_features, spacing=2)

In [None]:
ortog_center = 0
P = P_orig.copy(deep=True)
P.canonize(ortog_center)

In [None]:
P_orig.draw(show_inds='all', font_size_inner=15, figsize=(15,15))

In [None]:
with tqdm(range(n_epochs)) as progressbar:
    for it in progressbar:
        sweeps = itertools.chain(zip(list(range(0,N_features-1)), list(range(1,N_features))), reversed(list(zip(list(range(1,N_features)),list(range(0,N_features-1))))))
        for sweep_it, sites in enumerate(sweeps):
            [sitel, siter] = sites
            site_tags = [P.site_tag(site) for site in sites]
            
            # canonize P with root in sites
            P.canonize(sites, cur_orthog=ortog_center)
            ortog_center = sites

            # pop site tensor
            [origl, origr] = P.select_tensors(site_tags, which="any")
            tensor_orig = origl & origr ^ all
            tensor_orig.draw(show_inds='all', font_size_inner=15, figsize=(15,15))
            
            #virtual bonds
            #     left
            if sitel == 0 or (sitel == N_features-1 and sitel>siter): vindl = []
            elif sitel>0 and sitel<siter: vindl = [P.bond(sitel-1, sitel)]
            else: vindl = [P.bond(sitel, sitel+1)]
            #.    right
            if siter == N_features - 1 or (siter == 0 and siter<sitel): vindr = []
            elif siter < N_features-1 and siter>sitel: vindr = [P.bond(siter, siter+1)]
            else: vindr = [P.bond(siter-1, siter)]
            
            P.delete(site_tags, which="any")

            grad_miss=0; loss_miss_batch=0
            
            for sample in inputs_batch[it]:
                phi, _ = fm.embed(sample, fm.trigonometric)
                
                #calculate loss
                loss_miss_batch += loss_miss(phi, P_orig)

                #calculate gradient
                grad_miss += gradient_miss(phi, P_orig, P, sites)

            loss = (1/batch_size)*(loss_miss_batch) + loss_reg(P_orig, alpha)
            progressbar.set_postfix(loss = loss)
            grad = (1/batch_size)*grad_miss + gradient_reg(P_orig, P, alpha, sites)
            index_to_remove = sites[0] if (sites[1] == 0 and sites[1] < sites[0]) or (sites[1] == N_features-1 and sites[1] > sites[0]) else sites[1]
            grad = grad.reindex({f'bond{index_to_remove}': f'bond_{index_to_remove}'})
            
            tensor_new = tensor_orig - lamda*grad
            tensor_new.normalize(inplace=True)
            [tensorl, tensorr] = tensor_new.split(get="tensors", left_inds=[P.upper_ind(sitel), P.lower_ind(sitel), *vindl])

            # link new tensor to psi
            for site, tensor in zip(sites, [tensorl, tensorr]):
                tensor.drop_tags()
                tensor.add_tag(P.site_tag(site))
                P.add_tensor(tensor)
            P.draw(show_inds='all', font_size_inner=15, figsize=(15,15))