In [20]:
import torch
from scipy.sparse.linalg import eigsh
import numpy as np
import yastn
import yastn.tn.fpeps as fpeps
import yastn.operators as op_mod
from tqdm import tqdm
from functions_fpeps import *
from functions_ed import *
import matplotlib.pyplot as plt

NUM_THREADS = 8 
torch.set_num_threads(NUM_THREADS)

In [21]:
L_x, L_y = 2, 2
t = 1.0
U = 10.0
mu = 0.0
learning_rate = 0.01
n_var_steps   = 10
D_target = 4
chi = 5 * D_target
output_no = 8
max_ctm_sweeps = 20
tol_ctm = 1e-10

config_kwargs = {
    "backend": "torch",
    "default_dtype": "float64",
    "device": "cpu"
}

In [22]:
ops = yastn.operators.SpinfulFermions(sym="U1xU1", **config_kwargs)

I = ops.I()
c_up, cdag_up, n_up = ops.c('u'), ops.cp('u'), ops.n('u')
c_dn, cdag_dn, n_dn = ops.c('d'), ops.cp('d'), ops.n('d')

state_up = ops.vec_n((1, 0))
state_dn = ops.vec_n((0, 1))

geometry = fpeps.SquareLattice(dims=(L_x, L_y), boundary="obc")

vectors = {}
for (x, y) in geometry.sites():
    if (x + y) % 2 == 0:
        vectors[(x, y)] = state_up
    else:
        vectors[(x, y)] = state_dn

psi = fpeps.product_peps(geometry=geometry, vectors=vectors)

for site in geometry.sites():
    A = psi[site]
    A.requires_grad_(True)

In [24]:
env_ctm = fpeps.EnvCTM(psi, init="eye")
opts_svd_ctm = {
    "D_total": chi,
    "tol": 1e-10,
}

ctm_iter = env_ctm.ctmrg_(
    opts_svd=opts_svd_ctm,
    iterator=True,
    max_sweeps=50,
)

tol_energy = 1e-7
sites_list = list(geometry.sites())
bonds_list = list(geometry.bonds())
num_sites = len(sites_list)
num_bonds = len(bonds_list)
z_eff = 2.0 * num_bonds / num_sites

on_site_op = (n_up - I / 2) @ (n_dn - I / 2)
ev_loc = env_ctm.measure_1site(on_site_op)
ev_loc_mean = mean_dict_values(ev_loc)

ev_cdagc_up = env_ctm.measure_nn(cdag_up, c_up)
ev_cdagc_dn = env_ctm.measure_nn(cdag_dn, c_dn)
ev_cdagc_up_mean = mean_dict_values(ev_cdagc_up)
ev_cdagc_dn_mean = mean_dict_values(ev_cdagc_dn)

E_kin_site = - z_eff * t * (ev_cdagc_up_mean + ev_cdagc_dn_mean)
E_loc_site = U * ev_loc_mean

energy = E_kin_site + E_loc_site

print(float(energy.detach()))

-2.5
