In [None]:
import pennylane as qml
import qutip as qtp
import qutipHam
import numpy as np
from jax import numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
from scipy.optimize import minimize
import timeit
import mod_initstates as init
import mod_timeevol as te

In [None]:
n = 3 
d = 2**n 
cc = 1

isingHam = qutipHam.H_ising(n, cc)
sbHam = qutipHam.H_sb(n, cc)

initHam = isingHam
tau = te.characteristic_time(initHam)
#initHam = qml.pauli_decompose(isingHam.full()) 

tensor_initial_states, initial_states = init.init_states(n-1)
for i in range(len(initial_states)):
    initial_states[i] = jnp.array(initial_states[i])
initial_states = jnp.array(initial_states)

wireLists = [list(range(n))]
for i in range(1, n):
    wireL = list(range(n))
    wireL[0], wireL[i] = wireL[i], wireL[0]
    wireLists.append(wireL)

In [None]:
devRho = qml.device("default.mixed", wires=n)

@qml.qnode(devRho)
def purity(thetas, n, H, tau, rho, wireList):
    qml.QubitDensityMatrix(rho, wires=list(range(n)))
    qml.SpecialUnitary(thetas, wires=wireList)
    qml.ApproxTimeEvolution(H, tau, 1)        # If exact, set to 1. See ApproxTE documentation.
    qml.adjoint(qml.SpecialUnitary(thetas, wires=wireList))
    return qml.purity(0)

def cost(thetas, n, H, tau, rho, allWireLists):
    scores = []
    for i in range(n):
        decH = qml.pauli_decompose(H, wire_order=allWireLists[i])
        scores.append(1 - purity(thetas, n, decH, tau, rho, allWireLists[i]))                
    return max(scores)

gradCost = jax.grad(cost, argnums=0)

In [None]:
rThetas = jnp.array(np.random.randn((d**2-1)))
start = timeit.default_timer()
opt = minimize(cost, rThetas, args = (n, initHam.full(), tau, initial_states[0], wireLists), 
               method='BFGS', jac=gradCost)
end = timeit.default_timer()
print(end - start)