# Lindblad evolution

In [6]:
# imports

import numpy as np
import matplotlib.pyplot as plt
#import scipy as sp
from jax import numpy as jnp
import jax.scipy as sp

from src import visuals as vis
%matplotlib qt

# definitions

rho_x = jnp.array([[0,1],[1,0]])
rho_y = jnp.array([[0,-1j],[1j,0]])
rho_z = jnp.array([[1,0],[0,-1]])

rho_m = jnp.array([[0,1],[0,0]])
rho_p = jnp.array([[0,0],[1,0]])

psi_0 = jnp.array([1,0])
psi_1 = jnp.array([0,1])

def H(w, g):
    return w * rho_z + g * rho_x

def unitary_step(H, state, dt):
    U = sp.linalg.expm(-1j * H * dt)
    new_state = U @ state
    return new_state

def lindbladian(rho, H, L, gamma, dt):
    """hbar = 1"""
    Ldag = L.conj().T
    drho = -1j * (H @ rho - rho @ H) + gamma * (L @ rho @ Ldag - 0.5 * ((Ldag @ L) @ rho + rho @ (Ldag @ L) ))
    rho_new = rho + drho*dt
    return rho_new

rho = np.outer(np.conj(psi_0), psi_0)



# simulation

N = 100
T = 1

times = np.linspace(0,T,N)
s_x = np.zeros(N, dtype='complex')
s_y = np.zeros(N, dtype='complex')
s_z = np.zeros(N, dtype='complex')
tr = np.zeros(N)
pur = np.zeros(N)
fid = np.zeros(N)

rho_init = np.outer(np.conj(psi_0), psi_0)

w = 5
g = 4
L = rho_p
gamma = 1
dt = T/N

rho = rho_init
for i in range(N):
    s_x[i] = np.trace(rho @ rho_x)
    s_y[i] = np.trace(rho @ rho_y)
    s_z[i] = np.trace(rho @ rho_z)
    
    tr[i] = np.trace(rho)
    pur[i] = np.trace(rho @ rho)
    fid[i] = np.trace(rho @ rho_init)
    
    rho = lindbladian(rho, H(w,g), L, gamma, dt)

# vis.plotBlochSphereTrajectory(s_x, s_y, s_z, 100)

  tr[i] = np.trace(rho)
  pur[i] = np.trace(rho @ rho)
  fid[i] = np.trace(rho @ rho_init)


In [2]:
fig = plt.figure()
ax = fig.add_subplot(111)

plotting_spins = 1
if plotting_spins:
    ax.plot(times, s_x.real, c='C1', label='x')
    #ax.plot(s_x.imag, c='C1', ls='-.')
    ax.plot(times, s_y.real, c='C2', label='y')
    #ax.plot(s_y.imag, c='C2', ls='-.')
    ax.plot(times, s_z.real, c='C3', label='z')
    #ax.plot(s_z.imag, c='C3', ls='-.')
    ax.hlines([0], [0],[times[-1]], colors=['black'], linestyles=[':'])

plotting_markers = not plotting_spins
if plotting_markers:
    ax.plot(times, tr, ls='--', c='black', label='Trace')
    ax.plot(times, pur, ls='--', c='pink', label='Purity')
    ax.plot(times, fid, ls='--', c='cyan', label='Fidelity')

ax.legend()
fig.show()


# ML version of control

In [3]:
from flax import nnx
import optax
import jax
from jax import numpy as jnp


In [4]:
class Control(nnx.Module):
  def __init__(self, din, dmid1, dmid2, dout, rngs: nnx.Rngs):
    self.linear_in = nnx.Linear(din, dmid1, rngs=rngs)
    self.linear_mid = nnx.Linear(dmid1, dmid2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid2, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.linear_mid(nnx.relu(self.linear_in(x))))
    return self.linear_out(x)
  

din = 1 # t 
dmid1 = 3
dmid2 = 3
dout = 2 # hamiltonian parameters as functions of t

control_model = Control(din, dmid1, dmid2, dout, rngs=nnx.Rngs(0))

In [5]:
def unitary_evolution(psi, ws, gs, dt):
    nsteps = len(ws)
    for i in range(nsteps):
        psi = unitary_step(H(ws[i], gs[i]), psi, dt)
    return psi

In [6]:
optimizer = nnx.Optimizer(control_model, optax.adam(1e-3))  # Reference sharing.

@nnx.jit  # Automatic state management for JAX transforms.
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    # hyperparameters:
    T = 1
    nsteps = 100

    # evolution
    times = jnp.linspace(0,T,nsteps)
    dt = T/nsteps

    psi = x
    for i in range(nsteps):
      params = model(jnp.array([times[i]]))
      psi_new = unitary_step(H(params[0], params[1]), psi, dt)
      psi = psi_new

    return 1 - jnp.abs(jnp.dot(jnp.conj(psi), y))**2

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # in-place updates

  return loss

In [7]:
x_ex = psi_0
y_ex = psi_1

ns = 300
L = np.zeros(ns)

for i in range(ns):
    L[i] = train_step(control_model, optimizer, x_ex, y_ex)

In [8]:
plt.plot(L)

[<matplotlib.lines.Line2D at 0x16f5bc32860>]

In [67]:
N = 100

ts = jnp.linspace(0,1,N)
ws = np.zeros(N)
gs = np.zeros(N)
for i in range(N):
    params = control_model(ts[i:i+1])
    ws[i] = params[0]
    gs[i] = params[1]

In [76]:
plt.plot(ws)
plt.plot(gs)

[<matplotlib.lines.Line2D at 0x254c9378a00>]

In [73]:
plt.plot(L)

[<matplotlib.lines.Line2D at 0x254c95d9e40>]

In [84]:
psi = x_ex
for i in range(N):
    psi = unitary_step(H(ws[i], gs[i]), psi, dt)

print(psi.round(2))
print(y_ex)
print(jnp.abs(jnp.dot(jnp.conj(psi), y_ex))**2)


[ 0.  -0.j -0.02-1.j]
[0 1]
1.0000007


# Lindblad evolution, ML control

In [1]:
import numpy as np
import matplotlib.pyplot as plt
#import scipy as sp
from jax import numpy as jnp
import jax.scipy as sp

from src import visuals as vis
%matplotlib qt

from flax import nnx
import optax
import jax
from jax import numpy as jnp


In [157]:
rho_x = jnp.array([[0,1],[1,0]])
rho_y = jnp.array([[0,-1j],[1j,0]])
rho_z = jnp.array([[1,0],[0,-1]])

rho_m = jnp.array([[0,1],[0,0]])
rho_p = jnp.array([[0,0],[1,0]])

psi_0 = jnp.array([1,0])
psi_1 = jnp.array([0,1])

def H(w, g):
    return w * rho_z + g * rho_x

In [158]:
@jax.jit
def lindbladian(rho, H, L, gamma, dt):
    """hbar = 1"""
    Ldag = L.conj().T
    weight = np.random.uniform(0.5,1.5) ################################## ADDING RANDOMNESS - NOT PRECISELY GKSL EQ HERE ##################################
    drho = -1j * (H @ rho - rho @ H) + weight * gamma * (L @ rho @ Ldag - 0.5 * ((Ldag @ L) @ rho + rho @ (Ldag @ L) ))
    rho_new = rho + drho*dt
    return rho_new

In [159]:
######################################### PROBLEM DEFINITION HERE ###########################################

T = 1
nsteps = 100
times = jnp.linspace(0,T,nsteps)
dt = T/nsteps

rho_init = np.outer(np.conj(psi_0), psi_0)
#psi_final = 1/jnp.sqrt(2) * (psi_0 + psi_1)
psi_final = psi_1
rho_final = np.outer(np.conj(psi_final), psi_final)

diss = rho_x
gamma = 1

In [160]:
class Control(nnx.Module):
  def __init__(self, din, dmid1, dmid2, dout, rngs: nnx.Rngs):
    self.linear_in = nnx.Linear(din, dmid1, rngs=rngs)
    self.linear_mid = nnx.Linear(dmid1, dmid2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid2, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.gelu(self.linear_mid(nnx.relu(self.linear_in(x))))
    return self.linear_out(x)
  

din = 1 # t 
dmid1 = 5
dmid2 = 5
dout = 2 # hamiltonian parameters as functions of t

control_model_lindblad = Control(din, dmid1, dmid2, dout, rngs=nnx.Rngs(0))

In [161]:
optimizer_lindblad = nnx.Optimizer(control_model_lindblad, optax.adam(1e-3))  # Reference sharing.

@nnx.jit  # Automatic state management for JAX transforms.
def train_step_lindblad(control_model_lindblad, optimizer_lindblad, x, y):
  #y_sqrt = sp.linalg.sqrtm(y)

  def loss_fn(control_model_lindblad):
    
    rho = x
    for i in range(nsteps):
      params = control_model_lindblad(jnp.array([times[i]]))
      rho_new = lindbladian(rho, H(params[0], params[1]), diss, gamma, dt)
      rho = rho_new

    ############################ SPECIAL CASE: PURE FINAL STATE ##############################
    y_dag = jnp.conj(y).T
    fid = jnp.real(y_dag @ rho @ y)
    return (1-fid)**2

  loss, grads = nnx.value_and_grad(loss_fn)(control_model_lindblad)
  optimizer_lindblad.update(grads)  # in-place updates

  return loss

In [162]:
################################ TRAINING #######################################

nepochs = 8000
L = np.zeros(nepochs)

for i in range(nepochs):
    L[i] = train_step_lindblad(control_model_lindblad, optimizer_lindblad, rho_init, psi_final)

plt.plot(L)
plt.show()

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


In [163]:
ws = np.zeros(nsteps)
gs = np.zeros(nsteps)

for i in range(nsteps):
    params = control_model_lindblad(times[i:i+1])
    ws[i] = params[0]
    gs[i] = params[1]

plt.plot(ws)
plt.plot(gs)
plt.show()

In [164]:
plt.plot(L)
plt.show()

In [165]:
rho = rho_init
for i in range(nsteps):
    rho = lindbladian(rho, H(ws[i], gs[i]), diss, gamma, dt)

print(rho_init)
print(rho.round(2))
print(rho_final)

psi_final_dag = jnp.conj(psi_final).T
fid = jnp.real(psi_final_dag @ rho @ psi_final)
print(fid)

[[1 0]
 [0 0]]
[[0.  +0.j   0.13+0.09j]
 [0.13-0.09j 1.  +0.j  ]]
[[0 0]
 [0 1]]
0.9999994


# Fidelity as a function of gamma (dissipation rate)

In [None]:
######################################### PROBLEM DEFINITION HERE ###########################################

T = 1
nsteps = 100
times = jnp.linspace(0,T,nsteps)
dt = T/nsteps

rho_init = np.outer(np.conj(psi_0), psi_0)
#psi_final = 1/jnp.sqrt(2) * (psi_0 + psi_1)
psi_final = psi_1
rho_final = np.outer(np.conj(psi_final), psi_final)

diss = rho_x

In [129]:
ngammas = 10
max_gamma = 2
gammas = jnp.linspace(0,max_gamma,ngammas)
fids = np.zeros(ngammas)

for z in range(ngammas):
  gamma = gammas[z]
  ### MODEL

  class Control(nnx.Module):
    def __init__(self, din, dmid1, dmid2, dout, rngs: nnx.Rngs):
      self.linear_in = nnx.Linear(din, dmid1, rngs=rngs)
      self.linear_mid = nnx.Linear(dmid1, dmid2, rngs=rngs)
      self.linear_out = nnx.Linear(dmid2, dout, rngs=rngs)

    def __call__(self, x):
      x = nnx.gelu(self.linear_mid(nnx.relu(self.linear_in(x))))
      return self.linear_out(x)
  

  din = 1 # t 
  dmid1 = 4
  dmid2 = 4
  dout = 2 # hamiltonian parameters as functions of t

  control_model_lindblad = Control(din, dmid1, dmid2, dout, rngs=nnx.Rngs(0))

  ### OPTIMIZER

  optimizer_lindblad = nnx.Optimizer(control_model_lindblad, optax.adam(1e-3))  # Reference sharing.

  @nnx.jit  # Automatic state management for JAX transforms.
  def train_step_lindblad(control_model_lindblad, optimizer_lindblad, x, y):
    #y_sqrt = sp.linalg.sqrtm(y)

    def loss_fn(control_model_lindblad):
      
      rho = x
      for i in range(nsteps):
        params = control_model_lindblad(jnp.array([times[i]]))
        rho_new = lindbladian(rho, H(params[0], params[1]), diss, gamma, dt)
        rho = rho_new

      ############################ SPECIAL CASE: PURE FINAL STATE ##############################
      y_dag = jnp.conj(y).T
      fid = jnp.real(y_dag @ rho @ y)
      return (1-fid)**2

    loss, grads = nnx.value_and_grad(loss_fn)(control_model_lindblad)
    optimizer_lindblad.update(grads)  # in-place updates

    return loss


  ################################ TRAINING #######################################

  nepochs = 5000
  L = np.zeros(nepochs)

  for i in range(nepochs):
      L[i] = train_step_lindblad(control_model_lindblad, optimizer_lindblad, rho_init, psi_final)

  ################################ DRIVING #######################################

  ws = np.zeros(nsteps)
  gs = np.zeros(nsteps)

  for i in range(nsteps):
      params = control_model_lindblad(times[i:i+1])
      ws[i] = params[0]
      gs[i] = params[1]

  rho = rho_init
  for i in range(nsteps):
      rho = lindbladian(rho, H(ws[i], gs[i]), diss, gamma, dt)

  psi_final_dag = jnp.conj(psi_final).T
  fid = jnp.real(psi_final_dag @ rho @ psi_final)
  fids[z] = fid

In [131]:
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111)

ax.plot(gammas, fids)
ax.set_ylim(0.5,1.09)
fig.show()

In [124]:
plt.plot(ws)
plt.plot(gs)
plt.show()