# Open system evolution in Stinespring fashion

Specifying the form of environment and the interaction with it makes it possible to calculate entropy production in a clear way

In [2]:
import numpy as np
import matplotlib.pyplot as plt
#import scipy as sp
from jax import numpy as jnp
import jax.scipy as sp

from src import visuals as vis
%matplotlib qt

from flax import nnx
import optax
import jax
from jax import numpy as jnp

from tqdm import tqdm

import sys

In [None]:
### DECLARATIONS:


rho_x = jnp.array([[0,1],[1,0]])
rho_y = jnp.array([[0,-1j],[1j,0]])
rho_z = jnp.array([[1,0],[0,-1]])

rho_m = jnp.array([[0,1],[0,0]])
rho_p = jnp.array([[0,0],[1,0]])

psi_0 = jnp.array([1,0])
psi_1 = jnp.array([0,1])

### SYSTEM:
@nnx.jit
def H_S(w, g):
    # two-level system:
    H0 = w * rho_z

    # external driving:
    V = g * rho_x
    return jnp.kron(H0 + V, jnp.eye(2))

### ENVIRONMENT:
@nnx.jit
def H_E(r):
    # two-level system again:
    H = r * rho_z
    return jnp.kron(jnp.eye(2), H) 

### INTERACTION:
@nnx.jit
def H_I(k):
    op = jnp.kron(rho_m, rho_p)
    H = k * (op + jnp.conj(op).T)
    return H

In [14]:
H_S(2,3)

Array([[ 2.,  0.,  3.,  0.],
       [ 0.,  2.,  0.,  3.],
       [ 3.,  0., -2., -0.],
       [ 0.,  3., -0., -2.]], dtype=float32)

In [17]:
H_E(0)

Array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)

In [19]:
H_I(4)

Array([[0, 0, 0, 0],
       [0, 0, 4, 0],
       [0, 4, 0, 0],
       [0, 0, 0, 0]], dtype=int32)

In [20]:
def H(w, g, r, k):
    return H_S(w, g) + H_E(r) + H_I(k)

In [24]:
def U(w, g, r, k):
    H_tmp = H(w, g, r, k)
    return sp.linalg.expm(-1j * H_tmp)

In [31]:
def unitary_step(H, state, dt):
    U = sp.linalg.expm(-1j * H * dt)
    new_state = U @ state
    return new_state

In [33]:
def rho_th(T, H):
    beta = 1/T
    rho_tmp = sp.linalg.expm(- beta * H)
    Z = jnp.trace(rho_tmp)
    return rho_tmp / Z

In [35]:
R = rho_th(1,H(1,1,1,1))

In [36]:
import seaborn 

In [37]:
seaborn.heatmap(R)

<Axes: >

In [58]:
row = [1,2,3,4]

Z = np.array([row, [10*x for x in row], [100*x for x in row], [1000*x for x in row]])

In [62]:
Z

array([[   1,    2,    3,    4],
       [  10,   20,   30,   40],
       [ 100,  200,  300,  400],
       [1000, 2000, 3000, 4000]])

In [63]:
Z[2:4, 2:4]

array([[ 300,  400],
       [3000, 4000]])

In [64]:
def partial_tr(M):
    """Partial trace operation for tracing out the environment out of a composite system of two qubits"""
    A = jnp.trace(M[0:2, 0:2])
    B = jnp.trace(M[0:2, 2:4])
    C = jnp.trace(M[2:4, 0:2])
    D = jnp.trace(M[2:4, 2:4])
    M_new = jnp.array([[A,B],[C,D]])
    return M_new


In [65]:
partial_tr(Z)

Array([[  21,   43],
       [2100, 4300]], dtype=int32)