In [7]:
from IPython import display
import matplotlib.pyplot as plt
from matplotlib import patches
from ipywidgets import widgets
import numpy as np
from scipy import ndimage
from sunpy.map import Map

In [3]:
def select_region(labels, idx, boxSize=256, contourSize=None):
    """
    Return the bottom left and top right corners of the
    bounding box of a labelled image region, optionally
    contouring around the box to "fuzzily" include
    features on the edge.

    Parameters
    ----------
    labels : np.ndarray
        The integer labelled mask (produced by e.g. scipy.ndimage.label).
    idx : int
        The labelled region to extract.
    boxSize : int
        The size of the boxes used to define the regions.
    contourSize : Optional[int]
        The size to add onto each side of the region when defining
        the bounding box. Defauly: boxSize // 2.

    Returns
    -------
    bottomLeft, topRight : Tuple[int, int]
        The pixel coordinates of the corners of the selected region.
    """
    if contourSize is None:
        contourSize = boxSize // 2
    coords = np.argwhere(labels == idx)
    minX = np.min(coords[:, 1])
    maxX = np.max(coords[:, 1])
    minY = np.min(coords[:, 0])
    maxY = np.max(coords[:, 0])
    bottomLeft = (
        (maxX + 1) * boxSize + contourSize // 2,
        (maxY + 1) * boxSize + contourSize // 2,
    )
    topRight = (minX * boxSize - contourSize // 2, minY * boxSize - contourSize // 2)
    return bottomLeft, topRight

In [5]:
def overplot_rect_from_coords(ax, bottomLeft, topRight):
    """
    Overplot green rectangle from the provided corners on the axes.

    Parameters
    ----------
    ax : Matplotlib axes
        The axes to plot on.
    bottomLeft : Tuple[int, int]
        The bottom left corner (pixel coordinates) of the region.
    topRight : Tuple[int, int]
        The top right corner (pixel coordinates) of the region.
    """
    boxX = topRight[0] - bottomLeft[0]
    boxY = topRight[1] - bottomLeft[1]
    boxAnchor = (bottomLeft[0] - 0.5, bottomLeft[1] - 0.5)
    newRect = patches.Rectangle(
        boxAnchor, boxX, boxY, lineWidth=1, edgeColor="g", faceColor="none"
    )
    ax.add_patch(newRect)

In [6]:
def overplot_spots_from_mask(ax, mask, boxSize=128):
    """
    Plot the mask onto a given set of axes, drawing the selected regions in
    red.

    Parameters
    ----------
    ax : Matplotlib axes
        The axes to which to add the boxes from the mask.
    mask : np.ndarray
        The mask indicating the regions to draw.
    boxSize : int
        The size each block in the mask represents on the image.
    """
    boxCoords = zip(*np.unravel_index(mask.reshape(-1).nonzero()[0], mask.shape))
    for c in boxCoords:
        boxX = boxSize
        boxY = boxSize
        boxAnchor = (c[1] * boxX - 0.5, c[0] * boxY - 0.5)
        newRect = patches.Rectangle(
            boxAnchor, boxX, boxY, lineWidth=1, edgecolor="r", facecolor="none"
        )
        ax.add_patch(newRect)

In [None]:
class VisualisePredictions:
    """
    Class for plotting label predictions or contiguous regions over their
    corresponding sunmap(s).
    Designed to be used in a notebook with the matplotlib widget backend.

    Parameters
    ----------
    fileList : List[str]
        List of paths to the fits files to be used.
    predictions: np.ndarray[np.ndarray]
        Image stack of binary mask predictions, shape
        (len(fileList), boxSize, boxSize)
    boxSize : Optional[int]
        The size of the boxes used to define the regions.
    """

    def __init__(self, fileList, predictions, plot_regions=False, boxSize=64):
        self.idx = 0
        self.files = fileList
        self.predictions = predictions
        self.plot_regions = plot_regions
        self.boxsize = boxSize
        if self.plot_regions is True:
            self.get_regions()
        im = Map(self.files[0])
        self.fig = plt.figure(figsize=(8, 8))
        self.ax = plt.subplot(projection=im)
        self.setup_buttons()
        self.setup_im(0)

    def get_regions(self):
        self.regions = {}
        for index, prediction in enumerate(self.predictions):
            self.regions[index] = []
            labeled_array, num_features = ndimage.label(prediction)
            for idx in range(1, num_features + 1):
                self.regions[index].append(
                    select_region(labeled_array, idx, boxSize=self.boxsize)
                )

    def setup_buttons(self, startingSlider=0):
        self.slider = widgets.IntSlider(
            startingSlider, 0, len(self.files) - 1, description="Image Index:"
        )
        self.slider.layout.margin = "0px 10% 0px 10%"
        self.slider.layout.width = "40%"
        display.display(self.slider)
        self.slider.observe(self.change_image, names="value")

    def setup_im(self, idx):
        self.idx = idx
        for pIdx in range(len(self.ax.patches) - 1, -1, -1):
            self.ax.patches[pIdx].remove()
        self.fig.canvas.flush_events()
        self.im = Map(self.files[idx])
        self.im.data[np.isnan(self.im.data)] = 0
        self.im.plot(axes=self.ax)
        if self.plot_regions is True:
            self.region = self.regions[idx]
            for corner in self.region:
                overplot_rect_from_coords(self.ax, corner[0], corner[1])
        else:
            self.mask = self.predictions[idx]
            overplot_spots_from_mask(self.ax, self.mask, boxSize=self.boxsize)

    def change_image(self, event):
        self.setup_im(event["new"])

    def __getstate__(self):
        s = {}
        s["idx"] = self.idx
        s["files"] = self.files
        s["predictions"] = self.predictions
        s["plot_regions"] = self.plot_regions
        s["regions"] = self.regions
        s["boxSize"] = self.boxsize
        return s

    def __setstate__(self, s):
        self.files = s["files"]
        self.predictions = s["predictions"]
        self.plot_regions = s["plot_regions"]
        self.regions = s["regions"]
        self.boxsize = s["boxSize"]

        self.fig = plt.figure()
        im = Map(self.files[0])
        self.ax = plt.subplot(projection=im)
        self.ax = self.fig.gca()
        self.setup_buttons(startingSlider=s["idx"])
        self.setup_im(s["idx"])
        self.slider.send_state({"value": s["idx"]})

In [None]:
class SunspotSelector:
    """
    Class for selecting rectangular regions of an image sequence.
    Designed to be used in a notebook with the matplotlib widget backend.
    Reworked for Kedro compatibility, dataset dict should be output from Kedro
    catalog load.
    Can specify a local or remote dataset as determined by catalog entry specs.
    """

    def __init__(
        self, dataset: dict, dataset_name: str, pixels_per_cell: tuple[int, int]
    ) -> None:
        self.dataset = dataset
        self.dataset_name = dataset_name
        self.keys = list(self.dataset.keys())
        self.pixels_per_cell = pixels_per_cell
        self.hmi_coords_all = [[] for _ in self.keys]
        self.box_coords_all = [[] for _ in self.keys]

        self.index = 0
        image = self.dataset[self.keys[self.index]]()
        self.fig = plt.figure()
        self.axes = plt.subplot(projection=image)
        self.setup_buttons()
        self.setup_image(self.index)

    def setup_buttons(self, starting_slider: int = 0) -> None:
        """Sets up buttons for the selector widget."""
        self.receiver = self.fig.canvas.mpl_connect("button_press_event", self.on_click)
        self.slider = widgets.IntSlider(
            starting_slider, 0, len(self.keys) - 1, description="File Number"
        )
        display.display(self.slider)
        self.slider.observe(self.change_image, names="value")
        self.clear_button = widgets.Button(description="Clear Image")
        display.display(self.clear_button)
        self.clear_button.on_click(self.clear)

    def setup_image(self, index: int) -> None:
        """Plots image of a specified index from the dataset."""
        self.index = index
        for patch_index in range(len(self.axes.patches) - 1, -1, -1):
            self.axes.patches[patch_index].remove()
        self.fig.canvas.flush_events()
        self.image = self.dataset[self.keys[index]]()
        self.image.plot(axes=self.axes)
        self.hmi_coords = self.hmi_coords_all[index]
        self.box_coords = self.box_coords_all[index]

        for coord in self.box_coords:
            box_x = self.pixels_per_cell[0]
            box_y = self.pixels_per_cell[1]
            box_anchor = (coord[0] * box_x - 0.5, coord[1] * box_y - 0.5)
            new_rect = patches.Rectangle(
                box_anchor, box_x, box_y, linewidth=1, edgecolor="r", facecolor="none"
            )
            self.axes.add_patch(new_rect)
        self.axes.autoscale_view()

    def change_image(self, event: dict) -> None:
        """Changes index of image setup function."""
        self.setup_image(event["new"])

    def on_click(self, event: dict) -> None:
        """
        Determines the behaviour of clicking on the selector widget.
        If wrong mode is selected, the clicking will do nothing.
        If clicking on an existing box, the box will be removed.
        If clicking on blank space, a new box will be added.
        """
        if self.fig.canvas.manager.toolbar.mode != "":
            return

        box_x = self.pixels_per_cell[0]
        box_y = self.pixels_per_cell[1]
        box_coord = (
            int((event.xdata + 0.5) // box_x),
            int((event.ydata + 0.5) // box_y),
        )
        box_anchor = (box_coord[0] * box_x - 0.5, box_coord[1] * box_y - 0.5)
        if box_coord in self.box_coords:
            for patch in self.axes.patches:
                if patch.get_xy() == box_anchor:
                    patch.remove()

            index = self.box_coords.index(box_coord)
            del self.box_coords[index]
            del self.hmi_coords[index]
        else:
            self.hmi_coords.append((event.xdata, event.ydata))
            self.box_coords.append(box_coord)
            new_rect = patches.Rectangle(
                box_anchor, box_x, box_y, linewidth=1, edgecolor="r", facecolor="none"
            )
            self.axes.add_patch(new_rect)

    def disconnect_matplotlib(self, _) -> None:
        """Disconnect the receiver's callback id."""
        self.fig.canvas.mpl_disconnect(self.receiver)

    def clear(self, _) -> None:
        """Clears all boxes from the widget's current selected image."""
        self.hmi_coords[:] = []
        self.box_coords[:] = []
        for patch_index in range(len(self.axes.patches) - 1, -1, -1):
            self.axes.patches[patch_index].remove()
        self.fig.canvas.flush_events()
        self.fig.canvas.flush_events()

    def display_widget(self, dataset: dict) -> None:
        """
        Image widget initialiser for post-Pickle reload. Required because
        dataset dict of callables cannot be pickled to cloud. For cases where
        only the box coords and hmi data are required, this function need not be
        called.
        """
        self.dataset = dataset
        self.fig = plt.figure()
        image = self.dataset[self.keys[0]]()
        self.axes = plt.subplot(projection=image)
        self.axes = self.fig.gca()
        self.setup_buttons(starting_slider=self.index)
        self.setup_image(self.index)
        self.slider.send_state({"value": self.index})

    def __getstate__(self) -> dict:
        state = {
            "index": self.index,
            "dataset_name": self.dataset_name,
            "keys": self.keys,
            "pixels_per_cell": self.pixels_per_cell,
            "hmi_coords_all": self.hmi_coords_all,
            "box_coords_all": self.box_coords_all,
        }
        return state

    def __setstate__(self, state: dict) -> None:
        self.index = state["index"]
        self.dataset_name = state["dataset_name"]
        self.keys = state["keys"]
        self.pixels_per_cell = state["pixels_per_cell"]
        self.hmi_coords_all = state["hmi_coords_all"]
        self.box_coords_all = state["box_coords_all"]

In [8]:
180 / 0.75

240.0