In [1]:
from dataclasses import dataclass, field, asdict
from scipy.optimize import minimize
from tqdm import tqdm
import functools
import itertools
import numpy as np
import pandas as pd
import scipy
from src.angular_momentum import generate_spin_matrices

# Setup

- $\ket{ \psi_0 } = \ket{0}_S \bigotimes \ket{0}_A$

In [2]:
@dataclass
class Settings:
    dim_s: int
    dim_a: int
    # n_trials: int
    true_delta_s: float
    probability_error_tolerance: float
    system_jx: np.array = field(init=False)
    system_jz: np.array = field(init=False)
    ancilla_jx: np.array = field(init=False)
    ancilla_jz: np.array = field(init=False)
    initial_state: np.array = field(init=False)

    def __post_init__(self):
        self.system_jx, self.system_jz = generate_spin_matrices(dim=self.dim_s)
        self.ancilla_jx, self.ancilla_jz = generate_spin_matrices(dim=self.dim_a)
        self.initial_state = np.zeros(self.dim_s * self.dim_a)
        self.initial_state[0] = 1
        self.initial_state = np.outer(self.initial_state, self.initial_state)

    def generate_hamiltonian(
        self,
        j_s: float,
        u_s: float,
        delta_s: float,
        j_a: float,
        u_a: float,
        delta_a: float,
        alpha_xx: float,
        alpha_xz: float,
        alpha_zx: float,
        alpha_zz: float,
    ) -> np.array:
        system_hamiltonian = np.kron(
            -1 * j_s * self.system_jx + u_s * self.system_jz @ self.system_jz + delta_s * self.system_jz,
            np.divide(np.eye(self.dim_a), self.dim_a)
        )
        ancillary_hamiltonian = np.kron(
            np.divide(np.eye(self.dim_s), self.dim_s),
            -1 * j_a * self.ancilla_jx + u_a * self.ancilla_jz @ self.ancilla_jz + delta_a * self.ancilla_jz
        )
        interaction_hamiltonian = functools.reduce(
            lambda x, y: x + y,
            [
                alpha_xx * np.kron(self.system_jx, self.ancilla_jx),
                alpha_xz * np.kron(self.system_jx, self.ancilla_jz),
                alpha_zx * np.kron(self.system_jz, self.ancilla_jx),
                alpha_zz * np.kron(self.system_jz, self.ancilla_jz),
            ]
        )
        return system_hamiltonian + ancillary_hamiltonian + interaction_hamiltonian

    def trace_out_ancillary(self, state: np.array):
        return np.trace(
            np.array(state).reshape(self.dim_a, self.dim_s, self.dim_a, self.dim_s),
            axis1=1,
            axis2=3
        )
    
    @staticmethod
    def calculate_final_state(
        hamiltonian: np.array,
        initial_state: np.ndarray,
        t: float = 0,
    ) -> np.array:
        return scipy.linalg.expm(-1j * t * hamiltonian) @ initial_state
    
    def calculate_probabilities(self, final_state: np.array) -> np.array:
        system_state = self.trace_out_ancillary(state=final_state)
        probabilities = [np.abs(x)**2 for x in np.diag(system_state)]
        # assert np.abs(np.sum(probabilities) - 1) < settings.probability_error_tolerance, f"The observed probabilities {probabilities} are unphysical"
        return probabilities
    
    def loss_function(self, x) -> float:
        j_s = x[0]
        u_s = x[1]
        j_a = x[2]
        u_a = x[3]
        delta_a = x[4]
        alpha_xx = x[5]
        alpha_xz = x[6]
        alpha_zx = x[7]
        alpha_zz = x[8]
        time = x[9]
        epsilon = x[10] # Used to approximate the numerical derivative
        hamiltonian = self.generate_hamiltonian(
            j_s=j_s,
            u_s=u_s,
            delta_s=self.true_delta_s + epsilon,
            j_a=j_a,
            u_a=u_a,
            delta_a=delta_a,
            alpha_xx=alpha_xx,
            alpha_xz=alpha_xz,
            alpha_zx=alpha_zx,
            alpha_zz=alpha_zz,
        )
        final_state = self.calculate_final_state(hamiltonian, initial_state=self.initial_state, t=time)
        probabilities_upper = self.calculate_probabilities(final_state)
        hamiltonian = self.generate_hamiltonian(
            j_s=j_s,
            u_s=u_s,
            delta_s=self.true_delta_s - epsilon,
            j_a=j_a,
            u_a=u_a,
            delta_a=delta_a,
            alpha_xx=alpha_xx,
            alpha_xz=alpha_xz,
            alpha_zx=alpha_zx,
            alpha_zz=alpha_zz,
        )
        final_state = self.calculate_final_state(hamiltonian, initial_state=self.initial_state, t=time)
        probabilities_lower = self.calculate_probabilities(final_state)
        probability_derivative_cost = -1 * np.divide(
            np.sum(np.abs(np.subtract(probabilities_upper, probabilities_lower))),
            epsilon
        ) # Maximize delta info obtained
        epsilon_cost = (1-epsilon)**2  # Avoid low epsilons to avoid numerical instability. High epsilons are naturally avoided by the derivative cost
        time_cost = time**2 # Avoids unnecessarily large times
        coefficients_l1_cost = np.sum([np.abs(x) for x in [j_s, u_s, j_a, u_a, delta_a, alpha_xx, alpha_xz, alpha_zx, alpha_zz]]) # Minimizes the number of non-zero coeffs
        return np.sum([
            100 * probability_derivative_cost,
            1e-2 * epsilon_cost,
            1e-10 * time_cost,
            1e-5 * coefficients_l1_cost
        ])

# Run optimization

In [3]:
@dataclass
class Results:
    j_s: float
    u_s: float
    j_a: float
    u_a: float
    delta_a: float
    alpha_xx: float
    alpha_xz: float
    alpha_zx: float
    alpha_zz: float
    time: float
    epsilon: float
    nits: int
    nfev: int
    loss: float
    success: bool
    message: str
    method: str

In [4]:
def get_optimal_solution(settings: Settings, method:str = "Nelder-Mead") -> Results:    
    initial_x = (
        1e-1, # j_s
        1e-1, # u_s
        1e-1, # j_a
        1e-1, # u_a
        1e-1, # delta_a
        1e-1, # alpha_xx
        1e-1, # alpha_xz
        1e-1, # alpha_zx
        1e-1, # alpha_zz
        1e-1, # time
        1e-1, # epsilon    
    )
    results = minimize(
        settings.loss_function,
        initial_x,
        method=method,
        # jac=None,
        # hess=None,
        # hessp=None,
        # bounds=None,
        # constraints=(),
        # tol=None,
        # callback=None,
        # options=None
    )
    return Results(
        j_s=results.x[0],
        u_s=results.x[1],
        j_a=results.x[2],
        u_a=results.x[3],
        delta_a=results.x[4],
        alpha_xx=results.x[5],
        alpha_xz=results.x[6],
        alpha_zx=results.x[7],
        alpha_zz=results.x[8],
        time=results.x[9],
        epsilon=results.x[10],
        nits=results.nit,
        nfev=results.nfev,
        loss=results.fun,
        success=results.success,
        message=results.message,
        method=method,
    )

# Generate results df

In [5]:
test = get_optimal_solution(Settings(
    dim_s = 2,
    dim_a = 4,
    # n_trials = 5,
    true_delta_s = 3,
    probability_error_tolerance = 1e-10,
))

In [6]:
options = range(2,10)
tqdm_generator = itertools.product(options, repeat=2)
tqdm_total = len(options)**2
settings_generator = (
    Settings(dim_s, dim_a, 3, 1e-10)
    for dim_s, dim_a
    in tqdm_generator
)

df = pd.DataFrame([
    {
        "dim_s": settings.dim_s,
        "dim_a": settings.dim_a,
        # "n_trials": settings.n_trials,
        "true_delta_s": settings.true_delta_s,
        "probability_error_tolerance": settings.probability_error_tolerance,
        **asdict(get_optimal_solution(settings))
    }
    for settings in tqdm(settings_generator, total=tqdm_total)]
)

100%|██████████| 64/64 [28:23<00:00, 26.62s/it] 


In [7]:
df.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,54,55,56,57,58,59,60,61,62,63
dim_s,2,2,2,2,2,2,2,2,3,3,...,8,8,9,9,9,9,9,9,9,9
dim_a,2,3,4,5,6,7,8,9,2,3,...,8,9,2,3,4,5,6,7,8,9
true_delta_s,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
probability_error_tolerance,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
j_s,-0.614164,0.400113,-4.808845,-5.251521,5.97029,-3.504574,-7.080327,1.309152,0.354777,-6.065865,...,0.89038,2.837332,-1.317831,-4.586215,2.602346,3.466666,-1.10107,0.537358,0.01408,-3.517571
u_s,13.035962,-9.063561,4.516185,-8.827469,-0.410294,-1.168505,0.483279,-4.283175,1.751447,-1.281647,...,0.630342,1.495663,0.73743,1.459927,-1.136438,-0.742083,1.018646,1.552601,1.109244,1.908023
j_a,-1.893744,-5.403509,-0.706968,0.276883,1.305826,-0.885973,0.730601,-3.728412,-1.017799,5.467036,...,-0.744437,-3.410284,6.380873,-3.892244,-6.876324,-2.276501,0.187026,-1.467712,-8.0859,-4.52134
u_a,2.791028,4.545929,1.54556,3.966685,6.847315,1.889797,1.670508,1.406339,-3.699153,-3.932155,...,1.940635,1.078107,-2.271747,0.343363,-11.714435,-1.735873,0.726863,3.519649,6.468964,1.093718
delta_a,-5.732849,13.767416,0.707584,-3.851547,-20.719336,-0.06183,-1.789909,-0.11524,-0.591386,-1.346478,...,0.716617,-1.789703,2.727562,4.108234,10.69357,-0.937116,0.185769,1.454115,-9.986251,1.867129
alpha_xx,1.353268,-5.893731,-0.26987,0.501047,0.586046,0.771212,1.057593,-1.548468,0.124327,1.580692,...,-0.064887,-0.030466,-0.715325,1.512545,0.269581,-0.666404,-0.019961,-0.676247,-0.109386,0.186442


In [8]:
df[["dim_s", "dim_a", "j_s", "u_s", "j_a", "u_a", "delta_a", "alpha_xx", "alpha_xz", "alpha_zx", "alpha_zz", "time", "loss", "success"]]

Unnamed: 0,dim_s,dim_a,j_s,u_s,j_a,u_a,delta_a,alpha_xx,alpha_xz,alpha_zx,alpha_zz,time,loss,success
0,2,2,-0.614164,13.035962,-1.893744,2.791028,-5.732849,1.353268,-0.837998,-0.909006,-2.939662,8.783537,-265.555741,False
1,2,3,0.400113,-9.063561,-5.403509,4.545929,13.767416,-5.893731,1.658588,-7.243728,1.326521,20.602379,-380.192621,False
2,2,4,-4.808845,4.516185,-0.706968,1.545560,0.707584,-0.269870,-1.476643,-1.765912,0.218168,3.471314,-47.092890,False
3,2,5,-5.251521,-8.827469,0.276883,3.966685,-3.851547,0.501047,-0.843229,-1.663650,-3.096562,9.580045,-104.066046,False
4,2,6,5.970290,-0.410294,1.305826,6.847315,-20.719336,0.586046,0.134090,0.349416,-2.114662,6.988315,-66.335400,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
59,9,5,3.466666,-0.742083,-2.276501,-1.735873,-0.937116,-0.666404,-0.208755,0.089554,-0.164232,3.682808,-68.317327,True
60,9,6,-1.101070,1.018646,0.187026,0.726863,0.185769,-0.019961,1.092435,0.000450,0.495204,-1.683763,-50.208210,True
61,9,7,0.537358,1.552601,-1.467712,3.519649,1.454115,-0.676247,-0.754005,-0.393379,0.479294,-2.152515,-42.958829,False
62,9,8,0.014080,1.109244,-8.085900,6.468964,-9.986251,-0.109386,-0.669878,0.019576,0.256925,-2.144060,-41.534831,False
