[comment]: <>  (-------------------------------------------------------------------------------------------------)
[comment]: <>  (Copyright (C) 2025 Matías Rubén Bolaños Wagner)
[comment]: <>  (SPDX-License-Identifier: GPL-3.0-or-later)
[comment]: <>  (---- ---- ---- ---- ---- ---- ---- ---- ---- ---- ---- ---- ---- ---- ---- ---- ---- ---- ---- --)
# WELCOME TO zPulse

zPulse is able to go to **62.5 ps wide pulses** while being able to tune the repetition rate to basically whatever you want. Even though software limitation implies max pulses to 10, they could in principle be increased to 1024 and beyond (looking at you QKD folks).

## Connections

If you are using the [Hi-Tech global red board](https://www.hitechglobal.com/FMCModules/FMC_SMA_LVDS.htm), it's no surprise that the channel numbers **DO NOT** match the output number, so I prepared a fancy table to know where you need to connect your stuff.

| zPulse channel  | Red board pin  |
|---|---|
| 1  | C2M_4  |
| 2  | C2M_5  |
| 3  | C2M_6  |
| 4  | C2M_7  |
| 5  | C2M_0  |
| 6  | C2M_1  |
| 7  | C2M_2  |
| 8  | C2M_3  |

If you are using the [Terasic blue board](https://www.terasic.com.tw/cgi-bin/page/archive.pl?Language=English&CategoryNo=164&No=1226&PartNo=1) only channels 5 to 8 are available

| zPulse channel  | Blue board pin  |
|---|---|
| 1  | NC    |
| 2  | NC    |
| 3  | NC    |
| 4  | NC    |
| 5  | TX_0  |
| 6  | TX_1  |
| 7  | TX_2  |
| 8  | TX_3  |

In [None]:
import math
import json
import os
import logging
from pathlib import Path
from typing import Dict, Optional, Tuple, List

from zPulse.zPulse_overlay import zPulseOverlay
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import Widget, BoundedFloatText, IntSlider
from IPython.display import display, HTML, clear_output
from pynq import PL

colors_hex = [
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#7f7f7f",
    "#bcbd22",
    "#17becf",
]

ol = None
DEFAULT_PERIOD: float = None  #: Default period in ps
DEFAULT_WIDTH: float = None  #: Default width of pulses
DEFAULT_START_POSITION: float = None  #: Default start position for the pulses
RESOLUTION: float = None  #: default value for resolution step
    
DEFAULT_NUM_PULSES: int = 1  #: Default number of channels
DEFAULT_GLOBAL_DELAY: float = 0  #: Default delay on the whole channel
CHANNELS: int = 8  #: Number of channels

UIElements = Tuple[
    BoundedFloatText,
    IntSlider,
    BoundedFloatText,
    BoundedFloatText,
    Dict[str, BoundedFloatText],
    Widget,
]

# Function to locate the bitstream directory and create a dropdown for bitstream selection and resolution input
def bitstream_selection():
    # Locate Bitstream directory relative to this notebook
    root_dir = Path.cwd()
    while not (root_dir / "zPulse").exists() and root_dir != root_dir.parent:
        root_dir = root_dir.parent

    bitstream_dir = root_dir / "zPulse" / "Bitstream"

    if not bitstream_dir.exists():
        print("Bitstream directory not found.")
        return
    
    # Collect all .bit files and check for matching .hwh
    all_bit_files = list(bitstream_dir.glob("*.bit"))
    bit_files = []
    for bit_file in all_bit_files:
        hwh_file = bit_file.with_suffix(".hwh")
        if hwh_file.exists():
            bit_files.append(bit_file)
        else:
            print(f"Skipping '{bit_file.name}': missing '{hwh_file.name}'")

    if not bit_files:
        print("No matching .bit/.hwh pairs found in Bitstream folder.")
        return

    bitstream_options = {bit_file.stem: bit_file for bit_file in bit_files}

    dropdown = widgets.Dropdown(
        options=["-- Select a bitstream --"] + list(bitstream_options.keys()),
        value="-- Select a bitstream --",
        description='Select Bitstream:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='300px')
    )

    resolution_input = widgets.FloatText(
        description='Resolution (ps):',
        value=0.0,
        disabled=True,  # Initially disabled
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='200px')
    )

    button = widgets.Button(
        description="Load Bitstream",
        button_style='success',
        disabled=True  # Initially disabled
    )

    output = widgets.Output()

    def on_dropdown_change(change):
        selected_label = change['new']
        if selected_label == "-- Select a bitstream --":
            resolution_input.disabled = True
            button.disabled = True
            resolution_input.value = 0.0
        else:
            resolution_input.disabled = False
            button.disabled = False
            if "10" in selected_label:
                resolution_input.value = 100
            elif "12_5" in selected_label:
                resolution_input.value = 80
            elif "16" in selected_label:
                resolution_input.value = 62.5
            else:
                resolution_input.value = 0.0  # Prompt user to input manually

    dropdown.observe(on_dropdown_change, names='value')

    def on_button_clicked(b):
        nonlocal bitstream_options
        global ol, RESOLUTION

        with output:
            clear_output()
            selected_label = dropdown.value
            bit_path = bitstream_options[selected_label]

            print(f"Loading bitstream: {bit_path.name}")
            ol = zPulseOverlay(str(bit_path))

            RESOLUTION = resolution_input.value
            print(f"Step Resolution set to {RESOLUTION} ps\n")

            update_resolution(RESOLUTION)
            interactive_channels()

    button.on_click(on_button_clicked)

    display(dropdown, resolution_input, button, output)

# Function to update global resolution settings
def update_resolution(resolution) -> None:
    global DEFAULT_PERIOD, DEFAULT_WIDTH, DEFAULT_START_POSITION
    DEFAULT_PERIOD = 20 * RESOLUTION  #: Default period in ps
    DEFAULT_WIDTH = 4 * RESOLUTION  #: Default width of pulses
    DEFAULT_START_POSITION = 2 * RESOLUTION  #: Default start position for the pulses
    
# Function to save settings to a JSON file
def save_settings(file_path: os.PathLike, settings: Dict[str, Dict]) -> None:
    try:
        with open(file_path, "w") as f:
            json.dump(settings, f, indent=4)
    except IOError:
        logging.warning("settings could not saved at %s", str(file_path))


# Function to load settings from a JSON file
def load_settings(file_path) -> Optional[Dict[str, Dict]]:
    try:
        with open(file_path, "r") as f:
            return json.load(f)
    except IOError:
        logging.warning("Settings could ne be read at %s", str(file_path))


# Function to get current settings from UI elements for all channels
def get_all_settings(
    channel_tabs: List[Tuple[str, Widget, UIElements]],
) -> Dict[str, Dict]:
    settings = {}
    for channel_name, _, ui_elements in channel_tabs:
        period_slider, num_pulses_slider, delay_slider, pulse_sliders, _ = ui_elements
        settings[channel_name] = {
            "period": period_slider.value,
            "num_pulses": num_pulses_slider.value,
            "delay": delay_slider.value,
            "pulses": [
                {
                    "width": pulse_sliders[f"pulse{i+1}_width"].value,
                    "start": pulse_sliders[f"pulse{i+1}_start"].value,
                }
                for i in range(num_pulses_slider.value)
            ],
        }
    return settings


# Function to apply settings to UI elements for all channels
def apply_all_settings(
    channel_tabs: List[Tuple[str, Widget, UIElements]], settings: Dict[str, Dict]
):
    for (
        channel_name,
        _,
        ui_elements,
    ) in channel_tabs:  # Unpack the tuple to get channel_name and ui_elements
        period_slider, num_pulses_slider, delay_slider, pulse_sliders, _ = ui_elements
        if channel_name in settings:
            # Access settings for this channel by channel_name
            channel_settings = settings[channel_name]
            period_slider.value = channel_settings.get("period", DEFAULT_WIDTH)
            num_pulses_slider.value = channel_settings.get(
                "num_pulses", DEFAULT_NUM_PULSES
            )
            delay_slider.value = channel_settings.get("delay", DEFAULT_GLOBAL_DELAY)
            pulse: Dict[str, float]
            for i, pulse in enumerate(channel_settings["pulses"]):
                pulse_sliders[f"pulse{i+1}_width"].value = pulse.get(
                    "width", DEFAULT_WIDTH
                )
                pulse_sliders[f"pulse{i+1}_start"].value = pulse.get(
                    "start", i * DEFAULT_START_POSITION
                )
        else:
            logging.warning("%s not in settings", channel_name)


def lcm(a: int, b: int) -> int:
    return int(a * b / math.gcd(a, b))


def is_valid_value(value: float) -> bool:
    if value % RESOLUTION != 0:
        return False
    int_value = int(value / RESOLUTION)  # Convert to integer
    return lcm(int_value, 64) < 262144


def binary_array_to_integers(binary_array) -> List[int]:
    """Convert a binary array into an array of integers, grouping every 32 bits."""
    # Convert every 32-bit chunk to an integer
    int_values = [
        int("".join(map(str, binary_array[i + 32 : i : -1])), 2)
        for i in range(0, len(binary_array), 32)
    ]

    return int_values


global_combined_waveforms = {
    f"Channel {i+1}": None for i in range(8)
}  # Global variables to store waveforms per channel


class OutputLogic:

    class ChannelEnable:
        index: int
        state: bool

        def __init__(self, index):
            self.index = index
            self.state = False

        def on(self) -> None:
            self.state = True
            print(f"Channel {self.index + 1} ON")

        def off(self) -> None:
            self.state = False
            print(f"Channel {self.index + 1} OFF")

    ch_enable: List[ChannelEnable]

    def __init__(self, num_channels: int = 8):
        self.ch_enable = [self.ChannelEnable(i) for i in range(num_channels)]


output_logic = OutputLogic()


def generate_pulse(pulse_width: float, start_point: float, period: float) -> np.ndarray:
    num_points = int(period / RESOLUTION)
    pulse_points = int(pulse_width / RESOLUTION)
    start_index = int(start_point / RESOLUTION)

    waveform = np.zeros(num_points, dtype=int)
    waveform[start_index : start_index + pulse_points] = 1  # Set the pulse

    return waveform


def plot_pulses(channel_name: str, channel_index: int, **kwargs) -> np.ndarray:
    global global_combined_waveforms
    period = int(kwargs["period"] / RESOLUTION)
    delay_steps = int(kwargs["delay"] / RESOLUTION)
    combined_waveform = np.zeros(period, dtype=int)

    for i in range(kwargs["num_pulses"]):
        width = kwargs[f"pulse{i+1}_width"]
        start = kwargs[f"pulse{i+1}_start"]
        waveform = generate_pulse(width, start, kwargs["period"])
        combined_waveform = np.maximum(combined_waveform, waveform)

    combined_waveform = np.roll(combined_waveform, delay_steps)
    repetition_factor = int(lcm(len(combined_waveform), 64) / len(combined_waveform))
    extended_waveform = np.tile(combined_waveform, repetition_factor)
    global_combined_waveforms[channel_index] = (
        extended_waveform  # Store waveform globally
    )

    int_waveform_to_memory = binary_array_to_integers(extended_waveform)

    waveform_to_send = int_waveform_to_memory
    addr_limit = len(int_waveform_to_memory)
    current_on_state = ol.ch_enable[channel_index].read()

    ol.reset()
    ol.ch_player[channel_index][:addr_limit] = waveform_to_send

    ol.addr_limit[channel_index].write(addr_limit * 4 - 8, 0xFFFFFFF)
    if current_on_state:
        ol.ch_enable[channel_index].on()
    
    time_axis = np.arange(len(combined_waveform)) * RESOLUTION  # Time in ps

    plt.figure(figsize=(8, 3))
    plt.step(
        time_axis,
        combined_waveform,
        where="post",
        linewidth=2,
        color=colors_hex[channel_index],
    )
    plt.xlabel("Time (ps)")
    plt.ylabel("Amplitude")
    plt.title(f"Multiple Pulse Waveform - {channel_name}")
    plt.ylim(-0.1, 1.1)
    plt.grid()
    plt.show()

    return combined_waveform


def round_to_step(value: float, step: float = RESOLUTION) -> float:
    return round(value / step) * step


def create_channel_tab(
    channel_index: int, channel_name: str
) -> Tuple[Widget, UIElements]:
    valid_values = [v for v in np.arange(RESOLUTION, 65536 * RESOLUTION, RESOLUTION) if is_valid_value(v)]

    def round_input(change):
            change["owner"].value = round_to_step(change["new"],RESOLUTION)

    period_slider = widgets.BoundedFloatText(
        value=DEFAULT_PERIOD,
        min=RESOLUTION,
        max=262144,
        step=RESOLUTION,
        description="Period (ps)",
        readout_format=".1f",
    )

    num_pulses_slider = widgets.IntSlider(
        value=DEFAULT_NUM_PULSES, min=1, max=10, step=1, description="Num Pulses"
    )

    delay_slider = widgets.BoundedFloatText(
        value=DEFAULT_GLOBAL_DELAY,
        min=0,
        max=period_slider.value,
        step=RESOLUTION,
        description="Delay (ps)",
        readout_format=".1f",
    )

    period_slider.observe(round_input, "value")
    delay_slider.observe(round_input, "value")

    enable_button = widgets.ToggleButton(
        value=False, description=f"Enable {channel_name}", button_style="danger"
    )

    def toggle_channel(change):
        if enable_button.value:  # If button is pressed (enabled)
            int_waveform_to_memory = binary_array_to_integers(
                global_combined_waveforms[channel_index]
            )
            ol.reset()
            waveform_to_send = int_waveform_to_memory
            addr_limit = len(int_waveform_to_memory)
            ol.ch_player[channel_index][:addr_limit] = waveform_to_send
            ol.enable_channel(channel_index)
            enable_button.button_style = "success"  # Change color to indicate ON
        else:  # If button is not pressed (disabled)
            ol.disable_channel(channel_index)
            enable_button.button_style = "danger"  # Change color to indicate OFF

    enable_button.observe(toggle_channel, "value")

    def create_pulse_tab(
        index: int,
    ) -> Tuple[Widget, BoundedFloatText, BoundedFloatText]:
        width_slider = widgets.BoundedFloatText(
            value=DEFAULT_WIDTH,
            min=1 * RESOLUTION,
            max=period_slider.value,
            step=RESOLUTION,
            description="Width (ps)",
            readout_format=".1f",
        )
        start_slider = widgets.BoundedFloatText(
            value=index * DEFAULT_START_POSITION,
            min=0,
            max=period_slider.value,
            step=RESOLUTION,
            description="Start (ps)",
            readout_format=".1f",
        )

        width_slider.observe(round_input, "value")
        start_slider.observe(round_input, "value")

        return widgets.VBox([width_slider, start_slider]), width_slider, start_slider

    tabs = []
    pulse_sliders = {}

    for i in range(1, 11):
        tab_content, width_slider, start_slider = create_pulse_tab(i)
        tabs.append((f"Pulse {i}", tab_content))
        pulse_sliders[f"pulse{i}_width"] = width_slider
        pulse_sliders[f"pulse{i}_start"] = start_slider

    tab_widget = widgets.Tab()
    tab_widget.children = [tab[1] for tab in tabs]
    for i, (title, _) in enumerate(tabs):
        tab_widget.set_title(i, title)

    def update_tabs(*args):
        num_pulses = num_pulses_slider.value
        tab_widget.children = [tabs[i][1] for i in range(num_pulses)]
        delay_slider.max = period_slider.value
        for i in range(num_pulses):
            tab_widget.set_title(i, tabs[i][0])
            pulse_sliders[f"pulse{i+1}_start"].max = period_slider.value
            pulse_sliders[f"pulse{i+1}_width"].max = period_slider.value

    period_slider.observe(update_tabs, "value")
    num_pulses_slider.observe(update_tabs, "value")
    update_tabs()

    out = widgets.interactive_output(
        lambda **kwargs: plot_pulses(channel_name, channel_index, **kwargs),
        {
            **pulse_sliders,
            "period": period_slider,
            "num_pulses": num_pulses_slider,
            "delay": delay_slider,
        },
    )

    ui = widgets.VBox(
        [enable_button, period_slider, delay_slider, num_pulses_slider, tab_widget, out]
    )

    return ui, (period_slider, num_pulses_slider, delay_slider, pulse_sliders, out)


def interactive_channels():
    tab_widget = widgets.Tab()
    tab_initialized = [False] * CHANNELS
    tab_outputs = [widgets.Output() for _ in range(CHANNELS)]
    tab_titles = [f"Channel {i+1}" for i in range(CHANNELS)]
    channel_tabs = [None] * CHANNELS

    # Set placeholder output widgets as initial tab contents
    tab_widget.children = tab_outputs
    for i, title in enumerate(tab_titles):
        tab_widget.set_title(i, title)

    def initialize_channel_tab(i):
        if tab_initialized[i]:
            return
        # Generate tab UI
        channel_name = f"Channel {i+1}"
        channel_ui, ui_elements = create_channel_tab(i, channel_name)
        channel_tabs[i] = (channel_name, channel_ui, ui_elements)

        # Replace the Output widget content with the full UI
        with tab_outputs[i]:
            tab_outputs[i].clear_output()
            display(channel_ui)

        tab_initialized[i] = True

    def on_tab_selected(change):
        i = change['new']
        initialize_channel_tab(i)
    
    tab_widget.observe(on_tab_selected, names='selected_index')
    
    # Initialize the first tab to show content immediately
    initialize_channel_tab(0)

    # File name and path widgets for saving/loading
    save_file_name = widgets.Text(description="File Name", value="settings.json")
    save_file_path = widgets.Text(description="File Path", value=os.getcwd())

    save_button = widgets.Button(description="Save Settings")
    load_button = widgets.Button(description="Load Settings")

    def save_button_clicked(b):
        file_path = os.path.join(save_file_path.value, save_file_name.value)
        settings = get_all_settings(channel_tabs)
        save_settings(file_path, settings)
        print(f"Settings saved to {file_path}")

    def load_button_clicked(b):
        file_path = os.path.join(save_file_path.value, save_file_name.value)
        settings = load_settings(file_path)
        apply_all_settings(channel_tabs, settings)
        print(f"Settings loaded from {file_path}")

    save_button.on_click(save_button_clicked)
    load_button.on_click(load_button_clicked)

    
    display(
        tab_widget,
        widgets.VBox([save_file_name, save_file_path, save_button, load_button]),
    )

PL.reset()
bitstream_selection()