In [2]:
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple
import colorsys
import numpy as np
import matplotlib.pyplot as plt
import sk_dsp_comm.sigsys as ss
from numpy.typing import NDArray
from scipy.signal import find_peaks
import matplotlib.colors as colors
import random

NP1DF8 = NDArray[np.float64]
NP2DF8 = NDArray[np.float64]
NP1DC16 = NDArray[np.complex128]
NP2DC16 = NDArray[np.complex128]
TT_TYPE = NDArray[np.int64]

def filter_negWaveVectors(M_yx: NP2DC16) -> NP2DC16:
    M_kq = np.fft.fft2(M_yx)
    kLen = int(M_kq.shape[1] / 2 + 1)
    zeros = np.zeros((M_kq.shape[0], kLen))
    M_kq[:, :kLen] = zeros
    M_yx_filtered = np.fft.ifft2(M_kq)
    M_yx_filtered /= np.amax(
        np.abs(M_yx_filtered[:, int(M_yx_filtered.shape[1] / 2 + 100)])
    )
    return M_yx_filtered

def initialize_point_src() -> NP2DC16:
    M_oop = np.load("./point_source_yig/point_source_2600_MHz_200_mT_OOP.npz")["M"]
    pointSrc0 = filter_negWaveVectors(M_oop)
    return pointSrc0[:, int(pointSrc0.shape[1] / 2) :]

# def generate_digit_table(binary_table: List[int], number_of_inputs: int) -> List[int]:
#     # Generate the primary table with digits from 1 to number_of_inputs
#     primary_table = list(range(1, number_of_inputs + 1))
    
#     # Calculate the number of elements to select from the primary table
#     num_elements_to_select = sum(binary_table) * 3
    
#     # Ensure the number of elements to select does not exceed the primary table size
#     if num_elements_to_select > number_of_inputs:
#         raise ValueError("Number of elements to select exceeds the size of the primary table.")
    
#     # Select random, non-repeating elements from the primary table
#     selected_elements = random.sample(primary_table, num_elements_to_select)
    
#     return selected_elements

def generate_digit_table(binary_table: List[int], number_of_inputs: int) -> List[int]:
    table = []
    # If sum(binary_table)==0, then no waveguides are active.
    # If binary_table = [0,1,1], only waveguides #2, #3 active, etc.
    if binary_table[0] == 1:
        table.extend([1,2,3])  # waveguide #1 repeated 3 times
    if binary_table[1] == 1:
        table.extend([4,5,6])  # waveguide #2
    if binary_table[2] == 1:
        table.extend([7,8,9])  # waveguide #3
    return table

def generate_rectangular_function(
    distance_between_inputs: int,
    unit_cell_size: int,
    number_of_inputs: int,
    input_width: int,
) -> Tuple[NP1DF8, NP1DF8, float]:
    d = distance_between_inputs / unit_cell_size
    s = input_width / unit_cell_size
    width = d * number_of_inputs
    x_int = np.arange(0, width, 1)
    rec = ss.rect(np.mod(x_int, d) - d / 2, s)
    return rec, x_int, d

def calculate_analytic_signal(
    g_in: NP1DF8,
    x_int: NP1DF8,
    changed_src_numbers: List[int],
    input_phase_shift: float,
    active_input_amplitude: float,
):
        
    g_in_an = np.zeros_like(x_int, dtype=complex)
    b = 1
    for i in range(len(x_int)):
        if g_in[i] == 1 and (b in changed_src_numbers):
            phi_t = input_phase_shift
            Amp = active_input_amplitude
            if g_in[i + 1] == 0:
                b += 1
        elif g_in[i] == 1:
            phi_t = 0
            Amp = 1
            if g_in[i + 1] == 0:
                b += 1
        else:
            phi_t = 0
            Amp = 1
        g_in_an[i] = Amp * g_in[i] * np.exp(1j * phi_t)
    return g_in_an

def convolve_column(pointSrc: NP2DC16, g_in: NP1DC16, i: int) -> NP1DC16:
    return np.convolve(pointSrc[:, i], g_in, mode="full")

def convolve_columns(pointSrc: NP2DC16, g_in: NP1DC16) -> NP2DC16:
    newField = np.empty(
        (pointSrc.shape[0] + g_in.shape[0] - 1, pointSrc.shape[1]), dtype=np.complex128
    )
    with ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(convolve_column, pointSrc, g_in, i)
            for i in range(pointSrc.shape[1])
        ]
        for i, future in enumerate(futures):
            newField[:, i] = future.result()
    return newField

def normalize_and_slice_field(
    newField: NP2DC16,
    pointSrc: NP2DC16,
    axhl: int,
) -> NP2DF8:
    newField2 = np.abs(newField) ** 2
    newField2 = newField2[
        int(pointSrc.shape[0] / 2) : -int(pointSrc.shape[0] / 2), : int(1.5*axhl)
    ]
    return newField2

def calculate_truth_table(
    distance_between_inputs: float,
    max_i: List[float], 
    number_of_inputs: int,
    input_width: int,
    unite_cell_size: int
    ) -> TT_TYPE:
    truth_table = np.zeros(number_of_inputs-1)
    for i in range(0, number_of_inputs-1):
        if any((i+0.5)*distance_between_inputs + input_width/2
               <= element*unite_cell_size <= 
               (i+1.5)*distance_between_inputs - input_width/2 for element in max_i):
            truth_table[i] = 1
    return truth_table

# def calculate_truth_table2(
#     distance_between_inputs: float,
#     max_i: List[float], 
#     number_of_inputs: int,
#     input_width: int,
#     unite_cell_size: int,
#     active_outputs: List[int]
#     ) -> TT_TYPE:
#     truth_table = np.zeros(number_of_inputs-1)
#     for i in range(0, number_of_inputs-1):
#         if i==(active_outputs[0] or i==active_outputs[1]) and any((i+0.5)*distance_between_inputs + input_width/2
#                <= element*unite_cell_size <= 
#                (i+1.5)*distance_between_inputs - input_width/2 for element in max_i):
#             truth_table[i] = 1
#     return truth_table

def get_talbot_length(
    distance_between_inputs: int,  # nm
    wavelength: float, # nm
    input_width: int,  # nm
    unit_cell_size: int,  # nm
):
    if (
        distance_between_inputs % unit_cell_size != 0
        or input_width % unit_cell_size != 0
    ):
        raise ValueError(
            f"All length parameters must be divisible by {unit_cell_size}."
        )

    # round to the nearest multiple of unit_cell_size
    talbot_length = (
        (1 * distance_between_inputs**2 / wavelength) // unit_cell_size * unit_cell_size
    )
    return talbot_length

def colourMode(mod, phase):
    mod = mod/np.amax(mod)
    array1 = np.zeros((phase.shape[0], phase.shape[1], 3))
    for y in range(phase.shape[0]):
        for x in range(phase.shape[1]):
            h = 0.5+phase[y,x]/(2*np.pi)  # angle
            l, s = (mod[y,x]), 1
            array1[y,x] = np.array(colorsys.hls_to_rgb(h, l, s))
    return array1

In [3]:
def reference_signal(
    wavelength: float,
    output_line: int,
    input_phase: float,
    plot: bool = True,
    ):
    def y(x, wavelength, input_phase):
        return np.exp(1j*(2*np.pi/wavelength*x+input_phase))
    x = np.arange(1.3*output_line)
    
    if plot:
        plt.figure(figsize=(8,1))
        plt.plot(x,y(x, wavelength, input_phase).imag)
        plt.axvline(output_line, color='r', linestyle='--')
        plt.axhline(y(output_line, wavelength, input_phase).imag, color='r', linestyle='--')
        plt.show()  
    return y(output_line, wavelength, input_phase).imag

In [4]:
def logic_oop_phase(
    number_of_inputs: int,
    distance_between_inputs: int,  # nm
    input_width: int,  # nm
    input_binary_table: List[int],
    # active_active_input_phase: List[int],
    active_input_phase_shift: float,
    active_input_amplitude: float,
    active_outputs: List[int],
    wavelength: float,  # nm
    unit_cell_size: int,  # nm
    offset: int,
    w_axhl: int,
    plots: bool = True,
    phase_plot_output: bool = True,
    amp_threshold: float = 0.8,
    phase_threshold = 10 # degrees
):
    
    talbot_length = get_talbot_length(
        distance_between_inputs, wavelength, input_width, unit_cell_size
    )
    
    input_output_distance = talbot_length - offset
    axhl = int(input_output_distance // unit_cell_size)

    pointSrc: NP2DC16 = initialize_point_src()

    g_in, x_int, d = generate_rectangular_function(
        distance_between_inputs,
        unit_cell_size,
        number_of_inputs,
        input_width,
    )
    
    active_inputs = generate_digit_table(input_binary_table, number_of_inputs)
    I1 = active_inputs[:3]
    I2 = active_inputs[3:6]
    C1 = active_inputs[6:9]

    g_in_c: NP1DC16 = calculate_analytic_signal(
        g_in,
        x_int,
        active_inputs,
        active_input_phase_shift,
        active_input_amplitude
    )
    
    # input_table = get_input_table(active_inputs, active_active_input_phase)

    newField: NP2DC16 = convolve_columns(pointSrc, g_in_c)
    intensity: NP2DF8 = normalize_and_slice_field(
        newField, pointSrc, axhl)
    intensity_slice = intensity[:, 20:]
    if intensity_slice.size > 0:
        intensity /= np.amax(np.abs(intensity_slice))
    else:
        # fallback: use entire field to avoid ValueError
        intensity /= np.amax(np.abs(intensity))
    
    int_avg: NP1DF8 = np.mean(intensity[:, axhl - w_axhl : axhl + w_axhl], axis=1)
    threshold_amp = amp_threshold*int_avg.max()
    max_i: List[float] = find_peaks(int_avg, height=threshold_amp)[0]
    
    phase_2plot = np.angle(newField[int(pointSrc.shape[0] / 2) : -int(pointSrc.shape[0] / 2), : int(1.5*axhl)])
    mod_2plot = np.abs(newField[int(pointSrc.shape[0] / 2) : -int(pointSrc.shape[0] / 2), : int(1.5*axhl)])
    intentsity_1d_x = np.arange(0, intensity.shape[0], 1)
    int_avg = np.mean(intensity[:, axhl - w_axhl:axhl + w_axhl], axis=1)
    phase_1d_x = np.arange(0, intensity.shape[0], 1)
    phase_1d = phase_2plot[:, axhl]
    max_i = find_peaks(int_avg, height=threshold_amp)[0]
    colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'gold', 'black', 'orange', 'purple']

    truth_table_amp = calculate_truth_table(distance_between_inputs, max_i, number_of_inputs, input_width, unit_cell_size)

    phase_range = wavelength // unit_cell_size *1
    phases = []
    for i in range(1, number_of_inputs):
        phases.append(phase_2plot[distance_between_inputs*i//unit_cell_size, axhl-phase_range//2:axhl+phase_range//2])
    truth_table_phas = []
    for i in range(0, number_of_inputs-1):
        if (np.abs(np.mean((phases[i]-reference_signal(wavelength, axhl, 0, False))*180/np.pi,0)) >= phase_threshold) and (i in active_outputs):
            truth_table_phas.append(1)
            phase_diff = np.mean((phases[i]-reference_signal(wavelength, axhl, 0, False))*180/np.pi,0)
        elif (np.abs(np.mean((phases[i]-reference_signal(wavelength, axhl, 0, False))*180/np.pi,0)) < phase_threshold) and (i in active_outputs):
            truth_table_phas.append(0)
            phase_diff = np.mean((phases[i]-reference_signal(wavelength, axhl, 0, False))*180/np.pi,0)
        # else:
        #     truth_table_phas.append(0)
            
    if plots:

        f, ax = plt.subplots(3, 1, figsize=(12, 8), sharex=True)
        x_values = np.arange(0, len(active_inputs), 1)

        for i, val in enumerate(active_inputs):
            if val in I1:
                ax[0].text(-distance_between_inputs/2+val*distance_between_inputs, 1.15, 'I1', ha='center', color='blue', size=16, fontweight='bold')
            elif val in I2:
                ax[0].text(-distance_between_inputs/2+val*distance_between_inputs, 1.15, 'I2', ha='center', color='green', size=16, fontweight='bold')
            elif val in C1:
                ax[0].text(-distance_between_inputs/2+val*distance_between_inputs, 1.15, 'C1', ha='center', color='red', size=16, fontweight='bold')
            

        ax[0].plot(x_int*unit_cell_size, np.abs(g_in_c), color='k')
        ax[0].fill_between(x_int*unit_cell_size, 0, np.abs(g_in_c), color='#CCCCCC')
        ax[0].set_ylim(-0.1, 1.1)
        ax[0].set_ylabel("Amplitude")
        ax[0].set_text = "Input"
        ax0a = ax[0].twinx()
        ax0a.plot(x_int*unit_cell_size, np.angle(g_in_c) * 180 / np.pi, color='maroon', ls='-', lw=4)
        ax0a.tick_params(axis='y', labelcolor='maroon')
        ax0a.set_ylabel("Phase (°)", color='maroon')
        ax0a.set_ylim(-190, 190)
        ax0a.set_yticks(range(-180, 190, 90))
        
        ax[1].imshow(colourMode(mod_2plot.T, phase_2plot.T), aspect='auto', interpolation='sinc', extent=[0, len(intentsity_1d_x)*unit_cell_size, 1.5*axhl*unit_cell_size, 0])
        ax[1].set_xlabel("x (nm)")
        ax[1].set_ylabel("y (nm)")
        ax[1].axhline(axhl*unit_cell_size, color='red', lw=w_axhl, alpha=0)
        for i in range(0, number_of_inputs):
            ax[1].vlines(distance_between_inputs*(i+0.5), color='yellow', lw=2, alpha=0.6, ymin=(axhl-w_axhl*30)*unit_cell_size, ymax=(axhl+w_axhl*30)*unit_cell_size)
        for i in range(1, number_of_inputs):
            ax[1].vlines(distance_between_inputs*i, color=colors[(i-1) % len(colors)], lw=5, ymin=(axhl-phase_range//2)*unit_cell_size, ymax=(axhl+phase_range//2)*unit_cell_size)
        
        ax[2].set_ylim(0, 1)
        ax[2].plot(intentsity_1d_x*unit_cell_size, int_avg, linewidth=3, color='blue')
        ax[2].set_ylabel("Amplitude")
        ax[2].axhline(threshold_amp, color='orange', ls='--')
        for i in range(0, number_of_inputs-1):
            ax[2].text(distance_between_inputs/unit_cell_size * (i+1)*unit_cell_size, 0.8, f"{int(truth_table_amp[i])}", color='black', fontsize=12, ha='center')
        for i in range(0, number_of_inputs):
            ax[2].axvline(distance_between_inputs/unit_cell_size * (i+0.5)*unit_cell_size, color='orange', lw=2, alpha=0.4)
        for j in max_i:
            ax[2].scatter(j*unit_cell_size, int_avg[j], zorder=3, color='red')
        
        ax0b = ax[2].twinx()
        ax0b.plot(phase_1d_x*unit_cell_size, phase_1d * 180 / np.pi, linewidth=1, color='maroon')
        ax0b.set_ylabel("Phase (°)", color='maroon')
        ax0b.set_ylim(-190, 190)
        ax0b.set_yticks(range(-180, 190, 90))
        ax0b.tick_params(axis='y', labelcolor='maroon')
    
        f, ax = plt.subplots(1, number_of_inputs-1, figsize=(13, 2), sharey=True, gridspec_kw = {'wspace':0, 'hspace':0})
        for i in range(0, number_of_inputs-1):
            x_values = np.arange(0, len(phases[i]), 1) * unit_cell_size
            ax[i].scatter(x_values, (phases[i]-reference_signal(wavelength, axhl, 0, False))*180/np.pi, color=colors[i % len(colors)], label=f'Output {i + 1}', lw=2)
            ax[i].set_xlim(0, len(phases[i]) * unit_cell_size)
            if i == 0:
                ax[i].set_ylabel('Phase (°)')

        f.suptitle('Phase Difference Plots for Each Output')
        
        plt.tight_layout()
        plt.show()
        
    #-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
    if phase_plot_output==True:
        f = plt.figure(figsize=(11, 3))
        plt.title(f'I1={I1}, I2={I2}, C1={C1}, input: {input_binary_table} -----> output: {truth_table_phas}')
        plt.axhline(-phase_threshold, color='black', lw=1, ls='-.')
        plt.axhline(phase_threshold, color='black', lw=1, ls='-.')
        plt.axhspan(ymin=-phase_threshold, ymax=phase_threshold, facecolor='lightblue', alpha=0.3)

        phas=[]
        for i in range(0, number_of_inputs-1):
            if i in active_outputs:
                markerstyle = '*'
                marker_size = 150,
            else:
                markerstyle = 'o'
                marker_size = 100
            plt.scatter(i,np.mean((phases[i]-reference_signal(wavelength, axhl, 0, False))*180/np.pi,0), s=marker_size, marker=markerstyle,
                            color=colors[i % len(colors)], label=f'Output {i + 1}')
            phas.append(np.mean((phases[i]-reference_signal(wavelength, axhl, 0, False))*180/np.pi,0))
            if i == 0:
                plt.ylabel('Phase (°)')
        
        plt.ylim(-max(np.abs(phas))-10, max(np.abs(phas))+10)
        plt.plot(phas, color='grey', lw=1,ls='--')
        plt.tight_layout()
        plt.show()
        
    else:
        # print([1 if i in active_inputs else 0 for i in range(number_of_inputs)],'  ------>  ', truth_table_phas)
        print(f'I1={I1}, I2={I2}, C1={C1},\n {input_binary_table}  ------>  ', truth_table_phas)
        
    return I1, I2, C1, truth_table_phas
# , phase_diff

In [5]:
def is_full_adder_correct(logic_func, param_dict, verbose=False):

    FULL_ADDER_TRUTH_TABLE = [
        ([0,0,0], [0,0]),
        ([0,0,1], [1,0]),
        ([0,1,0], [1,0]),
        ([0,1,1], [0,1]),
        ([1,0,0], [1,0]),
        ([1,0,1], [0,1]),
        ([1,1,0], [0,1]),
        ([1,1,1], [1,1]),
    ]
    
    number_of_inputs        = param_dict["number_of_inputs"]
    distance_between_inputs = param_dict["distance_between_inputs"]
    input_width             = param_dict["input_width"]
    wavelength              = param_dict["wavelength"]
    unit_cell_size          = param_dict["unit_cell_size"]
    
    active_input_phase_shift= param_dict["active_input_phase_shift"]
    active_input_amplitude  = param_dict["active_input_amplitude"]
    active_outputs          = param_dict["active_outputs"] 
    phase_threshold         = param_dict["phase_threshold"]
    w_axhl                  = param_dict["w_axhl"]
    offset                  = param_dict["offset"]
    amp_threshold           = param_dict["amp_threshold"]
    
    plots = False
    phase_plot_output = False
    
    for combo, desired_out in FULL_ADDER_TRUTH_TABLE:

        I1_, I2_, C1_, truth_table_phas = logic_func(
            number_of_inputs          = number_of_inputs,
            distance_between_inputs   = distance_between_inputs,
            input_width               = input_width,
            input_binary_table        = combo,
            active_input_phase_shift  = active_input_phase_shift,
            active_input_amplitude    = active_input_amplitude,
            active_outputs            = active_outputs, 
            wavelength                = wavelength,
            unit_cell_size            = unit_cell_size,
            offset                    = offset,
            w_axhl                    = w_axhl,
            plots                     = plots,
            phase_plot_output         = phase_plot_output,
            amp_threshold             = amp_threshold,
            phase_threshold           = phase_threshold,
        )
        
        if len(truth_table_phas) == 0:
            sim_output = [0,0]
        else:
            sim_output = truth_table_phas
        
        if sim_output != desired_out:
            # if verbose:
            print(f" FAIL: input={combo}, got {sim_output}, expected {desired_out}")
            return False
    
    return True

In [6]:
def parameter_sweep_random(logic_func, n_iterations):
    
    param_dict_base = {
        "number_of_inputs": 9,
        "distance_between_inputs": 1440,
        "input_width": 240,            
        "wavelength": 232,             
        "unit_cell_size": 15,          
    }

    offset_candidates         = [0, 50, 100, 150, 200, 250, 300]
    w_axhl_candidates         = [5, 10, 15, 20]                   
    amp_threshold_candidates  = [0.5, 0.6, 0.7, 0.8, 0.9]
    active_input_amplitude_rng= (0.5, 1.5) 
    active_outputs_candidates = list(range(8))
    phase_shift_candidates    = [0, np.pi/6, np.pi/4, np.pi/2, 3*np.pi/4, np.pi, 5*np.pi/4, 3*np.pi/2, 7*np.pi/4]
    phase_threshold_candidates= list(range(5, 55, 5)) 

    for iteration in range(1, n_iterations+1):
        offset_val          = random.choice(offset_candidates)
        w_axhl_val          = random.choice(w_axhl_candidates)
        amp_thresh_val      = random.choice(amp_threshold_candidates)
        amplitude_val       = random.uniform(*active_input_amplitude_rng)
        out_idx             = random.sample(active_outputs_candidates, 2)
        phase_shift_val     = random.choice(phase_shift_candidates)
        phase_thresh_val    = random.choice(phase_threshold_candidates)

        param_dict = param_dict_base.copy()
        param_dict.update({
            "offset": offset_val,
            "w_axhl": w_axhl_val,
            "amp_threshold": amp_thresh_val,
            "active_input_amplitude": amplitude_val,
            "active_outputs": out_idx, 
            "active_input_phase_shift": phase_shift_val,
            "phase_threshold": phase_thresh_val,
        })
        
        print(f"Iteration {iteration}/{n_iterations}: "
              f"offset={offset_val}, w_axhl={w_axhl_val}, amp_thresh={amp_thresh_val}, "
              f"amplitude={amplitude_val:.2f}, out_idx={out_idx}, "
              f"phase_shift={phase_shift_val:.2f}, phase_thresh={phase_thresh_val}")
        
        if is_full_adder_correct(logic_func, param_dict, verbose=False):
            print(">>> FOUND a valid parameter set for the 3-input majority gate!")
            print(param_dict)
            return param_dict
        
    print("No suitable parameter combination found in this random sweep.")
    return None

In [8]:
n_iterations = 10000
found_params = parameter_sweep_random(logic_oop_phase, n_iterations)
if found_params:
    print("SUCCESS: Found full adder parameters:")
    print(found_params)
else:
    print(f"No success in {n_iterations} random trials.")

Iteration 1/10000: offset=200, w_axhl=10, amp_thresh=0.7, amplitude=1.38, out_idx=[1, 3], phase_shift=0.00, phase_thresh=40
I1=[], I2=[], C1=[],
 [0, 0, 0]  ------>   [0, 0]
I1=[7, 8, 9], I2=[], C1=[],
 [0, 0, 1]  ------>   [0, 0]
 FAIL: input=[0, 0, 1], got [0, 0], expected [1, 0]
Iteration 2/10000: offset=50, w_axhl=20, amp_thresh=0.7, amplitude=0.51, out_idx=[6, 5], phase_shift=0.00, phase_thresh=35
I1=[], I2=[], C1=[],
 [0, 0, 0]  ------>   [0, 0]
I1=[7, 8, 9], I2=[], C1=[],
 [0, 0, 1]  ------>   [0, 0]
 FAIL: input=[0, 0, 1], got [0, 0], expected [1, 0]
Iteration 3/10000: offset=150, w_axhl=10, amp_thresh=0.6, amplitude=1.20, out_idx=[4, 0], phase_shift=3.93, phase_thresh=50
I1=[], I2=[], C1=[],
 [0, 0, 0]  ------>   [0, 0]
I1=[7, 8, 9], I2=[], C1=[],
 [0, 0, 1]  ------>   [0, 0]
 FAIL: input=[0, 0, 1], got [0, 0], expected [1, 0]
Iteration 4/10000: offset=0, w_axhl=10, amp_thresh=0.6, amplitude=0.63, out_idx=[2, 3], phase_shift=4.71, phase_thresh=40
I1=[], I2=[], C1=[],
 [0, 0, 0