In [None]:
from scipy.integrate import solve_ivp
import jax
from jax import jit,vmap,grad
from jax import random
import numpy as np
import jax.numpy as jnp
from jax import vjp

from qutip import *

N = 4 # set number of qubits as a global parameter
N1 = 10 #fourier series
key = random.PRNGKey(42)

In [None]:
def generate_bach(U,N=100,*args):
    '''
    To generate a bach of data point
    '''
    pass

def normalize(psi):
    '''
    Normalize any given wavefunction
    '''
    return psi/jnp.linalg.norm(psi)

To perceed neural network, an ODE solver is needed. In our case is `scipy.integrate.solve_ivp`

In [None]:
def forward(t1,psi0,flat_p,method='function'):
    '''
    solver for Schrodinger equation with initial condition of psi0 and Hamiltonian in the interval of [0,t1]
    use mathod of scipy
    '''
    if method == 'function':
        Hz = lambda t,x: -1j*H(t,x,flat_p,t1)
    elif method == 'matrix':
        Hz = lambda t,x: -1j*Hmat(t,x,flat_p,t1)
    else:
        raise NameError("wrong solver")
    
    sol = solve_ivp(Hz,[0,t1],psi0,t_eval=[0,t1])
    return sol.t,sol.y

In [None]:
def initial(key1 = None):
    '''
    Return randomized parameters
    '''
    if (key1).all() == None:
        return np.random.normal(size=(N1*(N+1)+1,))
    else:
        return random.normal(key1,shape=(N1*(N+1)+1,))

In [None]:
def n(i,x):
    '''
    number operator for n_i (x)
    '''
    return (x&2**i)/2**i

def unpackp(x):
    return jnp.array(x[:N1]) , jnp.array([x[(i+1)*N1:(i+2)*N1] for i in range(N)]), float(x[-1])

def flatten(x):
    return 0 # TODO

Our Hamiltonian is $$\begin{aligned}
\frac{H}{\hbar}=& \frac{\Omega(t)}{2} \sum_{i=1}^{N} \sigma_{x}^{(i)}-\sum_{i=1}^{N} \Delta_{i}(t) n_{i} \\
&+\sum_{i<j} \frac{V}{|i-j|^{6}} n_{i} n_{j}
\end{aligned}$$

In [None]:
def H(t,psi,flat_p,t1):
    '''
    Using xor to implement off diagonal element
    '''
    D, = jnp.shape(psi)
    res = jnp.zeros(D,dtype=jnp.complex64)

    w = 2*jnp.pi/t1

    u_omega, u_d, V = unpackp(flat_p)
    omega = jnp.sum(jnp.array([u_omega[i]*jnp.sin(w*(i+1)*t) for i in range(N1)]))
    delta = [jnp.sum(jnp.array([u_d[j,i]*jnp.sin(w*(i+1)*t) for i in range(N1)])) for j in range(N)]
    
    for x in range(D):
        Ci = psi[x]
        diag = - jnp.sum(jnp.array([ delta[i]*n(i,x) for i in range(N)]))\
        + jnp.sum(jnp.array([[V*n(i,x)*n(j,x)/jnp.abs(i-j)**6 if i<j else 0 for i in range(N)] for j in range(N)])) #diagonal part of hamiltonian
        res = jax.ops.index_add(res, (x,), diag*Ci)
        cast = jnp.array([x^2**i for i in range(N)]) # ^ is for xor, to calculate the flip operation
        res = jax.ops.index_add(res, cast, omega*Ci/2)
    return res

For speed we use dense matrix to represent Hamiltonian

In [None]:
n0 = jnp.array([jnp.diag(jnp.array([n(i,x) for x in range(2**N)])) for i in range(N)]) #number operator matrix form n0[i] = n_i

print(n0.shape)
#print(jnp.matmul(n0[1],n0[2]))

def H_independent():
    res = jnp.zeros((2**N,2**N))
    for i in range(N):
        for j in range(N):
            if j<= i: continue
            params = 1/np.abs(i-j)**6
            res += params*(jnp.dot(n0[i],n0[j]))
    return res
H1 = H_independent()
print(jnp.trace(H1))

f = lambda x,y,i : 1 if y==x^2**i else 0
H2 = sum([jnp.array([[f(x,y,i) for x in range(2**N)] for y in range(2**N)]) for i in range(N)]) #checked

def Hmat(t,psi,flat_p,t1):
    '''
    Using dense matrix to represent Hamiltonian
    '''
    D, = jnp.shape(psi)
    res = jnp.zeros(D,dtype=jnp.complex64)
    
    w = 2*jnp.pi/t1

    u_omega, u_d, V = unpackp(flat_p)
    
    ft =jnp.array([jnp.sin(w*(i+1)*t) for i in range(N1)])
    
    omega = jnp.dot(u_omega,ft)
    delta = u_d@ft
    
    return (V*H1 + 0.5* omega* H2 - jnp.einsum('i,ijk->jk',delta,n0))@psi
    


In [None]:
@jit
def loss(psi,p,t1,A,psi0):
    '''
    define the loss function
    '''
    omega = 2*jnp.pi/t1
    l1 = 1 - jnp.abs(jnp.dot(jnp.conjugate(psi),psi0))**2 # Overlap
    
    l2 = 0.5*jnp.sum(jnp.array([ (i*omega*p[i])**2 for i in range(N1)]))
    l3 = 0.5*jnp.sum(jnp.array([ (i*omega*p[i+N1])**2 for i in range(N1)]))
    l4 = 0.5*p[-1]**2
    
    return A[0]*l1+A[1]*l2+A[2]*l3+A[3]*l4

In [None]:
@jit
def grad_all(psi1,flat_p,t1,A,psi0):
    D, = jnp.shape(psi1)
    
    a0 = grad(loss, argnums=0)(psi1,flat_p,t1,A,psi0)
    pLt1 = jnp.array(jnp.dot(a0,H(t1,psi1,flat_p,t1)))
    print(pLt1)
    
    aug_state = jnp.concatenate([psi1, a0, -pLt1, jnp.zeros((2*N1+1,))])

    def unpack(x):
              # z , vjp_z   , vjp_t , vjp_args
        return x[:D], x[D:2*D], x[2*D], x[2*D+1:]
    def wrap_vjp(augment_state,t,parameters):
        Hz = lambda z : H(t,z,parameters,t1)
        Ht = lambda t0 : H(t0,augment_state[0],parameters,t1)
        Hp = lambda p : H(t,auement_state[0],p,t1)

        _, vjp_funz = vjp(Hz, augment_state[0])
        _, vjp_funt = vjp(Ht, t)
        _, vjp_funp = vjp(Hp, flat_p)

        return [- vjp_funz(augment_state[1]),- vjp_funp(augment_state[1]),- vjp_funt(augment_state[1])]
    
    def aug_dynamics(augment_state,t,parameters):
        unpacked_aug_state = unpack(augment_state)
        vjps = wrap_vjp(unpacked_aug_state,t,parameters)
        return jnp.concatenate([H(t,augment_state[0],parameters,t1), vjps[0], vjps[1], vjps[2]])

    aug_state0 = solve_ivp(aug_dynamics,[t1,0],aug_state)
    return aug_state0

Our Hamiltonian is 

\begin{aligned}
\frac{H}{\hbar}=& \frac{\Omega(t)}{2} \sum_{i=1}^{N} \sigma_{x}^{(i)}-\sum_{i=1}^{N} \Delta_{i}(t) n_{i} \\
&+\sum_{i<j} \frac{V}{|i-j|^{6}} n_{i} n_{j}
\end{aligned}


In [None]:
def check_with_qutip(psi0,flat_p,t1):
    q_psi = Qobj(psi0,dims=[[2,2,2,2], [1 ,1,1,1]])

    t = np.linspace(0,t1,100)
    w = 2*jnp.pi/t1

    u_omega, u_d, V = unpackp(flat_p)
    
    ft = lambda t: jnp.array([jnp.sin(w*(i+1)*t) for i in range(N1)])
    omega = lambda t,args: jnp.dot(u_omega,ft(t))
    delta = lambda t,args: -u_d@ft(t)
    delta_func = [lambda t,args: delta(t,args)[i] for i in range(N1)]
#     omega = lambda t,args : jnp.sum(jnp.array([u_omega[i]*jnp.sin(w*(i+1)*t) for i in range(N1)]))
#     delta_func = [lambda t,args : - jnp.sum(jnp.array([u_d[j,i]*jnp.sin(w*(i+1)*t) for i in range(N1)])) for j in range(N)]
    h = []

    si = qeye(2)
    sx = sigmax()
    sy = sigmay()
    sz = sigmaz()
    ni = num(2)
    # time independent part
    h_list = []
    for i in range(N):
        for j in range(N):
            if j<= i: continue
            param = V/np.abs(i-j)**6
            op_list=[si for _ in range(N)]
            op_list[i] = ni
            op_list[j] = ni
            h_list.append(param*tensor(op_list))
    H0 = sum(h_list)
#     print(H0.tr()/V)
    h.append(H0)
    
    #time dependent part
    sx_list = []

    for n in range(N):
        op_list = []
        for m in range(N):
            op_list.append(si)

        op_list[n] = sx
        sx_list.append(tensor(op_list))

    # construct the hamiltonian
    H1 = sum([0.5 * sx_list[n] for n in range(N)])
    h.append([H1,omega])

    for i in range(N):
        n_list = [si for _ in range(N)]
        n_list[i] = ni
        H2= tensor(n_list)
        h.append([H2,delta_func[i]])
#     print(q_psi)
    output = mesolve(h, q_psi, t, progress_bar = True)

    return output
#     pass

In [None]:
if __name__=='__main__':
    keyre , keyim, keyinit = random.split(key,num=3)
    t1 = 2
    psi0 = random.normal(keyre,shape=(16,)) + 1j*random.normal(keyim,shape=(16,))
    psi0 = normalize(psi0)
#     print(psi0.shape)
    flat_params = initial(keyinit)
    res = forward(t1,psi0,flat_params,method='matrix')
#     print(res)
    res0 = check_with_qutip(np.array(psi0),flat_params, t1)
#     %time H(2,psi0,flat_params,4)
#     %time Hmat(2,psi0,flat_params,4)
#     print(flat_params.T@flat_params)

In [None]:
print(res[1])
print(res0.states[-1])

In [None]:
q_psi = Qobj(np.array(psi0),dims=[[2,2,2,2], [1,1,1,1]])
print(q_psi)

In [None]:
res = forward(t1,psi0,flat_params,method='matrix')

In [None]:
si = qeye(2)
sx = sigmax()
sy = sigmay()
sz = sigmaz()
ni = num(2)
print(si)