In [198]:
import jax
import pennylane.numpy as jnp
import matplotlib.pyplot as plt 
import pennylane as qml 
from jax import random
import optax
key = random.PRNGKey(6032021)  # Random seed is explicit in JAX
from math import pi

In [209]:
N=3
layers = 3
params = 0.01*2*pi*random.uniform(key, shape=(2**N-1,))
#params_hea = 0.01*2*pi*random.uniform(key, shape=(3*N*layers,))
params_hea = jnp.array(jnp.random.rand(3*N*layers),requires_grad=True)

In [219]:
dev = qml.device('lightning.qubit', wires=N)
@qml.qnode(dev,diff_method='best')
def cost_paula(params):
	i = 1
	for qubit in range(N):
		for _ in range(2):
			for ctrl in range(qubit):
				qml.CNOT(wires=[ctrl, qubit])
				qml.RY(params[qubit], wires=qubit)
				i += 1  
	return qml.state()

@qml.qnode(dev,diff_method='best')
def cost_hea_x(params):
		for layer in range(layers):
				for i in range(N):
						qml.RZ(params[layer*N+i],i)
						qml.RX(params[layers*N+layer*N+i],i)
						qml.RZ(params[2*layers*N+layer*N+i],i)
				for i in range(N-1):
						qml.CZ(wires=[i,i+1])
				qml.CZ(wires=[N-1,0])
				qml.Barrier(range(0,N))
		return qml.state()

@qml.qnode(dev,diff_method='best')
def cost_hea_p(params):
		for layer in range(layers):
				for i in range(N):
						qml.RZ(params[layer*N+i],i)
						qml.RX(params[layers*N+layer*N+i],i)
						qml.RZ(params[2*layers*N+layer*N+i],i)
				for i in range(N-1):
						qml.CZ(wires=[i,i+1])
				qml.CZ(wires=[N-1,0])
				qml.Barrier(range(0,N))
		qml.QFT(wires=range(N))
		return qml.state()




In [220]:
params_hea

tensor([0.2190278 , 0.5928065 , 0.88033761, 0.03297469, 0.13580099,
        0.5635394 , 0.41945441, 0.19386535, 0.84728077, 0.60357165,
        0.22905637, 0.83265108, 0.0440028 , 0.4798372 , 0.44483311,
        0.08576009, 0.81641102, 0.0959478 , 0.89942232, 0.88177944,
        0.04877636, 0.2285974 , 0.5297159 , 0.00820806, 0.26895669,
        0.4684872 , 0.3638163 ], requires_grad=True)

In [221]:
cost_hea_x(params_hea)

array([-0.38137766+0.52749492j, -0.44037909+0.29562994j,
       -0.08347031+0.42248748j, -0.08958947-0.1071032j ,
       -0.25558071+0.11578102j,  0.04641504-0.01894248j,
       -0.02873916-0.07476908j,  0.04811643-0.0061303j ])

In [222]:
xmin = -2
xmax = +2

def Fourier( dim ):
    F = jnp.exp( 2j*jnp.pi*jnp.outer(jnp.arange(dim),
                                    jnp.arange(dim))/dim 
                                    ) / jnp.sqrt(dim) 
    return F 
def grid_op(N,xmin,xmax):

    dim = 2 ** N
    L   = xmax - xmin
    dx  = L / (2**N - 1)
    k  = jnp.linspace( -jnp.pi / dx, 
                        jnp.pi / dx, 
                        dim + 1)[:-1]
    
    x_values = xmin + dx * jnp.arange(dim)
    p_values = jnp.fft.fftshift(k)

    return x_values, p_values

def matrix_op(N,xmin,xmax):

    x_values, p_values = grid_op(N ,
                                xmin, 
                                xmax  )

    dim = 2 ** N
    F = Fourier(dim)

    Op = jnp.diag( x_values**2/2 ) \
        + F.T.conj()@jnp.diag( p_values**2/2 )@F
    
    return Op

In [225]:
A = matrix_op(N,xmin,xmax)

In [226]:
jnp.vdot(cost_hea_x(params_hea),jnp.dot(cost_hea_x(params_hea),A))

(4.243231150128113+4.440892098500626e-16j)

In [227]:
def final_cost(params_hea):
  state_x = cost_hea_x(params_hea)
  state_p = cost_hea_p(params_hea)
  return jnp.real(jnp.vdot(state_x,jnp.dot(state_x,A)))

In [228]:
final_cost(params_hea)

4.243231150128113

In [229]:
jit_cost = jax.jit(final_cost)

In [196]:
max_iterations = 100
conv_tol = 1e-04

opt = optax.adam(learning_rate=0.2)
optimizer=optax.adam(0.01)
opt_state = optimizer.init(params_hea)

In [197]:
jax.grad(jit_cost)(params_hea)

TypeError: Custom JVP rule must produce primal and tangent outputs with equal shapes and dtypes, but got complex64[8] and float32[8] respectively.

In [170]:
prev_energy,grads = jax.value_and_grad(jit_cost)(params_hea)
print(grads)

TypeError: Custom JVP rule must produce primal and tangent outputs with equal shapes and dtypes, but got complex64[8] and float32[8] respectively.

In [130]:
optimizer=optax.adam(0.01)
opt_state = optimizer.init(params_hea)

gd_cost_history = []
#grad_list_history = []
for n in range(max_iterations):
    prev_energy,grads = jax.value_and_grad(jit_cost)(params_hea)
    updates, opt_state = optimizer.update(grads, opt_state)
    params_hea = optax.apply_updates(params_hea, updates)
    energy = jit_cost(params_hea)
    
    gd_cost_history.append(prev_energy)

    # Calculate difference between new and old energies
    conv = jnp.abs(energy - prev_energy)
    
    # Calculate norm of gradient of cost function

    #gradient = gradient_params(cost,params)
    #grad_list_history.append(gradient)

    if n % 20 == 0:
        print(
            "Iteration = {:},  Energy = {:.6f}".format(n, energy))
    if conv <= conv_tol:
        break

Iteration = 0,  Energy = 0.189398
Iteration = 20,  Energy = 0.169887
Iteration = 40,  Energy = 0.143693


In [73]:
jnp.real(jnp.sum((cost_hea_x(params_hea)*jnp.conj(cost_hea_x(params_hea)))**2))

Array(0.985648, dtype=float32)

In [64]:
sum((cost_hea_p(params_hea)*jnp.conj(cost_hea_p(params_hea)))**2)

Array(0.12636326-5.4385985e-10j, dtype=complex64)

In [35]:
final_cost(params_hea)

TypeError: 'StateMP' object is not iterable