In [1]:
import pennylane as qml
import qutip as qtp
import qutipHam
import numpy as np
from jax import numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
from scipy.optimize import minimize
import timeit
import mod_initstates as init
import mod_timeevol as te

In [5]:
n = 4 # total number of qubits; Hamiltonian is defined only for 3 qbs right now
d = 2**n # dimension of composite system
cc = 1 # coupling constant

isingHam = qutipHam.H_ising(n, cc)
sbHam = qutipHam.H_sb(n, cc)

initHam = isingHam
tau = te.characteristic_time(initHam)
initHam = qml.pauli_decompose(isingHam.full()) 

tensor_initial_states, initial_states = init.init_states(n-1)
for i in range(len(initial_states)):
    initial_states[i] = jnp.array(initial_states[i])
initial_states = jnp.array(initial_states)

wireLists = [list(range(n))]
for i in range(1, n):
    wireL = list(range(n))
    wireL[0], wireL[i] = wireL[i], wireL[0]
    wireLists.append(wireL)

In [8]:
devRho = qml.device("default.mixed", wires=n)

@qml.qnode(devRho)
def purity(thetas, n, H, tau, rho, wireList):
    qml.QubitDensityMatrix(rho, wires=list(range(n)))
    qml.SpecialUnitary(thetas, wires=wireList)
    qml.ApproxTimeEvolution(H, tau, 1)        # If exact, set to 1. See ApproxTE documentation.
    qml.adjoint(qml.SpecialUnitary(thetas, wires=wireList))
    return qml.purity(0)

def cost(thetas, n, H, tau, rho, allWireLists):
    scores = []
    for i in range(n):
        scores.append(1 - purity(thetas, n, H, tau, rho, allWireLists[i]))                
    #return jnp.max(jnp.array(scores))
    return max(scores)

gradCost = jax.grad(cost, argnums=0)

In [9]:
rThetas = jnp.array(np.random.randn((d**2-1)))
start = timeit.default_timer()
opt = minimize(cost, rThetas, args = (n, initHam, tau, initial_states[0], wireLists), 
               method='BFGS', jac=gradCost)
end = timeit.default_timer()
print(end - start)

8.234920574999997


In [10]:
opt

      fun: 1.5031521582997698e-09
 hess_inv: array([[ 9.85188200e-01, -1.63346055e-02,  2.08022093e-03, ...,
         3.45297350e-03,  1.31219555e-02, -1.23355178e-02],
       [-1.63346055e-02,  1.00318324e+00, -2.66787995e-03, ...,
        -6.75443117e-03,  4.84811874e-03, -9.98569241e-03],
       [ 2.08022093e-03, -2.66787995e-03,  9.87152751e-01, ...,
         3.98078171e-04,  1.87053226e-03, -9.64406290e-04],
       ...,
       [ 3.45297350e-03, -6.75443117e-03,  3.98078171e-04, ...,
         9.87711078e-01, -2.66542756e-03, -8.82597898e-03],
       [ 1.31219555e-02,  4.84811874e-03,  1.87053226e-03, ...,
        -2.66542756e-03,  1.00480290e+00,  1.38175274e-02],
       [-1.23355178e-02, -9.98569241e-03, -9.64406290e-04, ...,
        -8.82597898e-03,  1.38175274e-02,  9.95939251e-01]])
      jac: array([ 3.43079952e-06, -1.63552123e-06, -2.47290087e-06, -2.17051367e-06,
        2.34778283e-06, -1.42914082e-06, -5.60579097e-06, -1.66657024e-06,
       -3.51920100e-06, -5.19175333e-

In [11]:
opt.x

array([-0.70240307,  0.16169422,  2.06790347, -0.10682974, -1.28590659,
        0.79555994,  0.69297983,  0.68864775,  0.39137178,  1.61841993,
       -2.12352278, -0.62670252,  0.99562151, -0.10020471,  0.48762087,
        0.30529369,  1.44516103,  1.13451758, -0.14724252, -1.50474362,
        0.68549644,  0.20888162, -0.38239007, -0.36302438, -0.44236673,
        1.61584477,  0.11320293,  0.6100624 ,  1.96017786, -0.54254091,
       -0.0278598 , -0.71349696, -0.78283525, -2.26392275, -1.44497501,
       -0.72431634,  0.1873764 ,  0.39095997, -0.58681142, -0.53896641,
        1.65890009, -0.49030067,  0.10062121, -0.27067275, -0.94756888,
        0.6906477 , -1.50054083,  0.14441646,  0.09905711,  0.74192588,
        0.69864828,  0.19682459,  0.43958194, -0.84275508,  0.98292579,
        0.30952141, -0.53059608, -0.39521393, -0.23169053,  1.21030688,
        2.5213386 , -0.29565929,  0.68709123, -2.20525743, -0.65634924,
       -0.84271316,  0.52348181, -0.41354269,  0.93145141,  0.08