In [1]:
import jax.numpy as np
from Hubbard.plot import *

N = 20
R0 = np.array([3, 3, 7.2])
G = HubbardGraph(N,
                 R0=R0,
                 lattice=np.array([4], dtype=int),
                 band=1,
                 dim=1,
                 avg=1 / 2,
                 sparse=True,
                 symmetry=True)
v = True
fixed = False
A, U = G.singleband_solution(True)

Utarget = np.mean(U)
nnt = G.nn_tunneling(A)
xlinks, ylinks, nntx, nnty = G.xy_links(nnt)
if fixed:
    Vtarget = np.mean(np.real(np.diag(A)))
else:
    Vtarget = None

def cost_func(offset: np.ndarray) -> float:
    c = G.cbd_cost_func(offset, (xlinks, ylinks),
                        (Vtarget, Utarget, nntx, nnty), v)
    return c

DVR: dx=[0.15]w is set.
DVR: n=[20] is set.
DVR: R0=[3.]w is set.
['x']-reflection symmetry is used.
param_set: trap parameter V0=104.52kHz w=1000nm
Triangular lattice size adjust to: [4 1]
lattice: dx is fixed at: [0.15 0.   0.  ]w
lattice: lattice shape is square
lattice: Full lattice sizes: [4]
lattice: lattice constants: [1.52 1.69 1.52]w
DVR: dx=[0.15]w is set.
DVR: n=[35] is set.
DVR: R0=[5.28]w is set.
H_op: n=[35] dx=[0.15]w p=[1] Gaussian sparse diagonalization is enabled. Lowest 4 states are to be calculated.
H_op: n=[35] dx=[0.15]w p=[1] Gaussian operator constructed.
H_solver: diagonalize sparse hermitian matrix.
H_solver: Gaussian Hamiltonian solved. Time spent: 0.00s.
H_solver: eigenstates memory usage: 0.00 MiB.
H_op: n=[35] dx=[0.15]w p=[-1] Gaussian sparse diagonalization is enabled. Lowest 4 states are to be calculated.
H_op: n=[35] dx=[0.15]w p=[-1] Gaussian operator constructed.
H_solver: diagonalize sparse hermitian matrix.
H_solver: Gaussian Hamiltonian solved. Ti

In [4]:
from jax import grad, jit, vmap

v01 = np.ones(G.Nindep)
v02 = G.trap_centers[G.reflection[:, 0]]
# Bound trap depth variation
b1 = list((0.9, 1.1) for i in range(G.Nindep))
# Bound lattice spacing variation
xbonds = tuple(
    (v02[i, 0] - 0.05, v02[i, 0] + 0.05) for i in range(G.Nindep))
if G.lattice_dim == 1:
    ybonds = tuple((0, 0) for i in range(G.Nindep))
else:
    ybonds = tuple(
        (v02[i, 1] - 0.05, v02[i, 1] + 0.05) for i in range(G.Nindep))
nested = tuple((xbonds[i], ybonds[i]) for i in range(G.Nindep))
b2 = list(item for sublist in nested for item in sublist)

v0 = np.concatenate((v01, v02.reshape(-1)))
bonds = tuple(b1 + b2)

cg = grad(cost_func)
cj = jit(cost_func)