In [None]:
import importlib
import Chebyshev.Chebyshev as Cbs

importlib.reload(Cbs)



In [None]:
n = 10

# compute correlator
N = 2**n
dt = 1e-2


# calculate correlators C = <psi| e^iHt X e^-iHt X |psi>
def Cs_chebyshev(t_matrix, psi = psi):
    a,b = t_matrix.shape
    print(a, b)
    t = t_matrix.reshape(-1)*(N-1)*dt
    psis = np.zeros((a*b,psi.shape[0]), dtype=np.complex128)
    for i,tt in enumerate(t):
        psis[i,:] = expm_multiply(-1j * H * tt, psi)
    corr = np.einsum('j, ij -> i', psi.conj(), psis) * np.exp(1j * E0 * t)
    return corr.reshape(a, b)


#print(Cs_chebyshev(20))


#use numpy.save, numpy.savez, hdf5 file


# evolve states in time
psis = expm_multiply(-1j * H,
                     psi,
                     start=0,
                     stop=N*dt,
                     num=N,
                     endpoint=False)

# exact function values
#xs = np.linspace(0, N*dt, 2**n, endpoint=False)
xs = np.arange(N)*dt

# calculate correlators C = <psi| e^iHt X e^-iHt X |psi>
Cs = np.einsum('j, ij -> i', psi.conj(), psis) * np.exp(1j * E0 * np.arange(N) * dt)
func_vals = Cs

# get MPS from Chebyshev interpolation
chi = 20
As, _, _, _, _ = Cbs.Chebyshev_interpolation(Cs_chebyshev, # function to be interpolated
                                 func_vals,
                                 L=n,              # number of MPS tensors
                                 chi=chi)           # (half of the) bond dimension
# also plot points on which the function was evaluated
Cheb_xs = 0.5 * np.sort(np.arange(2)[:, None] + Cbs.c_a_N(np.arange(chi), chi-1)[None, :])
Cheb_vals = Cs_chebyshev(Cheb_xs)
Cheb_xs *= (N-1)*dt

# contract MPS to reconstruct interpolated function
func_interp = Cbs.interpolate_singlesite(As)

# plot results
fig, axs = plt.subplots(ncols=2, dpi=300, figsize=(8,4), sharex=True)

axs[0].plot(xs, np.abs(func_vals))
axs[0].plot(xs, np.abs(func_interp), '--')
axs[0].plot(Cheb_xs, np.abs(Cheb_vals), 'x', ms=4, color='0.3')

axs[1].plot(xs, func_vals.real, label='Real part (exact)')
axs[1].plot(xs, func_interp.real, '--', label='Real part (interp.)')
axs[1].plot(Cheb_xs, Cheb_vals.real, 'x', ms=4, color='0.3')
axs[1].plot(xs, func_vals.imag, label='Imaginary part (exact)')
axs[1].plot(xs, func_interp.imag, '--', label='Real part (interp.)')
axs[1].plot(Cheb_xs, Cheb_vals.imag, 'x', ms=4, color='0.3')
axs[1].legend()

plt.tight_layout()
plt.show()


### Error vs eval

In [None]:
err_max = []
err_2 = []
evals = []
func_interp_list = []

chi_list = [i for i in range(2,10)]
chi_list.extend([i for i in range(10,40,5)])
chi_list.append(70)

for chi in chi_list:
# get MPS from Chebyshev interpolation
    _, eval, errmax, err2, func_interp = Cbs.Chebyshev_interpolation(Cs_chebyshev, # function to be interpolated
                                     func_vals,
                                     L=n,              # number of MPS tensors
                                     chi=chi)           # (half of the) bond dimension
    err_max.append(errmax)
    err_2.append(err2)
    evals.append(eval)
    func_interp_list.append(func_interp)




In [None]:
chi_list.append(200)
chi_list.append(400)
for chi in chi_list[-2:]:
    _, eval, errmax, err2, func_interp = Cbs.Chebyshev_interpolation(Cs_chebyshev, # function to be interpolated
                                     func_vals,
                                     L=n,              # number of MPS tensors
                                     chi=chi)           # (half of the) bond dimension
    err_max.append(errmax)
    err_2.append(err2)
    evals.append(eval)
    func_interp_list.append(func_interp)


In [None]:
if 'err_max' in locals() or 'err_max' in globals():
    np.save("err_max", err_max)
    np.save("err_2", err_2)
    np.save("evals", evals)
    np.save 
else:
    err_max = np.load("err_max")
    err_2 = np.load("err_2")
    evals = np.load("evals")

## plot of error vs eval

In [None]:
plt.plot(np.array(evals), np.array(err_max), ".-", label = r"$\epsilon_\infty$")
plt.plot(np.array(evals), np.array(err_2), ".-", label = r"$\epsilon_2$")

# Annotate the points with their chi values
for i, chi in enumerate(chi_list):
    # Annotate error_max (first curve)
    plt.annotate(f"{chi}", (evals[i], err_max[i]),
                 textcoords="offset points", xytext=(5, 5), fontsize=8,
                 arrowprops=dict(arrowstyle="-", lw=0.5),
                 ha='center')

    # Annotate error_2 (second curve)
    plt.annotate(f"{chi}", (evals[i], err_2[i]),
                 textcoords="offset points", xytext=(5, -10), fontsize=8,
                 arrowprops=dict(arrowstyle="-", lw=0.5),
                 ha='center')
    

plt.yscale("log")
plt.xlabel(r"$evals(\chi)$")
plt.ylabel(r"error")
plt.legend()
plt.title("Error vs # of evaluations")
plt.tight_layout()
plt.show()

In [None]:
# Determine the number of rows and columns
n_plots = len(func_interp_list)
ncols = 2
nrows = (n_plots + 1) // 2  # Ensure enough rows for all plots

# Create subplots with a grid layout
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8, 2 * nrows), sharex=True)

# Flatten `axs` for easier indexing (handles both 2D and 1D cases)
axs = np.array(axs).T.flatten()

for i,element in enumerate(func_interp_list):
    axs[i].plot(xs, func_vals, '-')
    axs[i].plot(xs, element, '--')
    axs[i].set_title(f"Plot with error: {err_2[i]:.2f}")

for j in range(len(func_interp_list), len(axs)):
    fig.delaxes(axs[j])  # Remove empty subplot

fig.tight_layout()