In [None]:
import numpy as np
import opt_einsum as oe
from copy import deepcopy
import itertools
from scipy import integrate
import pickle as pk
import copy
import math
import scipy

import torch
torch.set_default_dtype(torch.float64)

import os
import sys
sys.path.append(os.path.realpath('/Users/wei/Documents/physics/code/tnpy'))
sys.path.append(os.path.realpath('/Users/wei/Documents/physics/code/tnpy/tnpy'))

%reload_ext autoreload
%autoreload all

import tnpy as tp

In [None]:
t = torch.randn(4, 3)
print(t)

u, s, v = torch.linalg.svd(t, full_matrices=False)
print(u.shape, s.shape, v.shape)
print(u)
print(s)
print(v)

print(torch.linalg.norm((u @ s.diag() @ v)-t))

u, s, v = tp.linalg.svd(t)
print(u.shape, s.shape, v.shape)
print(u)
print(s)
print(v)

print(torch.linalg.norm((u @ s.diag() @ v)-t))

In [None]:
# Heisenberg model
# exact GS energy: E=-0.66944

chi = 4
nx, ny = 2, 1

tps = tp.tps.SquareTPS.rand(nx=nx, ny=ny, chi=chi, cflag=False)
print(tps.coords, tps.nx, tps.ny)

xyz = tp.models.SquareXYZ(Jx=1.0, Jy=1.0, Jz=1.0, cflag=False)
ham_xyz = xyz.twobody_ham()
print(ham_xyz)
teo = xyz.twobody_img_time_evo(delta=0.1)

In [None]:
deltas = [1E-2, 1E-3, 1E-4]
nums = [1000]

counter = 0
for d, n in zip(deltas, nums):

    teo = xyz.twobody_img_time_evo(delta=d)
    u, s, v = tp.linalg.tsvd(teo, group_dims=((0, 2), (1, 3)), svd_dims=(1, 0))
    ss = torch.sqrt(s).diag()
    us = torch.einsum('abc,bB->aBc', u, ss)
    sv = torch.einsum('Aa,abc->Abc', ss, v)
    te_mpo = us, sv

    # print(te_mpo)

    for l in range(n):
        old_lts = tps.link_tensors()
        tps.simple_update_proj(te_mpo)
        counter += 1

        if l % 100 == 0:
            new_lts = tps.link_tensors()
            diff = 0.0
            for key, value in new_lts.items():
                diff += torch.linalg.norm(value[0]-old_lts[key][0])
                diff += torch.linalg.norm(value[1]-old_lts[key][1])
            print(d, l, diff)

            print(tps.betaX_twobody_measure(ham_xyz).item())

In [29]:
# TPS tensors
tps_ts = []
for c in tps.coords:
    tps_ts.append(tps.merged_tensor(c))

tps_t = torch.stack(tps_ts, dim=0).requires_grad_(True)
print(tps_t.shape, tps_t.requires_grad)

# CTM tensors
rho = 8

# corners
cs = [torch.rand(rho, rho).requires_grad_(False) for i in range(4)]

# edges
up_es = [torch.rand(rho, rho, chi, chi).requires_grad_(False) for i in range(nx)]
down_es = [t.clone().detach().requires_grad_(False) for t in up_es]
left_es = [torch.rand(rho, rho, chi, chi).requires_grad_(True) for j in range(ny)]
right_es = [t.clone().detach().requires_grad_(False) for t in left_es]

ctms = [cs, up_es, down_es, left_es, right_es]

print(right_es[0].shape, right_es[0].requires_grad)

torch.Size([2, 4, 4, 4, 4, 2]) True
torch.Size([8, 8, 4, 4]) False


In [None]:
for i in range(4):
    ene = tps.ctm_twobody_measure(tps_tensor=tps_t, ctm_tensors=ctms, op=ham_xyz)
    print(ene)
    ctms = tps.ctmrg(tps_tensor=tps_t, init_ctms=ctms, if_print=True)
    ene = tps.ctm_twobody_measure(tps_tensor=tps_t, ctm_tensors=ctms, op=ham_xyz)
    print(ene)

ene.backward(inputs=[tps_t])
print(tps_t.grad)


In [27]:
class model_xyz(torch.nn.Module):

    # def __init__(self, tps_tensor, dtype=torch.float64, device='cpu', use_checkpoint=False):
    def __init__(self, tps_tensor: torch.tensor):
        super(model_xyz, self).__init__()

        self._t = torch.nn.Parameter(tps_tensor)
        # self._ctms = torch.nn.Parameter(ctms)

    def forward(self, ctms: list, op: torch.tensor):

        # a fixed number of RG to update CTM tensors
        num_rg = 5
        for i in range(num_rg):
            ctms = tps.ctmrg(tps_tensor=self._t, init_ctms=ctms)

        # after RG, calculate ene
        bond_ene = tps.ctm_twobody_measure(tps_tensor=self._t, ctm_tensors=ctms, op=op)

        return bond_ene, ctms
    
model = model_xyz(tps_t)
# res = model(ctms=ctms, op=ham_xyz)
# print(res)

params = list(model.parameters())
print(len(params), params[0].shape)

opt = torch.optim.SGD(model.parameters(), lr=1E-3)

1 torch.Size([2, 4, 4, 4, 4, 2])


In [33]:
loss, ctms = model(ctms, ham_xyz)
print(loss)
opt.zero_grad()
loss.backward(retain_graph=True)
opt.step()

tensor(-0.29297, grad_fn=<DivBackward0>)


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
num_opt = 10
num_rg = 5

for j in range(num_opt):
    print('opt:', j)
    old_ene = tps.ctm_twobody_measure(tps_tensor=tps_t, ctm_tensors=ctms, op=ham_xyz)

    # perform fixed number of RGs
    for i in range(num_rg):
        ctms = tps.ctmrg(tps_tensor=tps_t, init_ctms=ctms)

    # compute energy backpropagation
    ene = tps.ctm_twobody_measure(tps_tensor=tps_t, ctm_tensors=ctms, op=ham_xyz)
    ene.backward(inputs=[tps_t])

    # optimize
    optimizer = torch.optim.SGD()

    # restart a new computation graph
    tps_t = tps_t.detach().requires_grad_(True)
    ctms = ctms.requires_grad_(True)

In [None]:
def test(tps_tensors, op, ctm_tensors):

    cs, up_es, down_es, left_es, right_es = ctm_tensors

    # SVD to MPO
    u, s, v = tp.linalg.tsvd(op, group_dims=((0, 2), (1, 3)), svd_dims=(0, 0))
    ss = torch.sqrt(s).diag()
    us = torch.einsum('Aa,abc->Abc', ss, u)
    sv = torch.einsum('Aa,abc->Abc', ss, v)

    mpo = us, sv

    mts, mts_conj = {}, {}
    for i, c in enumerate(tps.coords):
        # temp = self.merged_tensor(c)
        mts.update({c: tps_tensors[i]})
        mts_conj.update({c: tps_tensors[i].conj()})

    res = 0.0
    for i, c in enumerate(tps.coords):
        res += torch.einsum('abcde,abcde', mts[c], mts_conj[c])

    return torch.einsum('abcde,abcde', mts[(0, 0)], mts_conj[(0, 0)])

    '''
    pair = (0, 0), (1, 0)

    # mps_u = [t.clone().detach() for t in up_es]
    mps_u = [t for t in up_es]
    mps_u.insert(0, cs[2])
    mps_u.append(cs[3])

    # mps_d = [t.clone().detach() for t in down_es]
    mps_d = [t for t in down_es]
    mps_d.insert(0, cs[0])
    mps_d.append(cs[1])

    # build pure and impure double tensors
    pure_dts = [
            torch.einsum('ABCDe,abcde->AaBbCcDd', mts_conj[pair[0]], mts[pair[0]]),
            torch.einsum('ABCDe,abcde->AaBbCcDd', mts_conj[pair[1]], mts[pair[1]])]

    impure_dts = [
            torch.einsum('ABCDE,fEe,abcde->AaBbCfcDd', mts_conj[pair[0]], mpo[0], mts[pair[0]]),
            torch.einsum('ABCDE,fEe,abcde->AfaBbCcDd', mts_conj[pair[1]], mpo[1], mts[pair[1]])]

    temp_num = torch.einsum('ab,bcde,fc->afde', mps_d[0], left_es[0], mps_u[0])
    
    temp_num = torch.einsum('egAa,efDd,AaBbCicDd,ghBb->fhCic', temp_num, mps_d[1], impure_dts[0], mps_u[1])
    temp_num = torch.einsum('egAia,efDd,AiaBbCcDd,ghBb->fhCc', temp_num, mps_d[2], impure_dts[1], mps_u[2])

    num = torch.einsum('fhCc,fi,ijCc,hj', temp_num, mps_d[3], right_es[0], mps_u[3])

    pair = (0, 0), (1, 0)

    mps_u = [t.clone().detach() for t in up_es]
    # mps_u = up_es
    mps_u.insert(0, cs[2])
    mps_u.append(cs[3])

    mps_d = [t.clone().detach() for t in down_es]
    # mps_d = down_es
    mps_d.insert(0, cs[0])
    mps_d.append(cs[1])

    # build pure and impure double tensors
    pure_dts = [
            torch.einsum('ABCDe,abcde->AaBbCcDd', mts_conj[pair[0]], mts[pair[0]]),
            torch.einsum('ABCDe,abcde->AaBbCcDd', mts_conj[pair[1]], mts[pair[1]])]

    impure_dts = [
            torch.einsum('ABCDE,fEe,abcde->AaBbCfcDd', mts_conj[pair[0]], mpo[0], mts[pair[0]]),
            torch.einsum('ABCDE,fEe,abcde->AfaBbCcDd', mts_conj[pair[1]], mpo[1], mts[pair[1]])]

    # denominator
    temp_den = torch.einsum('ab,bcde,fc->afde', mps_d[0], left_es[0], mps_u[0])
    # temp_num = temp_den.clone().detach()
    # temp_num = torch.einsum('ab,bcde,fc->afde', mps_d[0], left_es[0], mps_u[0])
    
    temp_den = torch.einsum('egAa,efDd,AaBbCcDd,ghBb->fhCc', temp_den, mps_d[1], pure_dts[0], mps_u[1])
    temp_den = torch.einsum('egAa,efDd,AaBbCcDd,ghBb->fhCc', temp_den, mps_d[2], pure_dts[1], mps_u[2])

    den = torch.einsum('fhCc,fi,ijCc,hj', temp_den, mps_d[3], right_es[0], mps_u[3])

    # numerator
    # temp_num = torch.einsum('egAa,efDd,AaBbCicDd,ghBb->fhCic', temp_num, mps_d[1], impure_dts[0], mps_u[1])
    # temp_num = torch.einsum('egAia,efDd,AiaBbCcDd,ghBb->fhCc', temp_num, mps_d[2], impure_dts[1], mps_u[2])

    # num = torch.einsum('fhCc,fi,ijCc,hj', temp_num, mps_d[3], right_es[0], mps_u[3])

    # print(num.item(), den.item(), num / den)
    '''

    return num

def test_2(tps_tensors):

    mts, mts_conj = {}, {}
    for i, c in enumerate(tps.coords):
        # temp = self.merged_tensor(c)
        mts.update({c: tps_tensors[i]})
        mts_conj.update({c: tps_tensors[i].conj()})
    res = 0.0
    for i, c in enumerate(tps.coords):
        res += torch.einsum('abcde,abcde', mts[c], mts_conj[c])

    # return torch.einsum('abcde,abcde', tps_tensors[0], tps_tensors[0].conj())
    return res

In [None]:
tps_ts = []
for c in tps.coords:
    tps_ts.append(tps.merged_tensor(c).requires_grad_(True))

ene = tps.ctm_twobody_measure(tps_tensors=tps_ts, op=ham_xyz, ctm_tensors=ctm_ts)
# ene = test_2(tps_ts)
print(ene.item())

# print(tps_ts[0])

ene.backward(inputs=tps_ts)


In [None]:
print(tps_ts[0].grad)

In [None]:

lr = 1E-4
for i in range(10000):
    ene = tps.ctm_twobody_measure(tps_tensors=tps_ts, op=ham_xyz, ctm_tensors=ctm_ts)
    ene.backward(inputs=tps_ts)

    # update
    for i in range(2):
        tps_ts[i] = tps_ts[i] - lr
    if 0 == i % 200:
        print(i, c.item())

In [None]:
%history