In [None]:
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple

In [None]:
class AWGPulseVisualizer:
    def __init__(self, waveform_folder: str, sample_rate: float, channel_labels: dict):
        self.waveform_folder = waveform_folder
        self.sample_rate = sample_rate
        self.channel_labels = channel_labels

    def waveform_to_load(self, waveform_name: str) -> dict:
        """ Load a waveform from disk to a dictionary. """
        waveform = dict.fromkeys(self.channel_labels, None)

        for key in list(waveform):
            try:
                filename = f"{waveform_name}_{key}.pkl"
                filepath = os.path.join(self.waveform_folder, filename)
                with open(filepath, 'rb') as f:
                    waveform[key] = pickle.load(f)
                if not np.any(waveform[key]):
                    # Delete channels with only zeros
                    print(f"{key} -> null")
                    del waveform[key]
                else:
                    print(f"{key} -> {len(waveform[key])}")
            except FileNotFoundError as _:
                # Delete channels that don't exist on disk
                print(f"{key} not found")
                del waveform[key]
        return waveform


    def plot_waveform(self, waveform: dict, plot_range: Tuple[int, int] = (0, -1)) -> plt.Figure:
        """ Generate a matplotlib Figure from a waveform dictionary. """
        if len(waveform) == 1:
            # Single channel plotting
            size_x, size_y = 8, 5
        else:
            # Multi channel plotting
            size_x, size_y = len(waveform) * 1.3, len(waveform)

        fig, ax = plt.subplots(nrows=len(waveform), sharex="all", figsize=(size_x, size_y))

        # Correct for matplotlib (silly) behavior for single plots
        if not isinstance(ax, np.ndarray):
            ax = [ax]

        for idx, key in enumerate(waveform):
            time_in_us = np.linspace(
                0,
                len(waveform[key][plot_range[0]:plot_range[1]]) / self.sample_rate,
                len(waveform[key][plot_range[0]:plot_range[1]])
            ) * 1e6

            if key.startswith('a_'):
                # Analog channels
                ax[idx].plot(time_in_us, waveform[key][plot_range[0]:plot_range[1]], ".-", linewidth=0.5, color=f"C{idx}")
            else:
                # Digital channels
                ax[idx].plot(time_in_us, waveform[key][plot_range[0]:plot_range[1]], "-", color=f"C{idx}")

            ax[idx].set_title(f'{key} ({self.channel_labels[key]})')

        ax[-1].set_xlabel("Time (μs)")
        return fig

In [None]:
labels = {
    'a_ch0': "In-phase", 
    'a_ch1': "Quadrature", 
    'a_ch2': "", 
    'a_ch3': "", 
    'd_ch0': "Laser trigger", 
    'd_ch1': "Start gate trigger", 
    'd_ch2': "Switch trigger", 
    'd_ch3': "Next gate trigger", 
    'd_ch4': "", 
    'd_ch5': "MW sweep trigger"
}
folder = os.path.join("C:\\", "qudi-hira", "saved_pulsed_assets", "waveform")

In [None]:
%matplotlib widget

awg_pulse_visualizer = AWGPulseVisualizer(waveform_folder=folder, sample_rate=1.25e9, channel_labels=labels)

t1_waveform = awg_pulse_visualizer.waveform_to_load(waveform_name="t1")

figure = awg_pulse_visualizer.plot_waveform(t1_waveform)
figure.tight_layout()