In [5]:
#Developing the framework to test the information content across different NMR protocols. 
#The core ideas are:
#1) Leverage our tools for Linbladian construction in the Pauli basis. Caveat: it has a dependency on Spinach basis output. We need to be careful that input parameters correspond to
# those used in the Spinach simulation
# 2) The NMR protocols are defined as classmethods that can be differentiable through JAX


import sys
sys.path.append('../linbladian_utils/')
import numpy as np
import scipy.io as spio
from scipy.linalg import expm
from matplotlib import pyplot as plt
import sys 
import openfermion as of
import pandas as pd 
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from functools import partial
#from jax.experimental import sparse
from jax import lax
from basis_utils import Sx,Sy,Sz
from basis_utils import MatRepLib, S_plus, S_minus

from analytical_fit import  GetRelManySpins, get_chemical_shifts, Get_Det_And_Rates
from basis_utils import read_spinach_info, build_list_ISTs, NormalizeBasis, build_symbolic_list_ISTs, InnProd
#from simulation_utils import Hellinger_2D, GenNOESYSpectrum


from jax_linb_utils import flat_upper_triang_mat, get_H0_ops, build_time_evol_op
import scipy

# Simple test case: DFG

In [2]:
#Auxiliary functions to verify correctness 

def HamMatRep(H,basis,n_qubits=2):
    N = len(basis)

    Matrep= np.zeros([N,N],dtype=complex)
    for i in range(N):
        for j in range(N):
            Matrep[i,j] = InnProd(basis[i],of.commutator(H,basis[j]),n_qubits=n_qubits)

    
    return Matrep

def GenH0_Ham(offset,B0,zeeman_scalars,Jcoups,gamma):
    """
    Returns: the zeroth order Hamiltonian considered for dynamical evolution in nthe simulations, in OpenFermion format
    Args:
    offset: frequency offset for the spin Zeeman frequencies, in Hz
    B0: Strength of the magnetic field in Teslas
    zeeman_scalars: list of chemical shifts for spins, in ppm
    Jcoups: matrix of size N x N, N being the number of spins, that encodes the scalar couplings between spins (in Hz)
    gamma: gyromagnetic ratio of the spins (an homonuclear case is assumed) 
    """


    #offset = -46681
    #B0 = 9.3933
    w0 = -gamma*B0
    o1 = 2*np.pi*offset
    Nspins = len(zeeman_scalars)

    Hamiltonian = of.QubitOperator()

    for i in range(Nspins):
        w = o1+w0*zeeman_scalars[i]/1e6
        Hamiltonian+=w*Sz(i)
        for j in range(i+1,Nspins):

            Hamiltonian+=2*np.pi*Jcoups[i,j]*(Sx(i)*Sx(j)+Sy(i)*Sy(j)+Sz(i)*Sz(j))
    
    return Hamiltonian


####Just for the purposes of testing differentiation...
def build_ref_prop(time,offset,B0,zeeman_scalars,Jcoups,freqs,coords,tc,gamma,Normbasis):

    H0_of = GenH0_Ham(offset,B0,zeeman_scalars,Jcoups,gamma)
    H0 = HamMatRep(H0_of,Normbasis,n_qubits=2)

    K2,K1,K0,R_analytical = GetRelManySpins(2*np.pi*freqs,coords,tc,gamma,Normbasis)

    #return H0,R_analytical,scipy.linalg.expm(-1j*time*H0+time*R_analytical)
    return scipy.linalg.expm(-1j*time*H0+time*R_analytical)


# Generation of the time-evolution operator



In [3]:
##Input from spinach to generate the Pauli basis
text="""1      (0,0)   (0,0)   
  2      (0,0)   (1,1)   
  3      (0,0)   (1,0)   
  4      (0,0)   (1,-1)  
  5      (1,1)   (0,0)   
  6      (1,1)   (1,1)   
  7      (1,1)   (1,0)   
  8      (1,1)   (1,-1)  
  9      (1,0)   (0,0)   
  10     (1,0)   (1,1)   
  11     (1,0)   (1,0)   
  12     (1,0)   (1,-1)  
  13     (1,-1)  (0,0)   
  14     (1,-1)  (1,1)   
  15     (1,-1)  (1,0)   
  16     (1,-1)  (1,-1)  
"""

data = read_spinach_info(text)

basis = build_list_ISTs(data)
prefacts,Symb_basis = build_symbolic_list_ISTs(data)

#Normbasis = NormalizeBasis(basis,n_qubits=4,checkOrth=True) I have verified the orthonormalization of the basis
Normbasis = NormalizeBasis(basis,n_qubits=4,checkOrth=False)
Normbasis = np.array(Normbasis)




In [None]:
###The minimal number and necessary parameters for simulation...
Nspins = 2
gammaF = 251814800.0
coord1 = jnp.array([-0.0551,-1.2087,-1.6523],dtype=jnp.float64)*1e-10
coord2 = jnp.array([-0.8604 ,-2.3200 ,-0.0624],dtype=jnp.float64)*1e-10
coords = jnp.array([coord1,coord2])

w1 = -376417768.6316 
w2 = -376411775.1523 
freqs = jnp.array([w1,w2])
tc = 0.5255e-9
B0 = 9.3933
freq_offset = -46681

#These J couplings can be retrieved from Spinach in matrix form, for this simple example we can input it "by hand"
Jcoup = 238.0633
Jcouplings = np.array([[0.0, Jcoup],
                        [Jcoup,0.0]])

####We need to flatten the Jcoupling array for JAX functionality...
flat_Js = flat_upper_triang_mat(Jcouplings)

mat_basis = []

for i in range(len(basis)):
    sp_op = of.get_sparse_operator(Normbasis[i],n_qubits=Nspins)
    mat_basis.append(sp_op.toarray())

mat_basis=jnp.array(mat_basis)
time=0.1

coh_observables = get_H0_ops(Nspins)



#H0_jax, R_jax  = build_time_evol_op(time,freqs,flat_Js,B0,freq_offset,tc,coords,gammaF,coh_observables,mat_basis,Nspins)
jax_evol_op = build_time_evol_op(time,freqs,flat_Js,B0,freq_offset,tc,coords,gammaF,coh_observables,mat_basis,Nspins)

###reference 
zeeman_scalar_1 = -113.8796
zeeman_scalar_2 = -129.8002
zeeman_scalars = np.array([zeeman_scalar_1,zeeman_scalar_2])


#H0_ref, R_ref = build_ref_prop(time,freq_offset,B0,zeeman_scalars,Jcouplings,freqs,coords,tc,gammaF,Normbasis)
ref_evol_op = build_ref_prop(time,freq_offset,B0,zeeman_scalars,Jcouplings,freqs,coords,tc,gammaF,Normbasis)



K2 type contributions finished
K1 type contributions finished
K0 type contributions finished


In [118]:
#### Verification of differentiation with respect to J coupling#####
def fin_deffJ_Linb_evol(Jcoup,deltaJ,time,offset,B0,zeeman_scalars,freqs,coords,tc,gamma,Normbasis):

    Jcoups = np.array([[0.0, Jcoup+deltaJ],
                        [Jcoup+deltaJ,0.0]])
    

    Fin_prop = build_ref_prop(time,offset,B0,zeeman_scalars,Jcoups,freqs,coords,tc,gamma,Normbasis)

    Jcoups = np.array([[0.0, Jcoup],
                        [Jcoup,0.0]])

    Init_prop = build_ref_prop(time,offset,B0,zeeman_scalars,Jcoups,freqs,coords,tc,gamma,Normbasis)


    return (Fin_prop-Init_prop)/deltaJ


num_dev_J = fin_deffJ_Linb_evol(Jcoup,1e-6,time,freq_offset,B0,zeeman_scalars,freqs,coords,tc,gammaF,Normbasis)

grad_fn = jax.jacrev(lambda J: build_time_evol_op(time,freqs,jnp.array([J]),B0,freq_offset,tc,coords,gammaF,coh_observables,mat_basis,Nspins),holomorphic=True)

grad_J = grad_fn(238.0633+0.0*1j)


print("Norm of JAX derivative with respect to J:",np.linalg.norm(grad_J))

print("Norm of difference between JAX derivative and finite-difference one:", np.linalg.norm(grad_J-num_dev_J))




K2 type contributions finished
K1 type contributions finished
K0 type contributions finished
K2 type contributions finished
K1 type contributions finished
K0 type contributions finished
Norm of JAX derivative with respect to J: 0.21635888730804725
Norm of difference between JAX derivative and finite-difference one: 8.206686939728381e-07


# Modular implementation of NMR protocols

In [192]:
import jax
import jax.numpy as jnp
from dataclasses import dataclass, field
from functools import partial
from abc import ABC, abstractmethod

####Dataclass to handle the parameters

@partial(jax.tree_util.register_dataclass,
         data_fields=['freqs', 'Jcouplings', 'B0', 'freq_offset', 'tc', 'coords', 'gamma'],
         meta_fields=['Nspins']
)
@dataclass
class spin_params:
    """
    Class that contains all the parameters needed to define the spin system to simulate.
    The default parameters here correspond to a minimal instance with two fluorine nuclei.
    """
    freqs: jax.Array = field(default_factory=lambda: jnp.array([-3.76417769e+08, -3.76411775e+08]))
    Jcouplings: jax.Array = field(default_factory=lambda: jnp.array([238.0633]))
    B0: float = 9.3933
    freq_offset: float = -46681
    tc: float = 5.255e-10
    coords: jax.Array = field(default_factory=lambda: jnp.array([
        [-5.5100e-12, -1.2087e-10, -1.6523e-10],
        [-8.6040e-11, -2.3200e-10, -6.2400e-12]
    ]))
    gamma: float = 251814800.0
    Nspins: int = 2

#data class to define the parameters of time grid in two dimensions, the initial state and the observable to measure in the pauli basis...
@partial(jax.tree_util.register_dataclass,
         data_fields=['rho0', 'coil'],
         meta_fields=['Tpts1','Tpts2','deltaT1','deltaT2']
)
@dataclass ###TODO: the initial density matrix as well as the observable in the pauli basis can also be built inside the NMR class itself
class sampling_params:
    """
    Class  that contains parameters of time resolution, total propagation time, initial density matrix and the observable to sample 
    """
    deltaT1: float  
    deltaT2: float
    Tpts1: int
    Tpts2: int
    rho0: jax.Array
    coil: jax.Array

####Data class to store the specific parameters of protocols..
@partial(jax.tree_util.register_dataclass,
         data_fields=['Tmix'],
         meta_fields=['Lx','Ly']
)
@dataclass ###TODO: the initial density matrix as well as the observable in the pauli basis can also be built inside the NMR class itself
class NOESY_params:
    """
    Class  that contains parameters of time resolution, total propagation time, initial density matrix and the observable to sample 
    """
    Tmix: float  
    Lx: jax.Array
    Ly: jax.Array



####Abstract class. Template for the definition of protocols
class NMR_protocol(ABC):
    """
    Abstract class for NMR protocols. Required input: pauli basis 
    """
    def __init__(self, basis: jax.Array, prot_pars,samp_pars: sampling_params,params: spin_params):

        self.spinparams = params
        ###Initialize 
        self.basis = basis
        self.coh_ops = get_H0_ops(self.spinparams.Nspins) #TODO: import get_H0_ops from relevant file
        
        self.samp_params = samp_pars
        self.prot_params = prot_pars
        
          
    @abstractmethod
    def get_FID(self):
        """
        In self, we can include objects that include all the information to run the protocol
        """
        pass


####Particular definitions for protocols...
class NOESY(NMR_protocol):

    def get_FID(self):
        #build_time_evol_op(time,freqs,Jcouplings,B0,freq_offset,tc,coords,gamma,coh_observables,basis,Nspins)

        L_dt1 = build_time_evol_op(self.samp_params.deltaT1,self.spinparams.freqs,self.spinparams.Jcouplings,
                                   self.spinparams.B0,self.spinparams.freq_offset,self.spinparams.tc,
                                   self.spinparams.coords,self.spinparams.gamma,
                                   self.coh_ops,self.basis,self.spinparams.Nspins)
        
        L_dt2 = build_time_evol_op(self.samp_params.deltaT2,self.spinparams.freqs,self.spinparams.Jcouplings,
                                   self.spinparams.B0,self.spinparams.freq_offset,self.spinparams.tc,
                                   self.spinparams.coords,self.spinparams.gamma,
                                   self.coh_ops,self.basis,self.spinparams.Nspins)
        
        pulse_mix = build_time_evol_op(self.prot_params.Tmix,self.spinparams.freqs,self.spinparams.Jcouplings,
                                   self.spinparams.B0,self.spinparams.freq_offset,self.spinparams.tc,
                                   self.spinparams.coords,self.spinparams.gamma,
                                   self.coh_ops,self.basis,self.spinparams.Nspins)


        pulse_90x = expm(-1j*self.prot_params.Lx*np.pi/2)
        pulse_90y = expm(-1j*self.prot_params.Ly*np.pi/2)
        pulse_90mx = expm(1j*self.prot_params.Lx*np.pi/2)
        pulse_90my = expm(1j*self.prot_params.Ly*np.pi/2)


        #First 90x pulse:
        
        rho_t = pulse_90x@self.samp_params.rho0

        rho_stack = []
        rho_stack.append(rho_t)

        rho_temp = jnp.copy(rho_t)
        for i in range(1,self.samp_params.Tpts1):
            rho_temp = L_dt1@rho_temp
            rho_stack.append(rho_temp)


        rho_stack1_1 = []
        rho_stack1_2 = []
        rho_stack1_3 = []
        rho_stack1_4 = []

        for i in range(self.samp_params.Tpts1):
            rho_stack1_1.append(pulse_90y@pulse_mix@pulse_90x@rho_stack[i])
            rho_stack1_2.append(pulse_90y@pulse_mix@pulse_90y@rho_stack[i])
            rho_stack1_3.append(pulse_90y@pulse_mix@pulse_90mx@rho_stack[i])
            rho_stack1_4.append(pulse_90y@pulse_mix@pulse_90my@rho_stack[i])


        fid_temp_1 = jnp.zeros([self.samp_params.Tpts2,self.samp_params.Tpts1],dtype=complex)
        fid_temp_2 = jnp.zeros([self.samp_params.Tpts2,self.samp_params.Tpts1],dtype=complex)
        fid_temp_3 = jnp.zeros([self.samp_params.Tpts2,self.samp_params.Tpts1],dtype=complex)
        fid_temp_4 = jnp.zeros([self.samp_params.Tpts2,self.samp_params.Tpts1],dtype=complex)

        for i in range(self.samp_params.Tpts1):
            rho1 = rho_stack1_1[i]
            rho2 = rho_stack1_2[i]
            rho3 = rho_stack1_3[i]
            rho4 = rho_stack1_4[i]

            for j in range(self.samp_params.Tpts2):
                fid_temp_1[j,i] = jnp.dot(self.samp_params.coil,rho1)
                rho1 = L_dt2@rho1

                fid_temp_2[j,i] = jnp.dot(self.samp_params.coil,rho2)
                rho2 = L_dt2@rho2

                fid_temp_3[j,i] = jnp.dot(self.samp_params.coil,rho3)
                rho3 = L_dt2@rho3

                fid_temp_4[j,i] = jnp.dot(self.samp_params.coil,rho4)
                rho4 = L_dt2@rho4
        
        return fid_temp_1, fid_temp_2, fid_temp_3, fid_temp_4




In [193]:
#Getting the NOESY FID for our minimal example...

#matrix form of Pauli basis...
mat_basis = []

for i in range(len(basis)):
    sp_op = of.get_sparse_operator(Normbasis[i],n_qubits=Nspins)
    mat_basis.append(sp_op.toarray())

mat_basis=jnp.array(mat_basis)

###TODO: test the NOESY simulations....

###default spin parameters
spin_pars = spin_params()
samp_pars = sampling_params(deltaT1=0.01,deltaT2=0.01,Tpts1=1024,Tpts2=1024,rho0=jnp.ones(16),coil=jnp.ones(16))
NOESY_pars = NOESY_params(Tmix=0.2,Lx=jnp.ones([16,16]),Ly=jnp.ones([16,16]))


test_noesy = NOESY(mat_basis,NOESY_pars,samp_pars,spin_pars)



In [194]:
test_noesy.get_FID()

TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [181]:
test_samp_pars = sampling_params(Tpts1=10,Tpts2=10,rho0=jnp.array([1,2]),coil=jnp.array([2,3,4]))

NOESY_test = NOESY(mat_basis,test_samp_pars)

In [182]:
NOESY_test.get_FID()

Got here
Spin params are: [-3.76417769e+08 -3.76411775e+08]


# Deprecated 

In [152]:
####Render basis array in matrix form:

mat_basis = []

for i in range(len(basis)):
    sp_op = of.get_sparse_operator(basis[i],n_qubits=2)
    mat_basis.append(sp_op.toarray())

mat_basis=jnp.array(mat_basis)

J=1.1
inter_S = Sz(0)*Sz(1)+Sx(0)*Sx(1)+Sy(0)*Sy(1)

sp_inter_S = of.get_sparse_operator(inter_S,n_qubits=2)
observables = jnp.array([sp_inter_S.toarray()])
theta =jnp.array([J],dtype=complex)

H0_array = build_H0_operator(theta, observables, mat_basis)

In [157]:
##TODO: verify the differentiability of simulation to a state vector

S0S1_pauli = single_comm_superOp(observables[0], mat_basis)
eigs,eigvects = np.linalg.eig(S0S1_pauli)

#exponentiate_H0(theta, observables, mat_basis)
#rho_0 = jnp.ones(len(mat_basis))

grad_fn = jax.jacrev(lambda th: exponentiate_H0(th,observables,mat_basis)@eigvects[:,0],holomorphic=True)

grad_fn(theta)

Array([[ 0.00000000e+00+0.0000000e+00j],
       [-2.30812673e-17-2.9280605e-17j],
       [-1.10484595e-33-8.5067082e-34j],
       [ 1.30671799e-01+2.1313110e-01j],
       [ 2.30812673e-17+2.9280605e-17j],
       [ 0.00000000e+00+0.0000000e+00j],
       [ 1.79520973e-17+3.7646505e-17j],
       [ 2.19923468e-34+2.1410535e-33j],
       [ 1.09323381e-33+8.5067082e-34j],
       [-1.79520957e-17-3.7646498e-17j],
       [ 0.00000000e+00+0.0000000e+00j],
       [-1.30671829e-01-2.1313110e-01j],
       [-1.30671799e-01-2.1313111e-01j],
       [-2.19923468e-34-2.1023191e-33j],
       [ 1.30671799e-01+2.1313111e-01j],
       [ 0.00000000e+00+0.0000000e+00j]], dtype=complex64)

In [173]:
#S0S1_pauli = single_comm_superOp(observables[0], mat_basis)

#eigs,eigvects = np.linalg.eig(S0S1_pauli)


#S0S1_pauli@eigvects[:,0]

np.linalg.norm(jnp.subtract(jnp.squeeze(-1j*S0S1_pauli@exponentiate_H0(theta,observables,mat_basis)@eigvects[:,0]),jnp.squeeze(grad_fn(theta))))


5.575504e-08

In [174]:
np.linalg.norm(jnp.squeeze(-1j*S0S1_pauli@exponentiate_H0(theta,observables,mat_basis)@eigvects[:,0])-jnp.squeeze(grad_fn(theta)))

5.575504e-08

In [165]:
np.shape(-1j*S0S1_pauli@exponentiate_H0(theta,observables,mat_basis)@eigvects[:,0]-grad_fn(theta))

(16, 16)

In [163]:
-1j*S0S1_pauli@exponentiate_H0(theta,observables,mat_basis)@eigvects[:,0]

Array([ 0.0000000e+00+0.0000000e+00j, -2.3081267e-17-2.9280612e-17j,
       -1.1576580e-33-7.3885624e-34j,  1.3067180e-01+2.1313113e-01j,
        2.3081267e-17+2.9280612e-17j,  0.0000000e+00+0.0000000e+00j,
        1.7952096e-17+3.7646502e-17j,  4.5299663e-34+1.8881881e-33j,
        1.1576580e-33+7.3885624e-34j, -1.7952096e-17-3.7646502e-17j,
        0.0000000e+00+0.0000000e+00j, -1.3067180e-01-2.1313113e-01j,
       -1.3067180e-01-2.1313113e-01j, -4.5299663e-34-1.8881881e-33j,
        1.3067180e-01+2.1313113e-01j,  0.0000000e+00+0.0000000e+00j],      dtype=complex64)

In [156]:
np.linalg.norm(exponentiate_H0(theta,observables,mat_basis)-jax.scipy.linalg.expm(-1j*theta[0]*S0S1_pauli))

0.0

In [139]:
eigs,eigvects = np.linalg.eig(S0S1_pauli)

In [140]:
S0S1_pauli@eigvects[:,0]

Array([ 0.0000000e+00+0.j,  3.4345770e-17+0.j,  8.6666857e-34+0.j,
       -2.4999999e-01+0.j, -3.4345770e-17+0.j,  0.0000000e+00+0.j,
       -4.4158848e-17+0.j, -2.2148196e-33+0.j, -8.6666857e-34+0.j,
        4.4158848e-17+0.j,  0.0000000e+00+0.j,  2.4999999e-01+0.j,
        2.4999999e-01+0.j,  2.2148196e-33+0.j, -2.4999999e-01+0.j,
        0.0000000e+00+0.j], dtype=complex64)

In [None]:
###differentiation of the H0 operator....

grad_fn = jax.jacrev(lambda th: build_H0_operator(th, observables, mat_basis),holomorphic=True)
grads = grad_fn(theta)






In [119]:
grads.shape

(16, 16, 1)

In [44]:
####First target: definition of a differentiable function for Linbladian construction from  the input parameters.

####Coherent part of the Linbladian....
from basis_utils import Sx,Sy,Sz


def commutator(A,B):
    """
    Returns: commutators between matrices A and B 
    """

    return A@B-B@A


def GenH0_Ham(offset,B0,zeeman_scalars,Jcoups,gamma):
    """
    Returns: the zeroth order Hamiltonian considered for dynamical evolution in nthe simulations
    Args:
    offset: frequency offset for the spin Zeeman frequencies, in Hz
    B0: Strength of the magnetic field in Teslas
    zeeman_scalars: list of chemical shifts for spins, in ppm
    Jcoups: a vector which contains the coupling between the ith and jth spin at index N*i+j, N being the number of spins
    gamma: gyromagnetic ratio of the spins (an homonuclear case is assumed) 
    """

    #offset = -46681
    #B0 = 9.3933
    w0 = -gamma*B0
    o1 = 2*np.pi*offset
    Nspins = len(zeeman_scalars)

    Hamiltonian = jnp.zeros([2**(Nspins),2**(Nspins)],dtype=complex)

    for i in range(Nspins):
        w = o1+w0*zeeman_scalars[i]/1e6

        Hamiltonian+=w*of.get_sparse_operator(Sz(i),n_qubits=Nspins).toarray()
        for j in range(i+1,Nspins):
            idx = i*Nspins+j
            print("Value of J coupling:", Jcoups[idx])
            
            Hamiltonian+=2*jnp.pi*Jcoups[idx]*jnp.array(of.get_sparse_operator(Sx(i)*Sx(j)+Sy(i)*Sy(j)+Sz(i)*Sz(j),n_qubits=Nspins).toarray())
    
    return Hamiltonian



def InnProd_jax(Op1,Op2):
    """ 
    Op1 and Op2 are JAX arrays
    """

    return jnp.trace(Op1.conj().T@Op2)



def HamMatRep(H,basis,n_qubits=2):
    N = len(basis)

    Matrep= jnp.zeros([N,N],dtype=complex)
    for i in range(N):
        basis_i = jnp.array(of.get_sparse_operator(basis[i],n_qubits=n_qubits).toarray(),dtype=complex)
        for j in range(N):
            
            basis_j = jnp.array(of.get_sparse_operator(basis[j],n_qubits=n_qubits).toarray(),dtype=complex)

            #Matrep[i,j] = InnProd_jax(basis_i,commutator(H,basis_j))
            inn_prod = InnProd_jax(basis_i,commutator(H,basis_j))
            if np.abs(inn_prod)>0.0:
                print("Non vanishing matrix element",inn_prod)

                Matrep = Matrep.at[i,j].set(inn_prod)
                print("The effect on the matrix is:",Matrep[i,j] )


     
    return Matrep


def H0Ham_jax(offset,B0,zeeman_scalars,Jcoups,gamma,basis):
    """
    Returns: the zeroth order Hamiltonian considered for dynamical evolution in nthe simulations, in the Pauli basis
    Args:
    offset: frequency offset for the spin Zeeman frequencies, in Hz
    B0: Strength of the magnetic field in Teslas
    zeeman_scalars: list of chemical shifts for spins, in ppm
    Jcoups: matrix of size N x N, N being the number of spins, that encodes the scalar couplings between spins (in Hz)
    gamma: gyromagnetic ratio of the spins (an homonuclear case is assumed) 
    """
    H_of = GenH0_Ham(offset,B0,zeeman_scalars,Jcoups,gamma)
    
    n_qubits = len(zeeman_scalars)

    H_pauli = HamMatRep(H_of,basis,n_qubits=n_qubits)

    return H_pauli


def evolve_den_mat(offset,B0,zeeman_scalars,Jcoups,gamma,basis,tau,rho0):

    H_coh = H0Ham_jax(offset,B0,zeeman_scalars,Jcoups,gamma,basis)

    expmat = jax.scipy.linalg.expm(-1j*tau * H_coh)

    return expmat@rho0




In [5]:
#rho0 = jnp.ones(16)
#rho0[0]= 1
#rho0[5] =1

grad_fn = jax.jacrev(evolve_den_mat,0,holomorphic=True)
#grad_fn = jax.jvp(evolve_den_mat,0)

#made-up parameters....

offset = jnp.array(10.0, dtype=jnp.complex64)
B0 = jnp.array(9.3933, dtype=jnp.complex64)  # example value
zeeman_scalars = jnp.array([100.0, 200.0], dtype=jnp.complex64)  # example
Jcoups = jnp.array([2], dtype=jnp.complex64)
gammaF = jnp.array(1.0, dtype=jnp.complex64)  # example value
tau = jnp.array(2.0, dtype=jnp.complex64)
rho0 = jnp.ones(16, dtype=jnp.complex64)  # assuming it's a matrix, not a vector




jacobian = grad_fn(offset,B0,zeeman_scalars,Jcoups,gammaF,basis,tau,rho0)



In [94]:
basis[1]

-0.35355339059327373 [X1] +
-0.35355339059327373j [Y1]

In [None]:
####To verify that the framework is working, I will aim to differentiate with respect to one of the J couplings...
offset = jnp.array(10.0, dtype=jnp.complex64)
B0 = jnp.array(9.3933, dtype=jnp.complex64)  # example value
zeeman_scalars = jnp.array([100.0, 200.0], dtype=jnp.complex64)  # example
Jcoups = jnp.array([2], dtype=jnp.complex64)
gammaF = jnp.array(251814800, dtype=jnp.complex64)  # example value
tau = jnp.array(2.0, dtype=jnp.complex64)
rho0 = jnp.ones(16, dtype=jnp.complex64)  


def scalar_output(Jk_val,k):
    Jcoups_mod = Jcoups.at[k].set(Jk_val)
    rho = evolve_den_mat(offset, B0, zeeman_scalars, Jcoups_mod, gammaF, basis, tau, rho0)
    return rho

J_test = jnp.array(1.0,dtype=jnp.complex64)

grad_fn_Jcoup = jax.jacrev(scalar_output,holomorphic=True)
grad_Jk = grad_fn_Jcoup(J_test,0)



In [45]:

Jcoups = jnp.array([1.0], dtype=jnp.complex64)
zeeman_scalars = jnp.array([0.0, 0.0], dtype=jnp.complex64)
offset = jnp.array(0.0, dtype=jnp.complex64)

MatRep_SiSj = GenH0_Ham(offset,0.0,zeeman_scalars,Jcoups,0.0)


SiSj_paulibasis = HamMatRep(MatRep_SiSj,basis,n_qubits=2)

Value of J coupling: (1+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707964+0j)
Non vanishing matrix element (1.5707964+0j)
The effect on the matrix is: (1.5707964+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707964+0j)
Non vanishing matrix element (1.5707964+0j)
The effect on the matrix is: (1.5707964+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707964+0j)
Non vanishing matrix element (1.5707964+0j)
The effect on the matrix is: (1.5707964+0j)
Non vanishing matrix element (1.5707964+0j)
The effect on the matrix is: (1.5707964+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707964+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707964+0j)
Non vanishing matrix element (1.5707964+0j)
The effect on the matrix is: (1.5707964+0j)
Non vanishing matrix element (-1.5707964+0j)
The effect on the matrix is: (-1.5707

In [41]:
np.linalg.norm(SiSj_paulibasis)

0.0

In [34]:
Jcoups = jnp.array([1.0], dtype=jnp.complex64)
zeeman_scalars = jnp.array([0.0, 0.0], dtype=jnp.complex64)
offset = jnp.array(0.0, dtype=jnp.complex64)

GenH0_Ham(offset,0.0,zeeman_scalars,Jcoups,0.0)

Value of J coupling: (1+0j)


Array([[ 1.5707964+0.j,  0.       +0.j,  0.       +0.j,  0.       +0.j],
       [ 0.       +0.j, -1.5707964+0.j,  3.1415927+0.j,  0.       +0.j],
       [ 0.       +0.j,  3.1415927+0.j, -1.5707964+0.j,  0.       +0.j],
       [ 0.       +0.j,  0.       +0.j,  0.       +0.j,  1.5707964+0.j]],      dtype=complex64)

In [36]:
i=0
j=1


2*np.pi*of.get_sparse_operator(Sx(i)*Sx(j)+Sy(i)*Sy(j)+Sz(i)*Sz(j),n_qubits=Nspins).toarray()

array([[ 1.57079633+0.j,  0.        +0.j,  0.        +0.j,
         0.        +0.j],
       [ 0.        +0.j, -1.57079633+0.j,  3.14159265+0.j,
         0.        +0.j],
       [ 0.        +0.j,  3.14159265+0.j, -1.57079633+0.j,
         0.        +0.j],
       [ 0.        +0.j,  0.        +0.j,  0.        +0.j,
         1.57079633+0.j]])

In [24]:
for i in range(len(basis)):
    for j in range(len(basis)):
        test = InnProd_jax(of.get_sparse_operator(basis[j],n_qubits=Nspins).toarray(),commutator(MatRep_SiSj,of.get_sparse_operator(basis[i],n_qubits=Nspins).toarray()))
        if np.abs(test)>0.0:
            print("Non vanishing element:", test)

Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)
Non vanishing element: (0.25+0j)
Non vanishing element: (-0.25+0j)


In [17]:
np.linalg.norm(MatRep_SiSj)

0.8660254037844386

In [11]:
rho0

Array([1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j,
       1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],      dtype=complex64)

In [10]:
##We compare the result with the anayltical result:


evolve_den_mat(offset,B0,zeeman_scalars,Jcoups,gammaF,basis,tau,rho0)


Array([1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j,
       1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j, 1.+0.j],      dtype=complex64)

In [8]:
offset

Array(10.+0.j, dtype=complex64)