In [None]:
from scipy.integrate import solve_ivp
import jax
from jax import jit,vmap,grad
from jax import random
import numpy as np
import jax.numpy as jnp
from jax import vjp

! pip install qutip
from qutip import *

N = 4 # set number of qubits as a global parameter
N1 = 10 #fourier series
key = random.PRNGKey(42)

Collecting qutip
  Using cached https://files.pythonhosted.org/packages/71/94/c79bd57137320657c3acc8e7a08d3c66dcf9d697e594e945fbc905dfec42/qutip-4.5.2.tar.gz
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: qutip
  Building wheel for qutip (PEP 517) ... [?25l[?25hdone
  Created wheel for qutip: filename=qutip-4.5.2-cp36-cp36m-linux_x86_64.whl size=12467997 sha256=7d3cd0fc3b0078573fa62042c1ccda38a0e3bea123d0e9c0777946fdb2786e22
  Stored in directory: /root/.cache/pip/wheels/ae/90/e9/f26fdecfb6c0e9d9d6f5fa564d16d26ff2bdfd8ad6e7a8a28a
Successfully built qutip
Installing collected packages: qutip
Successfully installed qutip-4.5.2


In [None]:
def generate_bach(U,N=100,*args):
    '''
    To generate a bach of data point
    '''
    pass

To perceed neural network, an ODE solver is needed. In our case is `scipy.integrate.solve_ivp`

In [None]:
def forward(t1,psi0,flat_p):
    '''
    solver for Schrodinger equation with initial condition of psi0 and Hamiltonian in the interval of [0,t1]
    use mathod of scipy
    '''
    Hz = lambda t,x: H(t,x,flat_p,t1)
    sol = solve_ivp(Hz,[0,t1],psi0,t_eval=[0,t1])
    return sol.t,sol.y

In [None]:
def initial():
    '''
    Return randomized parameters
    '''
    return np.random.normal(size=(N1*(N+1)+1,))

In [None]:
def n(i,x):
    '''
    number operator for n_i (x)
    '''
    return (x&2**i)/2**i

def unpackp(x):
  return jnp.array(x[:N1]) , jnp.array([x[(i+1)*N1:(i+2)*N1] for i in range(N)]), x[-1]

def flatten(x):
  return 0 # TODO

Our Hamiltonian is $$\begin{aligned}
\frac{H}{\hbar}=& \frac{\Omega(t)}{2} \sum_{i=1}^{N} \sigma_{x}^{(i)}-\sum_{i=1}^{N} \Delta_{i}(t) n_{i} \\
&+\sum_{i<j} \frac{V}{|i-j|^{6}} n_{i} n_{j}
\end{aligned}$$

In [None]:
def H(t,psi,flat_p,t1):
    '''
    Using xor to implement off diagonal element
    '''
    D, = jnp.shape(psi)
    res = jnp.zeros(D,dtype=jnp.complex64)

    w = 2*jnp.pi/t1

    u_omega, u_d, V = unpackp(flat_p)
    omega = jnp.sum(jnp.array([u_omega[i]*jnp.sin(w*(i+1)*t) for i in range(N1)]))
    delta = [jnp.sum(jnp.array([u_d[j,i]*jnp.sin(w*(i+1)*t) for i in range(N1)])) for j in range(N)]

    for x in range(D):
        Ci = psi[x]
        diag = - jnp.sum(jnp.array([ delta[i]*n(i,x) for i in range(N)]))\
        + jnp.sum(jnp.array([[V*n(i,x)*n(j,x)/jnp.abs(i-j)**6 if i<j else 0 for i in range(N)] for j in range(N)])) #diagonal part of hamiltonian
        res = jax.ops.index_add(res, (x,), diag*Ci)
        cast = jnp.array([x^2**i for i in range(N)]) # ^ is for xor, to calculate the flip operation
        res = jax.ops.index_add(res, cast, omega*Ci/2)
    return res

In [None]:
@jit
def loss(psi,p,t1,A,psi0):
    '''
    define the loss function
    '''
    omega = 2*jnp.pi/t1
    l1 = 1 - jnp.abs(jnp.dot(jnp.conjugate(psi),psi0))**2 # Overlap
    
    l2 = 0.5*jnp.sum(jnp.array([ (i*omega*p[i])**2 for i in range(N1)]))
    l3 = 0.5*jnp.sum(jnp.array([ (i*omega*p[i+N1])**2 for i in range(N1)]))
    l4 = 0.5*p[-1]**2
    
    return A[0]*l1+A[1]*l2+A[2]*l3+A[3]*l4

In [None]:
@jit
def grad_all(psi1,flat_p,t1,A,psi0):
    D, = jnp.shape(psi1)
    
    a0 = grad(loss, argnums=0)(psi1,flat_p,t1,A,psi0)
    pLt1 = jnp.array(jnp.dot(a0,H(t1,psi1,flat_p,t1)))
    print(pLt1)
    
    aug_state = jnp.concatenate([psi1, a0, -pLt1, jnp.zeros((2*N1+1,))])

    def unpack(x):
              # z , vjp_z   , vjp_t , vjp_args
      return x[:D], x[D:2*D], x[2*D], x[2*D+1:]
    def wrap_vjp(augment_state,t,parameters):
      Hz = lambda z : H(t,z,parameters,t1)
      Ht = lambda t0 : H(t0,augment_state[0],parameters,t1)
      Hp = lambda p : H(t,auement_state[0],p,t1)

      _, vjp_funz = vjp(Hz, augment_state[0])
      _, vjp_funt = vjp(Ht, t)
      _, vjp_funp = vjp(Hp, flat_p)

      return [- vjp_funz(augment_state[1]),- vjp_funp(augment_state[1]),- vjp_funt(augment_state[1])]
    
    def aug_dynamics(augment_state,t,parameters):
      unpacked_aug_state = unpack(augment_state)
      vjps = wrap_vjp(unpacked_aug_state,t,parameters)
      return jnp.concatenate([H(t,augment_state[0],parameters,t1), vjps[0], vjps[1], vjps[2]])

    aug_state0 = solve_ivp(aug_dynamics,[t1,0],aug_state)
    return aug_state0

Our Hamiltonian is $$\begin{aligned}
\frac{H}{\hbar}=& \frac{\Omega(t)}{2} \sum_{i=1}^{N} \sigma_{x}^{(i)}-\sum_{i=1}^{N} \Delta_{i}(t) n_{i} \\
&+\sum_{i<j} \frac{V}{|i-j|^{6}} n_{i} n_{j}
\end{aligned}$$

In [None]:
def check_with_qutip(psi0,flat_p,t1):
      q_psi = Qobj(psi0,dims=[[2 for _ in range(N)], [1 for _ in range(N)]])
      t = np.linspace(0,t1,100)
      w = 2*jnp.pi/t1

      u_omega, u_d, V = unpackp(flat_p)
      
      omega = lambda t,args : jnp.sum(jnp.array([u_omega[i]*jnp.sin(w*(i+1)*t) for i in range(N1)]))
      delta = [lambda t,args : - jnp.sum(jnp.array([u_d[j,i]*jnp.sin(w*(i+1)*t) for i in range(N1)])) for j in range(N)]
      h = []

      si = qeye(2)
      sx = sigmax()
      sy = sigmay()
      sz = sigmaz()
      ni = num(2)
      # time independent part
      h_list = []
      for i in range(N):
        for j in range(N):
          if j<= i: continue
          param = V/np.abs(i-j)**6
          op_list=[si for _ in range(N)]
          op_list[i] = ni
          op_list[j] = ni
          h_list.append(param*tensor(op_list))
      H0 = sum(h_list)
      h.append(H0)

      sx_list = []

      for n in range(N):
          op_list = []
          for m in range(N):
              op_list.append(si)

          op_list[n] = sx
          sx_list.append(tensor(op_list))

      # construct the hamiltonian
      H1 = sum([- 0.5 * sx_list[n] for n in range(N)])
      h.append([H1,omega])
      
      for i in range(N):
        n_list = [si for _ in range(N)]
        n_list[i] = ni
        H2= tensor(n_list)
        h.append([H2,delta[i]])
      
      output = mesolve(h, q_psi, t)

      return output


In [None]:
if __name__=='__main__':
  psi0 = np.random.normal(size=(16,)) + 1j*np.random.normal(size=(16,))
  flat_params = initial()
  # res = forward(2,psi0,flat_params)
  check_with_qutip(psi0,flat_params, 2)

In [None]:
z1 = jnp.array(res[1][:,1])
A = np.random.uniform(size=(4,))
grad_all(z1,flat_params,2,A,psi0)

In [None]:
psi0 = np.random.normal(size=(16,)) + 1j*np.random.normal(size=(16,))
qipsi = Qobj(psi0,dims=[[2,2,2,2], [1,1,1,1]])
H1 = (tensor(sigmax(), identity(2), identity(2),identity(2)) +
    tensor(identity(2), sigmax(), identity(2),identity(2)) +
     tensor(identity(2), identity(2), identity(2),sigmax()) +
      tensor(identity(2), identity(2), identity(2),sigmax()))
print(H1)
print(qipsi)
print(H1*qipsi)

In [None]:
for i in range(N):
  for j in range(N):
    if j<=i : continue
    print("({0},{1})".format(i,j))