In [276]:
#Developing the framework to test the information content across different NMR protocols. 
#The core ideas are:
#1) Leverage our tools for Linbladian construction in the Pauli basis. Caveat: it has a dependency on Spinach basis output. We need to be careful that input parameters correspond to
# those used in the Spinach simulation
# 2) The NMR protocols are defined as classmethods that can be differentiable through JAX


import sys
sys.path.append('../linbladian_utils/')
import numpy as np
import scipy.io as spio
from scipy.linalg import expm
from matplotlib import pyplot as plt
import sys 
import openfermion as of
import pandas as pd 
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from functools import partial
#from jax.experimental import sparse
from jax import lax
from basis_utils import Sx,Sy,Sz
from basis_utils import MatRepLib, S_plus, S_minus

from analytical_fit import  GetRelManySpins, get_chemical_shifts, Get_Det_And_Rates
from basis_utils import read_spinach_info, build_list_ISTs, NormalizeBasis, build_symbolic_list_ISTs, InnProd
#from simulation_utils import Hellinger_2D, GenNOESYSpectrum

from analytical_fit import lat_S_plus, lat_S_minus, lat_Sz, J_plus, J_minus,Get_Det_And_Rates_latex



In [335]:
import scipy

# Simple test case: DFG

In [256]:
#Auxiliary functions to verify correctness 

def HamMatRep(H,basis,n_qubits=2):
    N = len(basis)

    Matrep= np.zeros([N,N],dtype=complex)
    for i in range(N):
        for j in range(N):
            Matrep[i,j] = InnProd(basis[i],of.commutator(H,basis[j]),n_qubits=n_qubits)

    
    return Matrep

def GenH0_Ham(offset,B0,zeeman_scalars,Jcoups,gamma):
    """
    Returns: the zeroth order Hamiltonian considered for dynamical evolution in nthe simulations, in OpenFermion format
    Args:
    offset: frequency offset for the spin Zeeman frequencies, in Hz
    B0: Strength of the magnetic field in Teslas
    zeeman_scalars: list of chemical shifts for spins, in ppm
    Jcoups: matrix of size N x N, N being the number of spins, that encodes the scalar couplings between spins (in Hz)
    gamma: gyromagnetic ratio of the spins (an homonuclear case is assumed) 
    """


    #offset = -46681
    #B0 = 9.3933
    w0 = -gamma*B0
    o1 = 2*np.pi*offset
    Nspins = len(zeeman_scalars)

    Hamiltonian = of.QubitOperator()

    for i in range(Nspins):
        w = o1+w0*zeeman_scalars[i]/1e6
        Hamiltonian+=w*Sz(i)
        for j in range(i+1,Nspins):

            Hamiltonian+=2*np.pi*Jcoups[i,j]*(Sx(i)*Sx(j)+Sy(i)*Sy(j)+Sz(i)*Sz(j))
    
    return Hamiltonian


In [None]:
#For purposes of performance and cleanin-ness in coding with JAX, we consider some reformulation in the construction of Liouvillian in the Pauli basis..
from functools import partial
#from jax.experimental import sparse
from jax import lax
import scipy.linalg
from basis_utils import Sx,Sy,Sz
from basis_utils import MatRepLib, S_plus, S_minus



def InnProd_jax(Op1,Op2):
    """ 
    Op1 and Op2 are JAX arrays
    """

    return jnp.trace(Op1.conj().T@Op2)


def commutator(A,B):
    """
    Returns: commutators between matrices A and B 
    """

    return A@B-B@A


def single_comm_superOp(observable,basis):
    """
    Returns: a matrix representation of the superoperator A in AO=[a,O] in a Pauli basis contained in the "basis" array 
    Args:
    observable, a matrix 
    basis, Pauli basis, array of matrices
    """
    def compute_element(i,j,observable=observable,basis=basis):
        #print("i,j:",i,j)
        basis_i = lax.dynamic_index_in_dim(basis, i, axis=0, keepdims=False)
        basis_j = lax.dynamic_index_in_dim(basis, j, axis=0, keepdims=False)

        return InnProd_jax(basis_i,commutator(observable,basis_j))

    N=len(basis)

    i_idx, j_idx = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing='ij')
    i_flat = i_idx.flatten()
    j_flat = j_idx.flatten()

    # Vectorized computation
    vec_compute = jax.vmap(compute_element)
    elements = vec_compute(i_flat, j_flat)

    # Reshape to matrix
    matrix = elements.reshape(N, N)

    return matrix

def double_comm_superop(outer_op,inn_op,basis):
    """
    Returns: a matrix representation of the superoperator A in AO=[outer_op,[inn_op,O] in a Pauli basis contained in the "basis" array 
    Args:
    outer_op, matrix representation for outer_op
    inn_op, matrix representation for inner op 
    basis, Pauli basis, array of matrices
    """

    def compute_element(i,j,outer_op=outer_op,inn_op=inn_op,basis=basis):
        #print("i,j:",i,j)
        basis_i = lax.dynamic_index_in_dim(basis, i, axis=0, keepdims=False)
        basis_j = lax.dynamic_index_in_dim(basis, j, axis=0, keepdims=False)

        return InnProd_jax(basis_i,commutator(outer_op(commutator(inn_op,basis_j))))

    N=len(basis)

    i_idx, j_idx = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing='ij')
    i_flat = i_idx.flatten()
    j_flat = j_idx.flatten()

    # Vectorized computation
    vec_compute = jax.vmap(compute_element)
    elements = vec_compute(i_flat, j_flat)

    # Reshape to matrix
    matrix = elements.reshape(N, N)

    return matrix

def Linb_chann_superop(An,Am,basis):
    """
    Returns: a matrix representation of the channel L_{An,Am}[rho]=An*rho*Am^{\dagger}-0.5*(An*Am^{\dagger}*rho+rho*An*Am^{\dagger} )
    Args: An, matrix form of a jump operator
    Am, matrix form of a jump operator
    basis, array of matrices, the Pauli basis used to represent the operators 
    """
    def compute_element(i,j,An=An,Am=Am,basis=basis):
        #print("i,j:",i,j)
        basis_i = lax.dynamic_index_in_dim(basis, i, axis=0, keepdims=False)
        basis_j = lax.dynamic_index_in_dim(basis, j, axis=0, keepdims=False)
        Am_dag = jnp.transpose(jnp.conjugate(Am))


        Linb_on_j = An@basis_j@Am_dag -0.5*(An@Am_dag@basis_j+basis_j@An@Am_dag)

        return InnProd_jax(basis_i,Linb_on_j)


    N=len(basis)

    i_idx, j_idx = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing='ij')
    i_flat = i_idx.flatten()
    j_flat = j_idx.flatten()

    # Vectorized computation
    vec_compute = jax.vmap(compute_element)
    elements = vec_compute(i_flat, j_flat)

    # Reshape to matrix
    matrix = elements.reshape(N, N)
    return matrix


#@jax.jit
#def build_R_operator(gammas,two_spin_ops,basis):

#    n_ops = len(two_spin_ops)

#    i_idx, j_idx = jnp.meshgrid(jnp.arange(n_ops), jnp.arange(n_ops), indexing='ij')
#    i_flat = i_idx.flatten()
#    j_flat = j_idx.flatten()

#    def term(gamma,i_idx,j_idx,two_spin_ops =two_spin_ops,basis=basis):
        
#        Mk = double_comm_superop(two_spin_ops[i_idx],two_spin_ops[j_idx], basis)
#        return gamma * Mk

    # Vectorize over theta and observables
#    return jax.vmap(term)(gammas, i_flat,j_flat).sum(axis=0)


@jax.jit
def build_H0_operator(theta, observables, basis):
    def term(t, M):
        Mk = single_comm_superOp(M, basis)
        return t * Mk

    # Vectorize over theta and observables
    return jax.vmap(term)(theta, observables).sum(axis=0)

@jax.jit
def exponentiate_H0(theta, observables, basis):

    return jax.scipy.linalg.expm(-1j*build_H0_operator(theta, observables, basis))


##### Functions dedicated to the construction of damping rates and jump operators...

def SpecFunc(w,tc):

    return 0.2*tc/(1 + w**2 * tc**2)

@jax.jit
def GammaRates_jax(w,tc,coord1,coord2,coord3,coord4,gamma):
    hbar = 1.054571628*1e-34
    diff1 = coord2-coord1
    #jax.debug.print("gamma^4 is={}",gamma**4)
    #jax.debug.print("diff between coord2 and coord1 ={}",diff1)

    diff2 = coord4 - coord3
    r1 = jnp.sqrt(jnp.dot(diff1,diff1))
    r2 = jnp.sqrt(jnp.dot(diff2,diff2))

    #jax.debug.print("r1 = {}", r1)
    #jax.debug.print("r2 = {}", r2)
    #jax.debug.print("diff1 ={}",diff1)
    #jax.debug.print("r2 = {}", r2)

    x1 = diff1[0]
    y1 = diff1[1]
    z1 = diff1[2]

    x2 = diff2[0]
    y2 = diff2[1]
    z2 = diff2[2]

    r1 = jnp.sqrt(x1**2+y1**2+z1**2)
    r2 = jnp.sqrt(x2**2+y2**2+z2**2)

    #jax.debug.print("x1 = {}", x1)

    #r1 = jnp.sqrt(x1**2+y1**2+z1**2)
    #r2 = jnp.sqrt(x2**2+y2**2+z2**2)

    Invariant = y2**2 * (2*y1**2-z1**2)-x2**2 * (y1**2+z1**2)+6*y1*y2*z1*z2-(y1**2-2*z1**2)*z2**2

    #jax.debug.print("Invariant1 = {}", Invariant)
    Invariant += 6*x1*x2*(y1*y2+z1*z2)+x1**2 * (2*x2**2-y2**2-z2**2)
    #jax.debug.print("Invariant2 = {}", Invariant)
    Invariant = (3/(r1**2 * r2**2))*Invariant
    #jax.debug.print("Invariant3 = {}", Invariant)
    
    Apref = Invariant

    #jax.debug.print("Apref = {}", Apref)
    #jax.debug.print("Specfunc vakues is = {}",SpecFunc(w,tc))

    #jax.debug.print("hbar^2 gamma^4/(r1^3 * r2^3) is={}",hbar**2 * gamma**4 /(r1**3 * r2**3))

    return hbar**2 * 1e-14*gamma**4 * SpecFunc(w,tc)*Apref/(r1**3 * r2**3)

####TODO: constructing the family of jump operators under the K_i classification is computationally inefficient
#but useful for debugging purposes. In a near-future version, we can generate the family of operators with 
#a single function and avoid redundant calculations


def K2_ops(Nspins,basis):
    """
    Returns: a jax array that contains the corresponding list of operators to the rates generated by the K2_rates_jax function 
    Args:
    Nspins: number of spins
    basis: array of matrices that encode the Pauli basis 
    """
    ops =[]
    for i in range(Nspins):
        for j in range(i+1,Nspins):
            
            for k in range(Nspins):
                for l in range(k+1,Nspins):
                    An = jnp.array(of.get_sparse_operator(S_plus(i)*S_plus(j),n_qubits=Nspins).toarray())
                    Am = jnp.array(of.get_sparse_operator(S_plus(k)*S_plus(l),n_qubits=Nspins).toarray())

                    dum= Linb_chann_superop(An,Am,basis)

                    An = jnp.array(of.get_sparse_operator(S_minus(k)*S_minus(l),n_qubits=Nspins).toarray())
                    Am = jnp.array(of.get_sparse_operator(S_minus(i)*S_minus(j),n_qubits=Nspins).toarray())

                    dum+= Linb_chann_superop(An,Am,basis)

                    dum+=jnp.conjugate(jnp.transpose(dum))

                    ops.append(dum)

    return jnp.array(ops)


def K1_ops(Nspins,basis):
    """
    Returns: two jax arrays that contains the corresponding list of operators to the rates generated by the K1_rates_jax function 
    Args:
    Nspins: number of spins
    basis: array of matrices that encode the Pauli basis 
    """
    def mat_wrap(of_op,Nspins=Nspins):
        return jnp.array(of.get_sparse_operator(of_op,n_qubits=Nspins).toarray())

    #ops_l =[]
    #ops_k =[]

    #ops = jnp.array([])
    ops = []
    for i in range(Nspins):
        for j in range(i+1,Nspins):
            
            for k in range(Nspins):
                for l in range(k+1,Nspins):
                    
                    dum_l=Linb_chann_superop(mat_wrap(Sz(i)*S_plus(j)),mat_wrap(Sz(k)*S_plus(l)),basis)
                    dum_l+=Linb_chann_superop(mat_wrap(Sz(k)*S_minus(l)),mat_wrap(Sz(i)*S_minus(j)),basis)
                    dum_l+=Linb_chann_superop(mat_wrap(S_plus(i)*Sz(j)),mat_wrap(Sz(k)*S_plus(l)),basis)
                    dum_l+=Linb_chann_superop(mat_wrap(Sz(k)*S_minus(l)),mat_wrap(S_minus(i)*Sz(j)),basis)
                    dum_l+=jnp.conjugate(jnp.transpose(dum_l))
                    #ops_l.append(dum_l)

                    dum_k = Linb_chann_superop(mat_wrap(S_plus(i)*Sz(j)),mat_wrap(S_plus(k)*Sz(l)),basis)
                    dum_k+= Linb_chann_superop(mat_wrap(S_minus(k)*Sz(l)),mat_wrap(S_minus(i)*Sz(j)),basis)
                    dum_k+= Linb_chann_superop(mat_wrap(Sz(i)*S_plus(j)),mat_wrap(S_plus(k)*Sz(l)),basis)
                    dum_k+= Linb_chann_superop(mat_wrap(S_minus(k)*Sz(l)),mat_wrap(Sz(i)*S_minus(j)),basis)
                    dum_k+=jnp.conjugate(jnp.transpose(dum_k))
                    #ops_k.append(dum_k)
                    #ops = jnp.concatenate([ops,jnp.array([dum_l,dum_k])])
                    ops.append(dum_l)
                    ops.append(dum_k)

                    


    return jnp.array(ops)


def K0_ops(Nspins,basis):
    """
    Returns: two jax arrays that contains the corresponding list of operators to the rates generated by the K1_rates_jax function 
    Args:
    Nspins: number of spins
    basis: array of matrices that encode the Pauli basis 
    """
    def mat_wrap(of_op,Nspins=Nspins):
        return jnp.array(of.get_sparse_operator(of_op,n_qubits=Nspins).toarray())
    
    #ops_0 =[]
    #ops_diff =[]
    #ops = jnp.array([])
    ops = []
    for i in range(Nspins):
        for j in range(i+1,Nspins):
            
            for k in range(Nspins):
                for l in range(k+1,Nspins):
                    dum_0 = (8.0/3.0)*Linb_chann_superop(mat_wrap(Sz(i)*Sz(j)),mat_wrap(Sz(k)*Sz(l)),basis)
                    dum_0+= -(2.0/3.0)*Linb_chann_superop(mat_wrap(S_plus(i)*S_minus(j)),mat_wrap(Sz(k)*Sz(l)),basis)
                    dum_0+= -(2.0/3.0)*Linb_chann_superop(mat_wrap(Sz(k)*Sz(l)),mat_wrap(S_minus(i)*S_plus(j)),basis)
                    dum_0 += jnp.conjugate(jnp.transpose(dum_0))

                    #ops_0.append(-1.0*dum_0)

                    dum_diff = -(2.0/3.0)*Linb_chann_superop(mat_wrap(Sz(i)*Sz(j)),mat_wrap(S_plus(k)*S_minus(l)),basis)
                    dum_diff+= -(2.0/3.0)*Linb_chann_superop(mat_wrap(S_minus(k)*S_plus(l)),mat_wrap(Sz(i)*Sz(j)),basis)
                    dum_diff+=  (1.0/6.0)*Linb_chann_superop(mat_wrap(S_plus(i)*S_minus(j)),mat_wrap(S_plus(k)*S_minus(l)),basis)
                    dum_diff+= (1.0/6.0)*Linb_chann_superop(mat_wrap(S_minus(k)*S_plus(l)),mat_wrap(S_minus(i)*S_plus(j)),basis)
                    dum_diff+= (1.0/6.0)*Linb_chann_superop(mat_wrap(S_minus(i)*S_plus(j)),mat_wrap(S_plus(k)*S_minus(l)),basis)
                    dum_diff+= (1.0/6.0)*Linb_chann_superop(mat_wrap(S_minus(k)*S_plus(l)),mat_wrap(S_plus(i)*S_minus(j)),basis)
                    dum_diff+= jnp.conjugate(jnp.matrix_transpose(dum_diff))
                    #ops_diff.append(-1.0*dum_diff)
                    #ops = jnp.concatenate([ops,[dum_0, dum_diff]])
                    ops.append(dum_0)
                    ops.append(dum_diff)


    return jnp.array(ops)

@partial(jax.jit, static_argnames=['Nspins'])
def K2_rates_jax(freqs,tc,coords,Nspins,gamma):
    """
    Returns: 1) jnp array of K2 damping rates
    Args:
    freqs: array that contains the Zeeman frequencies of the N spins
    tc: bath correlation time
    coords: coordinates of spins
    Nspins: number of spins
    gamma: gyromagnetic ratio for spins (assuming homonuclear for now)
    basis: Pauli basis for matrix representation of operators, it needs to be an array of matrices
    """
    
    def get_rate(i,j,k,l,freqs=freqs,tc=tc,coords=coords,gamma=gamma):

        damp_rate = GammaRates_jax(freqs[k]+freqs[l],tc,coords[i],coords[j],coords[k],coords[l],gamma)
        #jax.debug.print("Gamma = {}", damp_rate)

        return damp_rate
    
    
    vec_get_rate = jax.vmap(lambda x: get_rate(x[0], x[1], x[2], x[3]))
    #vec_get_op = jax.vmap(lambda x: get_op(x[0],x[1],x[2],x[3]))
    
    ij = jnp.array([(i, j) for i in range(Nspins) for j in range(i + 1, Nspins)])
    kl = jnp.array([(k, l) for k in range(Nspins) for l in range(k + 1, Nspins)])

    # Create the full 2D grid of combinations between ij and kl
    ij_grid = jnp.repeat(ij, len(kl), axis=0)
    kl_grid = jnp.tile(kl, (len(ij), 1))

    # Combine into a single array of shape 
    ijkl = jnp.hstack([ij_grid, kl_grid])

    
    rates = vec_get_rate(ijkl)
    #ops = vec_get_op(ijkl)
    #print("Rates are:", rates)

    return rates

@partial(jax.jit, static_argnames=['Nspins'])
def K1_rates_jax(freqs,tc,coords,Nspins,gamma):
    """
    Returns: 1) jnp array of K1_l damping rates and 2) array of K1_k damping rates
    Args:
    freqs: array that contains the Zeeman frequencies of the N spins
    tc: bath correlation time
    coords: coordinates of spins
    Nspins: number of spins
    gamma: gyromagnetic ratio for spins (assuming homonuclear for now)
    basis: Pauli basis for matrix representation of operators, it needs to be an array of matrices
    """
    def get_rate(i,j,k,l,freqs=freqs,tc=tc,coords=coords,gamma=gamma):
        rate_l = GammaRates_jax(freqs[l],tc,coords[i],coords[j],coords[k],coords[l],gamma)
        rate_k = GammaRates_jax(freqs[k],tc,coords[i],coords[j],coords[k],coords[l],gamma)
        #jax.debug.print("Gamma = {}", rate_l)
        return rate_l, rate_k
      
    
    vec_get_rate = jax.vmap(lambda x: get_rate(x[0], x[1], x[2], x[3]))
    #vec_get_op = jax.vmap(lambda x: get_op(x[0],x[1],x[2],x[3]))
    
    ij = jnp.array([(i, j) for i in range(Nspins) for j in range(i + 1, Nspins)])
    kl = jnp.array([(k, l) for k in range(Nspins) for l in range(k + 1, Nspins)])

    # Create the full 2D grid of combinations between ij and kl
    ij_grid = jnp.repeat(ij, len(kl), axis=0)
    kl_grid = jnp.tile(kl, (len(ij), 1))

    # Combine into a single array of shape (num_combinations, 4)
    ijkl = jnp.hstack([ij_grid, kl_grid])

    
    rates = vec_get_rate(ijkl)
    return jnp.concatenate(rates)

@partial(jax.jit, static_argnames=['Nspins'])
def K0_rates_jax(freqs,tc,coords,Nspins,gamma):
    """
    Returns: 1) jnp array of K0_0 damping rates and 2) array of K0_diff damping rates
    Args:
    freqs: array that contains the Zeeman frequencies of the N spins
    tc: bath correlation time
    coords: coordinates of spins
    Nspins: number of spins
    gamma: gyromagnetic ratio for spins (assuming homonuclear for now)
    basis: Pauli basis for matrix representation of operators, it needs to be an array of matrices
    """
    def get_rate(i,j,k,l,freqs=freqs,tc=tc,coords=coords,gamma=gamma):
        rate_0 = GammaRates_jax(0.0,tc,coords[i],coords[j],coords[k],coords[l],gamma)
        rate_diff = GammaRates_jax(freqs[k]-freqs[l],tc,coords[i],coords[j],coords[k],coords[l],gamma)
        #jax.debug.print("Gamma = {}", rate_l)
        return rate_0, rate_diff
      
    
    vec_get_rate = jax.vmap(lambda x: get_rate(x[0], x[1], x[2], x[3]))
    #vec_get_op = jax.vmap(lambda x: get_op(x[0],x[1],x[2],x[3]))
    
    ij = jnp.array([(i, j) for i in range(Nspins) for j in range(i + 1, Nspins)])
    kl = jnp.array([(k, l) for k in range(Nspins) for l in range(k + 1, Nspins)])

    # Create the full 2D grid of combinations between ij and kl
    ij_grid = jnp.repeat(ij, len(kl), axis=0)
    kl_grid = jnp.tile(kl, (len(ij), 1))

    # Combine into a single array of shape (num_combinations, 4)
    ijkl = jnp.hstack([ij_grid, kl_grid])

    
    rates = vec_get_rate(ijkl)
    return jnp.concatenate(rates)



@partial(jax.jit,static_argnames=['Nspins'])
def build_R_operator(freqs,tc,coords,gamma,basis,Nspins):
    """
    Returns: \sum_{i}\gamma_{i}O_{i}, \gamma_{i}'s being scalars and O_{i}'s being operators 
    """

    #get rates and operators...
    k2_rates = K2_rates_jax(freqs,tc,coords,Nspins,gamma)
    k1_rates = K1_rates_jax(freqs,tc,coords,Nspins,gamma)
    k0_rates = K0_rates_jax(freqs,tc,coords,Nspins,gamma)

    k2_ops = K2_ops(Nspins,basis)
    k1_ops = K1_ops(Nspins,basis)
    k0_ops = K0_ops(Nspins,basis)

    rates =0.25*jnp.concatenate([k2_rates,k1_rates,k0_rates])
    ops = jnp.concatenate([k2_ops,k1_ops,k0_ops])

    #return jnp.tensordot(k2_rates,k2_ops,axes=1),jnp.tensordot(k1_rates,k1_ops,axes=1),jnp.tensordot(k0_rates,k0_ops,axes=1),jnp.tensordot(rates, ops, axes=1)
    return jnp.tensordot(rates, ops, axes=1)

###Generator of the total Liouvillian and the time-evolution operator...
@partial(jax.jit,static_argnames=['Nspins'])
def build_time_evol_op(time,coh_params,coh_observables,freqs,tc,coords,gamma,basis,Nspins):
    
    R = build_R_operator(freqs,tc,coords,gamma,basis,Nspins)

    H0 = build_H0_operator(coh_params, coh_observables, basis)

    return jax.scipy.linalg.expm(-1j*time*H0+time*R)
    #return H0,R


####Just for the purposes of testing differentiation...
def build_ref_prop(time,offset,B0,zeeman_scalars,Jcoups,freqs,coords,tc,gamma,Normbasis):

    H0_of = GenH0_Ham(offset,B0,zeeman_scalars,Jcoups,gamma)
    H0 = HamMatRep(H0_of,Normbasis,n_qubits=2)

    K2,K1,K0,R_analytical = GetRelManySpins(2*np.pi*freqs,coords,tc,gamma,Normbasis)

    scipy.linalg.expm(-1j*time*H0)




In [235]:
from functools import partial
from openfermion.ops import QubitOperator
from openfermion import get_sparse_operator
import jax.numpy as jnp

@partial(jax.jit, static_argnames=["Nspins"])
def build_ops(Nspins):
    ops = []

    
    def comp_el(i,j):
        op = S_plus(i)*S_plus(j)
        mat = get_sparse_operator(op, n_qubits=Nspins).toarray()

        return mat
    
    i_idx, j_idx = jnp.meshgrid(jnp.arange(Nspins), jnp.arange(Nspins), indexing='ij')
    i_flat = i_idx.flatten()
    j_flat = j_idx.flatten()
    
    return jax.vmap(comp_el)(i_flat,j_flat)


# Generation of the time-evolution operator



In [None]:
##
text="""1      (0,0)   (0,0)   
  2      (0,0)   (1,1)   
  3      (0,0)   (1,0)   
  4      (0,0)   (1,-1)  
  5      (1,1)   (0,0)   
  6      (1,1)   (1,1)   
  7      (1,1)   (1,0)   
  8      (1,1)   (1,-1)  
  9      (1,0)   (0,0)   
  10     (1,0)   (1,1)   
  11     (1,0)   (1,0)   
  12     (1,0)   (1,-1)  
  13     (1,-1)  (0,0)   
  14     (1,-1)  (1,1)   
  15     (1,-1)  (1,0)   
  16     (1,-1)  (1,-1)  
"""

data = read_spinach_info(text)

basis = build_list_ISTs(data)
prefacts,Symb_basis = build_symbolic_list_ISTs(data)

#Normbasis = NormalizeBasis(basis,n_qubits=4,checkOrth=True) I have verified the orthonormalization of the basis
Normbasis = NormalizeBasis(basis,n_qubits=4,checkOrth=False)
Normbasis = np.array(Normbasis)




In [None]:
gammaF = 251814800.0
coord1 = jnp.array([-0.0551,-1.2087,-1.6523],dtype=jnp.float64)*1e-10
coord2 = jnp.array([-0.8604 ,-2.3200 ,-0.0624],dtype=jnp.float64)*1e-10

coords = jnp.array([coord1,coord2])

w1 = -376417768.6316 
w2 = -376411775.1523 
freqs = jnp.array([w1,w2])
tc = 0.5255e-9
B0 = 9.3933

zeeman_scalar_1 = -113.8796
zeeman_scalar_2 = -129.8002
zeeman_scalars = jnp.array([zeeman_scalar_1,zeeman_scalar_2])

#w0*zeeman_scalars[i]/1e6
chem_shifts = get_chemical_shifts(gammaF,B0,zeeman_scalars)
Nspins = 2

mat_basis = []

for i in range(len(basis)):
    sp_op = of.get_sparse_operator(Normbasis[i],n_qubits=2)
    mat_basis.append(sp_op.toarray())

mat_basis=jnp.array(mat_basis)

#Generation of reference rate values and operators...

list_jumps, list_damp_rates, list_dets=Get_Det_And_Rates(2*np.pi*freqs,tc,coords,Nspins,gammaF,chem_shifts)

###Generation of the reference unitary...




In [271]:
offset = -46681
B0 = 9.3933
zeeman_scalar_1 = -113.8796
zeeman_scalar_2 = -129.8002
zeeman_scalars = np.array([zeeman_scalar_1,zeeman_scalar_2])
Jcoup = 238.0633#2*np.pi*238.0633

Jcoups = np.array([[0.0,Jcoup],[Jcoup,0.0]])

RefH0_of = GenH0_Ham(offset,B0,zeeman_scalars,Jcoups,gammaF)
RefH0 = HamMatRep(RefH0_of,Normbasis,n_qubits=2)





In [274]:
####Checking "by-hand" TODO: we need to developed the interface to extract these parameters and incorporate
#them in the workflow
omega1=2*np.pi*offset-B0*gammaF*zeeman_scalars[0]/1e6
omega2 = 2*np.pi*offset-B0*gammaF*zeeman_scalars[1]/1e6
angJ = 2*np.pi*Jcoup
theta = jnp.array([omega1,omega2,angJ])

obs0 = of.get_sparse_operator(Sz(0),n_qubits=2)
obs1 = of.get_sparse_operator(Sz(1),n_qubits=2)
obs2 = of.get_sparse_operator(Sz(0)*Sz(1)+Sx(0)*Sx(1)+Sy(0)*Sy(1))
observables = jnp.array([obs0.toarray(),obs1.toarray(),obs2.toarray()])

testH0  = build_H0_operator(theta, observables, mat_basis)

In [277]:
####Verifying the construction of the relaxation matrix...

K2,K1,K0,R_analytical = GetRelManySpins(2*np.pi*freqs,coords,tc,gammaF,Normbasis)


K2 type contributions finished
K1 type contributions finished
K0 type contributions finished


In [314]:
K2_jax,K1_jax,K0_jax,R_jax = build_R_operator(2*np.pi*freqs,tc,coords,gammaF,mat_basis,Nspins)

In [330]:
gard_el = jax.jacrev(lambda gyro: build_time_evol_op(0.1,theta,observables,freqs,tc,coords,gyro,mat_basis,Nspins),holomorphic=True)

In [334]:
auto_diff_res = gard_el(gammaF+0.0*1j)

In [327]:
###Numerical verification with respect to brute-force differentiation...

build_time_evol_op(0.1,theta,observables,freqs,tc,coords,gammaF,mat_basis,Nspins)



Array([[ 1.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j],
       [ 0.00000000e+00+0.00000000e+00j, -5.46674026e-01+7.59732071e-01j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
        -3.59605913e-04+6.86122948e-04j,  0.00000000e+00+0.00000000e+00j,
         4.62846708e-03+2.41934955e-03j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j, -1.16976615e-01-8.30686782e-02j,
         0.00000000e+00+0.00000000e+0

In [325]:
time=0.1
jax.scipy.linalg.expm(-1j*time*H0+time*R)

Array([[ 1.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j],
       [ 0.00000000e+00+0.00000000e+00j, -5.46674026e-01+7.59732071e-01j,
         0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
        -3.59605913e-04+6.86122948e-04j,  0.00000000e+00+0.00000000e+00j,
         4.62846708e-03+2.41934955e-03j,  0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j, -1.16976615e-01-8.30686782e-02j,
         0.00000000e+00+0.00000000e+0

# Deprecated 

In [152]:
####Render basis array in matrix form:

mat_basis = []

for i in range(len(basis)):
    sp_op = of.get_sparse_operator(basis[i],n_qubits=2)
    mat_basis.append(sp_op.toarray())

mat_basis=jnp.array(mat_basis)

J=1.1
inter_S = Sz(0)*Sz(1)+Sx(0)*Sx(1)+Sy(0)*Sy(1)

sp_inter_S = of.get_sparse_operator(inter_S,n_qubits=2)
observables = jnp.array([sp_inter_S.toarray()])
theta =jnp.array([J],dtype=complex)

H0_array = build_H0_operator(theta, observables, mat_basis)

In [157]:
##TODO: verify the differentiability of simulation to a state vector

S0S1_pauli = single_comm_superOp(observables[0], mat_basis)
eigs,eigvects = np.linalg.eig(S0S1_pauli)

#exponentiate_H0(theta, observables, mat_basis)
#rho_0 = jnp.ones(len(mat_basis))

grad_fn = jax.jacrev(lambda th: exponentiate_H0(th,observables,mat_basis)@eigvects[:,0],holomorphic=True)

grad_fn(theta)

Array([[ 0.00000000e+00+0.0000000e+00j],
       [-2.30812673e-17-2.9280605e-17j],
       [-1.10484595e-33-8.5067082e-34j],
       [ 1.30671799e-01+2.1313110e-01j],
       [ 2.30812673e-17+2.9280605e-17j],
       [ 0.00000000e+00+0.0000000e+00j],
       [ 1.79520973e-17+3.7646505e-17j],
       [ 2.19923468e-34+2.1410535e-33j],
       [ 1.09323381e-33+8.5067082e-34j],
       [-1.79520957e-17-3.7646498e-17j],
       [ 0.00000000e+00+0.0000000e+00j],
       [-1.30671829e-01-2.1313110e-01j],
       [-1.30671799e-01-2.1313111e-01j],
       [-2.19923468e-34-2.1023191e-33j],
       [ 1.30671799e-01+2.1313111e-01j],
       [ 0.00000000e+00+0.0000000e+00j]], dtype=complex64)

In [173]:
#S0S1_pauli = single_comm_superOp(observables[0], mat_basis)

#eigs,eigvects = np.linalg.eig(S0S1_pauli)


#S0S1_pauli@eigvects[:,0]

np.linalg.norm(jnp.subtract(jnp.squeeze(-1j*S0S1_pauli@exponentiate_H0(theta,observables,mat_basis)@eigvects[:,0]),jnp.squeeze(grad_fn(theta))))


5.575504e-08

In [174]:
np.linalg.norm(jnp.squeeze(-1j*S0S1_pauli@exponentiate_H0(theta,observables,mat_basis)@eigvects[:,0])-jnp.squeeze(grad_fn(theta)))

5.575504e-08

In [165]:
np.shape(-1j*S0S1_pauli@exponentiate_H0(theta,observables,mat_basis)@eigvects[:,0]-grad_fn(theta))

(16, 16)

In [163]:
-1j*S0S1_pauli@exponentiate_H0(theta,observables,mat_basis)@eigvects[:,0]

Array([ 0.0000000e+00+0.0000000e+00j, -2.3081267e-17-2.9280612e-17j,
       -1.1576580e-33-7.3885624e-34j,  1.3067180e-01+2.1313113e-01j,
        2.3081267e-17+2.9280612e-17j,  0.0000000e+00+0.0000000e+00j,
        1.7952096e-17+3.7646502e-17j,  4.5299663e-34+1.8881881e-33j,
        1.1576580e-33+7.3885624e-34j, -1.7952096e-17-3.7646502e-17j,
        0.0000000e+00+0.0000000e+00j, -1.3067180e-01-2.1313113e-01j,
       -1.3067180e-01-2.1313113e-01j, -4.5299663e-34-1.8881881e-33j,
        1.3067180e-01+2.1313113e-01j,  0.0000000e+00+0.0000000e+00j],      dtype=complex64)

In [156]:
np.linalg.norm(exponentiate_H0(theta,observables,mat_basis)-jax.scipy.linalg.expm(-1j*theta[0]*S0S1_pauli))

0.0

In [139]:
eigs,eigvects = np.linalg.eig(S0S1_pauli)

In [140]:
S0S1_pauli@eigvects[:,0]

Array([ 0.0000000e+00+0.j,  3.4345770e-17+0.j,  8.6666857e-34+0.j,
       -2.4999999e-01+0.j, -3.4345770e-17+0.j,  0.0000000e+00+0.j,
       -4.4158848e-17+0.j, -2.2148196e-33+0.j, -8.6666857e-34+0.j,
        4.4158848e-17+0.j,  0.0000000e+00+0.j,  2.4999999e-01+0.j,
        2.4999999e-01+0.j,  2.2148196e-33+0.j, -2.4999999e-01+0.j,
        0.0000000e+00+0.j], dtype=complex64)

In [None]:
###differentiation of the H0 operator....

grad_fn = jax.jacrev(lambda th: build_H0_operator(th, observables, mat_basis),holomorphic=True)
grads = grad_fn(theta)






In [119]:
grads.shape

(16, 16, 1)

In [44]:
####First target: definition of a differentiable function for Linbladian construction from  the input parameters.

####Coherent part of the Linbladian....
from basis_utils import Sx,Sy,Sz


def commutator(A,B):
    """
    Returns: commutators between matrices A and B 
    """

    return A@B-B@A


def GenH0_Ham(offset,B0,zeeman_scalars,Jcoups,gamma):
    """
    Returns: the zeroth order Hamiltonian considered for dynamical evolution in nthe simulations
    Args:
    offset: frequency offset for the spin Zeeman frequencies, in Hz
    B0: Strength of the magnetic field in Teslas
    zeeman_scalars: list of chemical shifts for spins, in ppm
    Jcoups: a vector which contains the coupling between the ith and jth spin at index N*i+j, N being the number of spins
    gamma: gyromagnetic ratio of the spins (an homonuclear case is assumed) 
    """

    #offset = -46681
    #B0 = 9.3933
    w0 = -gamma*B0
    o1 = 2*np.pi*offset
    Nspins = len(zeeman_scalars)

    Hamiltonian = jnp.zeros([2**(Nspins),2**(Nspins)],dtype=complex)

    for i in range(Nspins):
        w = o1+w0*zeeman_scalars[i]/1e6

        Hamiltonian+=w*of.get_sparse_operator(Sz(i),n_qubits=Nspins).toarray()
        for j in range(i+1,Nspins):
            idx = i*Nspins+j
            print("Value of J coupling:", Jcoups[idx])
            
            Hamiltonian+=2*jnp.pi*Jcoups[idx]*jnp.array(of.get_sparse_operator(Sx(i)*Sx(j)+Sy(i)*Sy(j)+Sz(i)*Sz(j),n_qubits=Nspins).toarray())
    
    return Hamiltonian



def InnProd_jax(Op1,Op2):
    """ 
    Op1 and Op2 are JAX arrays
    """

    return jnp.trace(Op1.conj().T@Op2)



def HamMatRep(H,basis,n_qubits=2):
    N = len(basis)

    Matrep= jnp.zeros([N,N],dtype=complex)
    for i in range(N):
        basis_i = jnp.array(of.get_sparse_operator(basis[i],n_qubits=n_qubits).toarray(),dtype=complex)
        for j in range(N):
            
            basis_j = jnp.array(of.get_sparse_operator(basis[j],n_qubits=n_qubits).toarray(),dtype=complex)

            #Matrep[i,j] = InnProd_jax(basis_i,commutator(H,basis_j))
            inn_prod = InnProd_jax(basis_i,commutator(H,basis_j))
            if np.abs(inn_prod)>0.0:
                print("Non vanishing matrix element",inn_prod)

                Matrep = Matrep.at[i,j].set(inn_prod)
                print("The effect on the matrix is:",Matrep[i,j] )


     
    return Matrep


def H0Ham_jax(offset,B0,zeeman_scalars,Jcoups,gamma,basis):
    """
    Returns: the zeroth order Hamiltonian considered for dynamical evolution in nthe simulations, in the Pauli basis
    Args:
    offset: frequency offset for the spin Zeeman frequencies, in Hz
    B0: Strength of the magnetic field in Teslas
    zeeman_scalars: list of chemical shifts for spins, in ppm
    Jcoups: matrix of size N x N, N being the number of spins, that encodes the scalar couplings between spins (in Hz)
    gamma: gyromagnetic ratio of the spins (an homonuclear case is assumed) 
    """
    H_of = GenH0_Ham(offset,B0,zeeman_scalars,Jcoups,gamma)
    
    n_qubits = len(zeeman_scalars)

    H_pauli = HamMatRep(H_of,basis,n_qubits=n_qubits)

    return H_pauli


def evolve_den_mat(offset,B0,zeeman_scalars,Jcoups,gamma,basis,tau,rho0):

    H_coh = H0Ham_jax(offset,B0,zeeman_scalars,Jcoups,gamma,basis)

    expmat = jax.scipy.linalg.expm(-1j*tau * H_coh)

    return expmat@rho0




In [5]:
#rho0 = jnp.ones(16)
#rho0[0]= 1
#rho0[5] =1

grad_fn = jax.jacrev(evolve_den_mat,0,holomorphic=True)
#grad_fn = jax.jvp(evolve_den_mat,0)

#made-up parameters....

offset = jnp.array(10.0, dtype=jnp.complex64)
B0 = jnp.array(9.3933, dtype=jnp.complex64)  # example value
zeeman_scalars = jnp.array([100.0, 200.0], dtype=jnp.complex64)  # example
Jcoups = jnp.array([2], dtype=jnp.complex64)
gammaF = jnp.array(1.0, dtype=jnp.complex64)  # example value
tau = jnp.array(2.0, dtype=jnp.complex64)
rho0 = jnp.ones(16, dtype=jnp.complex64)  # assuming it's a matrix, not a vector




jacobian = grad_fn(offset,B0,zeeman_scalars,Jcoups,gammaF,basis,tau,rho0)



In [94]:
basis[1]

-0.35355339059327373 [X1] +
-0.35355339059327373j [Y1]

In [None]:
####To verify that the framework is working, I will aim to differentiate with respect to one of the J couplings...
offset = jnp.array(10.0, dtype=jnp.complex64)
B0 = jnp.array(9.3933, dtype=jnp.complex64)  # example value
zeeman_scalars = jnp.array([100.0, 200.0], dtype=jnp.complex64)  # example
Jcoups = jnp.array([2], dtype=jnp.complex64)
gammaF = jnp.array(251814800, dtype=jnp.complex64)  # example value
tau = jnp.array(2.0, dtype=jnp.complex64)
rho0 = jnp.ones(16, dtype=jnp.complex64)  


def scalar_output(Jk_val,k):
    Jcoups_mod = Jcoups.at[k].set(Jk_val)
    rho = evolve_den_mat(offset, B0, zeeman_scalars, Jcoups_mod, gammaF, basis, tau, rho0)
    return rho

J_test = jnp.array(1.0,dtype=jnp.complex64)

grad_fn_Jcoup = jax.jacrev(scalar_output,holomorphic=True)
grad_Jk = grad_fn_Jcoup(J_test,0)



In [45]:

Jcoups = jnp.array([1.0], dtype=jnp.complex64)
zeeman_scalars = jnp.array([0.0, 0.0], dtype=jnp.complex64)
offset = jnp.array(0.0, dtype=jnp.complex64)

MatRep_SiSj = GenH0_Ham(offset,0.0,zeeman_scalars,Jcoups,0.0)


SiSj_paulibasis = HamMatRep(MatRep_SiSj,basis,n_qubits=2)

Value of J coupling: (1+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707964+0j)
Non vanishing matrix element (1.5707964+0j)
The effect on the matrix is: (1.5707964+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707964+0j)
Non vanishing matrix element (1.5707964+0j)
The effect on the matrix is: (1.5707964+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707964+0j)
Non vanishing matrix element (1.5707964+0j)
The effect on the matrix is: (1.5707964+0j)
Non vanishing matrix element (1.5707964+0j)
The effect on the matrix is: (1.5707964+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707964+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707964+0j)
Non vanishing matrix element (1.5707964+0j)
The effect on the matrix is: (1.5707964+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707

In [41]:
np.linalg.norm(SiSj_paulibasis)

0.0

In [34]:
Jcoups = jnp.array([1.0], dtype=jnp.complex64)
zeeman_scalars = jnp.array([0.0, 0.0], dtype=jnp.complex64)
offset = jnp.array(0.0, dtype=jnp.complex64)

GenH0_Ham(offset,0.0,zeeman_scalars,Jcoups,0.0)

Value of J coupling: (1+0j)


Array([[ 1.5707964+0.j,  0.       +0.j,  0.       +0.j,  0.       +0.j],
       [ 0.       +0.j, -1.5707964+0.j,  3.1415927+0.j,  0.       +0.j],
       [ 0.       +0.j,  3.1415927+0.j, -1.5707964+0.j,  0.       +0.j],
       [ 0.       +0.j,  0.       +0.j,  0.       +0.j,  1.5707964+0.j]],      dtype=complex64)

In [36]:
i=0
j=1


2*np.pi*of.get_sparse_operator(Sx(i)*Sx(j)+Sy(i)*Sy(j)+Sz(i)*Sz(j),n_qubits=Nspins).toarray()

array([[ 1.57079633+0.j,  0.        +0.j,  0.        +0.j,
         0.        +0.j],
       [ 0.        +0.j, -1.57079633+0.j,  3.14159265+0.j,
         0.        +0.j],
       [ 0.        +0.j,  3.14159265+0.j, -1.57079633+0.j,
         0.        +0.j],
       [ 0.        +0.j,  0.        +0.j,  0.        +0.j,
         1.57079633+0.j]])

In [24]:
for i in range(len(basis)):
    for j in range(len(basis)):
        test = InnProd_jax(of.get_sparse_operator(basis[j],n_qubits=Nspins).toarray(),commutator(MatRep_SiSj,of.get_sparse_operator(basis[i],n_qubits=Nspins).toarray()))
        if np.abs(test)>0.0:
            print("Non vanishing element:", test)

Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)


In [17]:
np.linalg.norm(MatRep_SiSj)

0.8660254037844386

In [11]:
rho0

Array([1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j,
       1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],      dtype=complex64)

In [10]:
##We compare the result with the anayltical result:


evolve_den_mat(offset,B0,zeeman_scalars,Jcoups,gammaF,basis,tau,rho0)


Array([1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j,
       1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],      dtype=complex64)

In [8]:
offset

Array(10.+0.j, dtype=complex64)