# Open system evolution in Stinespring fashion

Specifying the form of environment and the interaction with it makes it possible to calculate entropy production in a clear way

In [83]:
import numpy as np
import matplotlib.pyplot as plt
#import scipy as sp
from jax import numpy as jnp
import jax.scipy as sp

from src import visuals as vis
%matplotlib qt

from flax import nnx
import optax
import jax
from jax import numpy as jnp

from tqdm import tqdm

import sys

import seaborn 

In [88]:
### DECLARATIONS:


rho_x = jnp.array([[0,1],[1,0]])
rho_y = jnp.array([[0,-1j],[1j,0]])
rho_z = jnp.array([[1,0],[0,-1]])

rho_m = jnp.array([[0,1],[0,0]])
rho_p = jnp.array([[0,0],[1,0]])

psi_0 = jnp.array([1,0])
psi_1 = jnp.array([0,1])

### SYSTEM:
@nnx.jit
def H_S_single(w, g):
    # two-level system:
    H0 = w * rho_z

    # external driving:
    V = g * rho_x
    return H0 + V

@nnx.jit
def H_S(w, g):
    # two-level system:
    H0 = w * rho_z

    # external driving:
    V = g * rho_x
    return jnp.kron(H0 + V, jnp.eye(2))

### ENVIRONMENT:
@nnx.jit
def H_E_single(r):
    # two-level system again:
    H = r * rho_z
    return H


@nnx.jit
def H_E(r):
    # two-level system again:
    H = r * rho_z
    return jnp.kron(jnp.eye(2), H) 

### INTERACTION:
@nnx.jit
def H_I(k):
    op = jnp.kron(rho_m, rho_p)
    H = k * (op + jnp.conj(op).T)
    return H

def H(w, g, r, k):
    return H_S(w, g) + H_E(r) + H_I(k)

def U(w, g, r, k):
    H_tmp = H(w, g, r, k)
    return sp.linalg.expm(-1j * H_tmp)

def unitary_step(H, state, dt):
    U = sp.linalg.expm(-1j * H * dt)
    new_state = U @ state
    return new_state

def rho_th(T, H):
    beta = 1/T
    rho_tmp = sp.linalg.expm(- beta * H)
    Z = jnp.trace(rho_tmp)
    return rho_tmp / Z

def partial_tr(M):
    """Partial trace operation for tracing out the environment out of a composite system of two qubits"""
    A = jnp.trace(M[0:2, 0:2])
    B = jnp.trace(M[0:2, 2:4])
    C = jnp.trace(M[2:4, 0:2])
    D = jnp.trace(M[2:4, 2:4])
    M_new = jnp.array([[A,B],[C,D]])
    return M_new

def stinespring_step(rho_sys, rho_env, H, dt):
    U = sp.linalg.expm(-1j * H * dt)
    U_dag = jnp.conj(U).T
    rho_comp = jnp.kron(rho_sys, rho_env)
    rho_after_unit = U @ rho_comp @ U_dag
    rho_sys_after = partial_tr(rho_after_unit)
    return rho_sys_after

In [201]:
### MODEL

T = 2 * np.pi # total time of simulation
nsteps = 1000 # number of timesteps
times = jnp.linspace(0, T, nsteps)
dt = T/nsteps

### HAMILTONIAN
w = 0
g = 0
r = 1
k = 1
H_tot = H(w,g,r,k)

### SYSTEM
rho_init = jnp.outer(jnp.conj(psi_0), psi_0) # initial state

### ENVIRONMENT 
temp = 1 # bath temparature
rho_env = rho_th(temp, H_E_single(r))
#rho_env = jnp.outer(jnp.conj(psi_1), psi_1)

print(rho_env)


[[0.11920292 0.        ]
 [0.         0.8807971 ]]


In [208]:
s_x = np.zeros(nsteps)
s_y = np.zeros(nsteps)
s_z = np.zeros(nsteps)
tr = np.zeros(nsteps)
pur = np.zeros(nsteps)
fid = np.zeros(nsteps)
W = 0


ws = gs = rs = ks = np.zeros(nsteps)

# random hamiltonian params:
ws = np.random.normal(1,2,nsteps)
gs = np.random.normal(1,1,nsteps)
rs = np.random.normal(10,1,nsteps)
ks = np.random.normal(1,1,nsteps)



rate = 10

rho = rho_init
for i in range(nsteps):
    decay = jnp.exp(-rate*times[i]/T) # - 0.5
    #decay = jnp.sin(2*jnp.pi*times[i]/T)
    s_x[i] = np.trace(rho @ rho_x).real
    s_y[i] = np.trace(rho @ rho_y).real
    s_z[i] = np.trace(rho @ rho_z).real
    
    tr[i] = np.real(np.trace(rho))
    pur[i] = np.real(np.trace(rho @ rho))
    fid[i] = np.real(np.trace(rho @ rho_init))
    H_tot = H(w*decay+ws[i], g*decay+gs[i], r+rs[i], k+ks[i])
    rho_new = stinespring_step(rho, rho_env, H_tot, dt)

    if i>0:
        dH = 0 #H(ws[i], gs[i]) - H(ws[i-1], gs[i-1])
        dW = 0 #jnp.trace(dH @ rho_new) #* dt
        W += jnp.real(dW)
    rho = rho_new


print(rho_init)
print()
print("Final state: ")
print(rho.round(2))

[[1 0]
 [0 0]]

Final state: 
[[ 0.66999996+0.j         -0.05      -0.39999998j]
 [-0.05      +0.39999998j  0.32999998-0.j        ]]


In [209]:
fig = plt.figure(figsize=(14,6))
ax = fig.add_subplot(121)

plotting_spins = 1
if plotting_spins:
    ax.plot(times, s_x.real, c='C1', label='x')
    #ax.plot(s_x.imag, c='C1', ls='-.')
    ax.plot(times, s_y.real, c='C2', label='y')
    #ax.plot(s_y.imag, c='C2', ls='-.')
    ax.plot(times, s_z.real, c='C3', label='z')
    #ax.plot(s_z.imag, c='C3', ls='-.')
    ax.hlines([0], [0],[times[-1]], colors=['black'], linestyles=[':'])
ax.legend()

ax2 = fig.add_subplot(122)
plotting_markers = 1 #not plotting_spins
if plotting_markers:
    ax2.plot(times, tr, ls='--', c='black', label='Trace')
    ax2.plot(times, pur, ls='--', c='pink', label='Purity')
    ax2.plot(times, fid, ls='--', c='cyan', label='Fidelity')

ax2.legend()
fig.suptitle("Spin components exp. vals. and qubit figures of merit")
fig.show()


In [210]:
fig, ax = vis.plotBlochSphereTrajectory(s_x, s_y, s_z, nsteps)