From c800f930440f3de439ecb3c540ee10fd53f0cb28 Mon Sep 17 00:00:00 2001 From: John Halloran Date: Mon, 11 Aug 2025 22:51:19 -0700 Subject: [PATCH 1/2] feat: add live plotting of updates --- src/diffpy/snmf/main.py | 1 + src/diffpy/snmf/plotter.py | 56 +++++++++++++++++++++++++++++++++++ src/diffpy/snmf/snmf_class.py | 24 +++++++++++++++ 3 files changed, 81 insertions(+) create mode 100644 src/diffpy/snmf/plotter.py diff --git a/src/diffpy/snmf/main.py b/src/diffpy/snmf/main.py index 3378d8d0..108fe05e 100644 --- a/src/diffpy/snmf/main.py +++ b/src/diffpy/snmf/main.py @@ -12,6 +12,7 @@ init_weights=init_weights_file, init_components=init_components_file, init_stretch=init_stretch_file, + show_plots=True, ) print("Done") diff --git a/src/diffpy/snmf/plotter.py b/src/diffpy/snmf/plotter.py new file mode 100644 index 00000000..8b63d484 --- /dev/null +++ b/src/diffpy/snmf/plotter.py @@ -0,0 +1,56 @@ +# helper_plot.py +import matplotlib.pyplot as plt +import numpy as np + + +class SNMFPlotter: + def __init__(self, figsize=(12, 4)): + plt.ion() + self.fig, self.axes = plt.subplots(1, 3, figsize=figsize) + titles = ["Components", "Weights (rows as series)", "Stretch (rows as series)"] + for ax, t in zip(self.axes, titles): + ax.set_title(t) + self.lines = {"components": [], "weights": [], "stretch": []} + self._layout_done = False + plt.show() + + def _ensure_lines(self, ax, key, n_series): + cur = self.lines[key] + if len(cur) != n_series: + ax.cla() + ax.set_title(ax.get_title()) + self.lines[key] = [ax.plot([], [])[0] for _ in range(n_series)] + return self.lines[key] + + def _update_series(self, ax, key, data_2d): + # Expect rows = separate series for components + data_2d = np.atleast_2d(data_2d) + n_series, n_pts = data_2d.shape + lines = self._ensure_lines(ax, key, n_series) + x = np.arange(n_pts) + for ln, y in zip(lines, data_2d): + ln.set_data(x, y) + ax.relim() + ax.autoscale_view() + + def update(self, components, weights, stretch, update_tag=None): + # Components: transpose before plotting + C = np.asarray(components).T + self._update_series(self.axes[0], "components", C) + + W = np.asarray(weights) + self._update_series(self.axes[1], "weights", W) + + S = np.asarray(stretch) + self._update_series(self.axes[2], "stretch", S) + + if update_tag is not None: + self.fig.suptitle(f"Updated: {update_tag}", fontsize=14) + + if not self._layout_done: + self.fig.tight_layout() + self._layout_done = True + + self.fig.canvas.draw() + self.fig.canvas.flush_events() + plt.pause(0.001) diff --git a/src/diffpy/snmf/snmf_class.py b/src/diffpy/snmf/snmf_class.py index 3bccee5a..489aabd9 100644 --- a/src/diffpy/snmf/snmf_class.py +++ b/src/diffpy/snmf/snmf_class.py @@ -1,5 +1,6 @@ import cvxpy as cp import numpy as np +from plotter import SNMFPlotter from scipy.optimize import minimize from scipy.sparse import coo_matrix, diags @@ -73,6 +74,7 @@ def __init__( tol=5e-7, n_components=None, random_state=None, + show_plots=False, ): """Initialize an instance of SNMF and run the optimization. @@ -112,6 +114,8 @@ def __init__( random_state : int Optional Default = None The seed for the initial guesses at the matrices (A, X, and Y) created by the decomposition. + show_plots : boolean Optional Default = False + Enables plotting at each step of the decomposition. """ self.source_matrix = source_matrix @@ -123,6 +127,7 @@ def __init__( self.signal_length, self.n_signals = source_matrix.shape self.num_updates = 0 self._rng = np.random.default_rng(random_state) + self.plotter = SNMFPlotter() if show_plots else None # Enforce exclusive specification of n_components or init_weights if (n_components is None and init_weights is None) or ( @@ -236,6 +241,13 @@ def normalize_results(self): print(f"Objective function after normalize_components: {self.objective_function:.5e}") self._objective_history.append(self.objective_function) self.objective_difference = self._objective_history[-2] - self._objective_history[-1] + if self.plotter is not None: + self.plotter.update( + components=self.components, + weights=self.weights, + stretch=self.stretch, + update_tag="normalize components", + ) if self.objective_difference < self.objective_function * self.tol and outiter >= 7: break @@ -252,6 +264,10 @@ def outer_loop(self): if self.objective_function < self.best_objective: self.best_objective = self.objective_function self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()] + if self.plotter is not None: + self.plotter.update( + components=self.components, weights=self.weights, stretch=self.stretch, update_tag="components" + ) self.update_weights() self.residuals = self.get_residual_matrix() @@ -262,6 +278,10 @@ def outer_loop(self): if self.objective_function < self.best_objective: self.best_objective = self.objective_function self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()] + if self.plotter is not None: + self.plotter.update( + components=self.components, weights=self.weights, stretch=self.stretch, update_tag="weights" + ) self.objective_difference = self._objective_history[-2] - self._objective_history[-1] if self._objective_history[-3] - self.objective_function < self.objective_difference * 1e-3: @@ -276,6 +296,10 @@ def outer_loop(self): if self.objective_function < self.best_objective: self.best_objective = self.objective_function self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()] + if self.plotter is not None: + self.plotter.update( + components=self.components, weights=self.weights, stretch=self.stretch, update_tag="stretch" + ) def get_residual_matrix(self, components=None, weights=None, stretch=None): # Initialize residual matrix as negative of source_matrix From 2c20e07b15f39e58a3e9a551466af2d346a9150b Mon Sep 17 00:00:00 2001 From: John Halloran Date: Mon, 11 Aug 2025 22:53:14 -0700 Subject: [PATCH 2/2] style: make plotting vars lowercase --- src/diffpy/snmf/plotter.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffpy/snmf/plotter.py b/src/diffpy/snmf/plotter.py index 8b63d484..9d8255ef 100644 --- a/src/diffpy/snmf/plotter.py +++ b/src/diffpy/snmf/plotter.py @@ -1,4 +1,3 @@ -# helper_plot.py import matplotlib.pyplot as plt import numpy as np @@ -35,14 +34,14 @@ def _update_series(self, ax, key, data_2d): def update(self, components, weights, stretch, update_tag=None): # Components: transpose before plotting - C = np.asarray(components).T - self._update_series(self.axes[0], "components", C) + c = np.asarray(components).T + self._update_series(self.axes[0], "components", c) - W = np.asarray(weights) - self._update_series(self.axes[1], "weights", W) + w = np.asarray(weights) + self._update_series(self.axes[1], "weights", w) - S = np.asarray(stretch) - self._update_series(self.axes[2], "stretch", S) + s = np.asarray(stretch) + self._update_series(self.axes[2], "stretch", s) if update_tag is not None: self.fig.suptitle(f"Updated: {update_tag}", fontsize=14)