In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap, ListedColormap, BoundaryNorm
from matplotlib.patches import Patch
from looptools.mokulaserlock import MokuLaserLock
from looptools.component import Component
import looptools.simulation as lsim
from looptools.plots import default_rc
from functools import partial
from scipy.interpolate import griddata
plt.rcParams.update(default_rc)

In [None]:
sps = 80e6
frfr = np.logspace(np.log10(1e0), np.log10(1e6), int(1e5))

# Define plant:
nume = [-27391.4746505128605349, 28991.6861562978592701,
        27391.5753081338189077, -28991.5850488191608747]
deno = [1.0, -2.9907469440381682, 2.9815121426943869,
        -0.9907651980332260] + [0.0]*10
Plant = Component("Plant", sps, nume=nume, deno=deno)

# Sweep parameters:
Kp_vals = np.linspace(-100, 80, 60)
f_I_vals = np.logspace(0, 5, 40)
param_grid = {
    "Servo_Kp_dB": Kp_vals,
    "Servo_Fc_i": f_I_vals
}
xlabel, ylabel = "P-gain (dB)", "Integrator crossover (Hz)"

# Define LOOP object:
loop = MokuLaserLock(Plant=Plant,
    Amp_reference=1.0,
    Amp_input=1.0,
    LPF_cutoff=0.25e6,
    LPF_n=4,
    Cshift=14,
    Kp_db=0,
    f_I=1,
    f_II=None,
    n_reg=0,
    off=[None]
)

# Run sweep:
result = lsim.parameter_sweep_nd(
    loop,
    param_grid=param_grid,
    frequencies=frfr,
    deg=True,
    unwrap_phase=True,
    interpolate=True
)

In [None]:
# Extract data
ugf = result["metrics"]["ugf"]                 # shape: (len(Kp_vals), len(f_I_vals))
pm = result["metrics"]["phase_margin"]         # same shape
f_I_vals = param_grid["Servo_Fc_i"]            # swept input param
Kp_vals = param_grid["Servo_Kp_dB"]            # swept input param

print("Any NaNs in ugf?", np.isnan(ugf).any())
print("Any NaNs in pm?", np.isnan(pm).any())

In [None]:
# Create 2D grid of inputs to match result shape
Kp_grid, fI_grid = np.meshgrid(Kp_vals, f_I_vals, indexing='ij')  # same shape as ugf

fig = plt.figure(figsize=(6, 4), dpi=200)
ax = fig.add_subplot(111)

# Use contourf instead of pcolormesh to handle non-monotonic X
c = ax.contourf(ugf, fI_grid, pm, levels=100, cmap='viridis')
c.set_clim(-90, 90)

# pcolormesh directly with ugf as X, f_I as Y, pm as value
# NOTE: may issue warnings if ugf is not monotonic
# Mask invalid PM values
ugf_clean = np.where(np.isfinite(ugf), ugf, np.nan)
ugf_clean = np.nan_to_num(ugf_clean, nan=0.0, posinf=1e7, neginf=1e0)
valid = np.isfinite(ugf_clean) & np.isfinite(pm)
pm_clean = np.where(valid, pm, np.nan)
c = ax.pcolormesh(ugf_clean, fI_grid, pm_clean, shading='auto', cmap='viridis', vmin=-90, vmax=90)

ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel("Unity gain frequency (Hz)")
ax.set_ylabel("Integrator crossover frequency (Hz)")
fig.colorbar(c, ax=ax, label="Phase margin (degrees)")

# Add contour line for π/4
contour = ax.contour(ugf, fI_grid, np.radians(pm), levels=[np.pi / 4], colors='k', linewidths=1.5)
ax.clabel(contour, fmt={np.pi / 4: r'Full model'}, inline=True, inline_spacing=1, fontsize=8, manual=[(1e4, 1e1)])

# Simplified model π/4 contour
x_sub = np.logspace(-1, 5, 100)
y_sub = np.logspace(-1, 5, 100)
X, Y = np.meshgrid(x_sub, y_sub)
# Phase margin function
def phase_margin(ugf, fz, fc):
    return np.arctan(ugf / fz) - np.arctan(ugf / fc)
Z = phase_margin(X, Y, 10e3)

contour = ax.contour(X, Y, Z, levels=[np.pi/4], colors='white', linewidths=1.5)
# ax.clabel(contour, fmt='Simple model', inline=True, inline_spacing=1, fontsize=8)
ax.text(0.4e3, 1e3, r'Simple model',
        fontsize=8.5,
        color='white',
        ha='center', va='center')

ax.set_xlim(20,1e5)
ax.set_ylim(1,1e5)
ax.grid(False)
fig.tight_layout()
plt.show()

In [None]:
def plot_sweep_results(result, *,
                       xlabel=None, ylabel=None,
                      logx=True, logy=False,
                      ax=None, cmap="RdYlGn", levels=20,
                      show_colorbar=True, title=None,
                      interpolation=False):
    """
    Plot a performance/stability metric from a parameter sweep result.

    Parameters
    ----------
    ... (other params) ...
    interpolation : bool, optional
        If True, plots a smooth, interpolated contour plot. If False (default),
        plots a discrete grid. For 'phase_margin', this toggles between a
        smooth gradient and a categorical plot.
    """
    metric = "phase_margin"
    param_names = result["parameter_names"]
    metric_values = result["metrics"][metric]

    with plt.rc_context(default_rc):

        if ax is None:
            fig, ax = plt.subplots()
        else:
            fig = ax.get_figure()

        x = result["parameter_grid"][param_names[0]]
        y = result["parameter_grid"][param_names[1]]
        z = metric_values

        if interpolation:
            # --- SMOOTH, INTERPOLATED PLOT ---
            # Define the colors and the phase margin values they correspond to.
            # We'll create a gradient from 0 to 90 degrees.
            colors = ['#DC143C', '#FFD700', '#32CD32', '#3CB371', '#87CEEB']
            nodes = [0.0, 30.0, 60.0, 76.0, 90.0]

            # Normalize the nodes to the [0, 1] range for the colormap
            norm_nodes = [n / max(nodes) for n in nodes]
            
            # Create a continuous colormap
            cmap_gradient = LinearSegmentedColormap.from_list(
                "phase_margin_gradient", list(zip(norm_nodes, colors))
            )
            
            # Set colors for values outside the 0-90 range
            cmap_gradient.set_under('#DC143C') # Unstable
            cmap_gradient.set_over('#6495ED')  # Very Sluggish

            # Normalize the data to the 0-90 range
            norm = plt.Normalize(vmin=0, vmax=90)
            
            # Use contourf for a smooth plot
            contour = ax.contourf(x, y, z, levels=512, cmap=cmap_gradient, norm=norm, extend='both')

            if show_colorbar:
                cbar = fig.colorbar(contour, ax=ax, label=metric.replace('_', ' ').title())
                cbar.set_ticks(nodes) # Add ticks at key locations

        else:
            # --- DISCRETE, CATEGORICAL PLOT (pcolormesh) ---
            bounds = [-np.inf, 0, 30, 60, 76, 90, np.inf]
            labels = ['Unstable (<0°)', 'Marginally Stable (0-30°)', 'Well Damped (30-60°)',
                        'Optimally Damped (60-76°)', 'Overdamped (76-90°)', 'Very Sluggish (>90°)']
            colors = ['#DC143C', '#FFD700', '#32CD32', '#3CB371', '#87CEEB', '#6495ED']
            
            cmap_custom = ListedColormap(colors)
            norm = BoundaryNorm(bounds, cmap_custom.N)

            ax.pcolormesh(x, y, z, cmap=cmap_custom, norm=norm, shading='auto')

            if show_colorbar: # This flag now controls the legend for this plot type
                legend_handles = [Patch(facecolor=color, edgecolor='black', label=label)
                                    for color, label in zip(colors, labels)]
                ax.legend(handles=legend_handles, title="Phase Margin Classification",
                            loc="upper right", fontsize=7, title_fontsize=8)

        # Axis settings for all 2D plots
        if logx:
            ax.set_xscale("log")
        if logy:
            ax.set_yscale("log")

        if xlabel is not None:
            ax.set_xlabel(xlabel)
        else:
            ax.set_xlabel(param_names[0])
        
        if ylabel is not None:
            ax.set_ylabel(ylabel)
        else:
            ax.set_ylabel(param_names[1])

        if title is not None:
            ax.set_title(title)
        
        return ax

In [None]:
# Parameters
sps = 80e6
Amp_reference = 1.0
Amp_input = 1.0
LPF_cutoff = 0.1e6
LPF_n = 1
Kp_db = -30
f_I = 1e2
f_II = None
n_reg = 0
frfr = np.logspace(np.log10(1e3), np.log10(1e5), int(1e5))

# Define plant
nume = [-27391.4746505128605349, 28991.6861562978592701,
        27391.5753081338189077, -28991.5850488191608747]
deno = [1.0, -2.9907469440381682, 2.9815121426943869,
        -0.9907651980332260] + [0.0]*10
Plant = Component("Plant", sps, nume=nume, deno=deno)

# Sweep parameters
Kp_vals = np.linspace(-60, 60, 40)
f_I_vals = np.logspace(0, 5, 40)
param_grid = {
    "Servo_Kp_dB": Kp_vals,
    "Servo_Fc_i": f_I_vals
}
xlabel, ylabel = "P-gain (dB)", "Integrator crossover (Hz)"

# Fixed parameters
loop_template = partial(
    MokuLaserLock,
    Plant=Plant,
    Amp_reference=Amp_reference,
    Amp_input=Amp_input,
    LPF_cutoff=LPF_cutoff,
    LPF_n=LPF_n,
    Kp_db=Kp_db,
    f_I=f_I,
    f_II=f_II,
    n_reg=n_reg,
    off=['LPF']
)

# Cshift values to sweep
cshift_vals = [0, 5, 10, 15, 20, 25]

# Plotting setup
fig, axes = plt.subplots(3, 2, figsize=(6.5, 9), sharex=True, sharey=True)
axes = axes.flatten()

# Run sweep and plot
for i, Cshift in enumerate(cshift_vals):
    loop = loop_template(Cshift=Cshift)
    result = lsim.parameter_sweep_nd(
        loop,
        param_grid=param_grid,
        frequencies=frfr,
        deg=True,
        unwrap_phase=True,
        interpolate=True
    )

    ax = axes[i]

    # Determine whether to show axis labels
    is_left_column = (i % 2 == 0)
    is_bottom_row = (i >= 4)

    plot_sweep_results(
        result,
        ax=ax,
        xlabel=xlabel if is_bottom_row else '',
        ylabel=ylabel if is_left_column else '',
        logx=False,
        logy=True,
        interpolation=False,
        show_colorbar=False
    )

    ax.set_title(f"Cshift = {Cshift}", fontsize=10)

# --- Add color legend at bottom ---
# Define colors and labels (should match plot_sweep_results)
colors = ['#DC143C', '#FFD700', '#32CD32', '#3CB371', '#87CEEB', '#6495ED']
labels = ['Unstable (<0°)', 'Marginally Stable (0-30°)', 'Well Damped (30-60°)',
          'Optimally Damped (60-76°)', 'Overdamped (76-90°)', 'Very Sluggish (>90°)']
legend_handles = [Patch(facecolor=color, edgecolor='black', label=label)
                  for color, label in zip(colors, labels)]

# Add a new axis for the legend at the bottom
legend_ax = fig.add_axes([0.1, 0.92, 0.8, 0.05])  # [left, bottom, width, height]
legend_ax.axis('off')  # Hide the axis frame
legend_ax.legend(handles=legend_handles, loc='center', ncol=3,
                 fontsize=8, title="Phase Margin Classification", title_fontsize=9,
                 frameon=False)
plt.show()