In [None]:
from pathlib import Path
from typing import Sequence

import matplotlib.pyplot as plt
import numpy as np
import xarray
from matplotlib.colors import LogNorm
from skimage import color, data, filters
from skimage.transform import resize
from spectral import io, mixing, unmixing
from spectral.binlets.components import binlets_components_transform
from spectral.binlets.independent import binlets_independent_components, binlets_poisson

data_dir = Path("../data/")


class Dim:
    spectrum = "spectrum"
    components = "components"


DATA_NAMES = [
    "astronaut",
    "camera",
    "cat",
    "coffee",
    "eagle",
    "grass",
    "gravel",
    "rocket",
]


def build_components(
    *,
    shape: tuple[int, int],
    dim: str,
    order: Sequence[str],
    sigma: float = 0,
):
    assert len(order) <= len(DATA_NAMES), "Not enough stock images"

    def get_image(name: str):
        image: np.ndarray = getattr(data, name)()
        if image.ndim == 3:
            image = color.rgb2gray(image)
        image = resize(image, shape).astype(np.float64)
        if sigma > 0:
            image = filters.gaussian(image, sigma)
        return image

    images = np.empty((len(order), *shape), dtype=np.float64)
    for ndx, name in enumerate(DATA_NAMES[: len(order)]):
        images[ndx] = get_image(name)

    return xarray.DataArray(
        images,
        coords={dim: order},
        dims=[dim, "y", "x"],
    )

In [None]:
calibration = io.load_calibration(
    data_dir / "reference spectra (original).xlsx",
    input_name=Dim.spectrum,
    output_name=Dim.components,
)

## Ground truth

In [None]:
class GroundTruth:
    components = build_components(
        shape=(256, 256),
        dim=Dim.components,
        order=calibration.coords[Dim.components],
        sigma=0,
    )
    spectrum = mixing.matmul(calibration, components)

## Measurement

In [None]:
rng = np.random.default_rng(0)


class Measured:
    spectrum = xarray.apply_ufunc(rng.poisson, GroundTruth.spectrum)
    components = unmixing.lstsq(calibration, spectrum)
    components_weighted = unmixing.weighted_least_squares(calibration, spectrum)

## Denoising

In [None]:
sigma = 3

In [None]:
class ContinuousSpectrum:
    spectrum = binlets_poisson(
        Measured.spectrum,
        sigma=sigma,
        dim=None,
        independent=True,
    )
    components = unmixing.lstsq(calibration, spectrum)
    components_weighted = unmixing.weighted_least_squares(calibration, spectrum)

In [None]:
class FullSpectrum:
    spectrum = binlets_poisson(
        Measured.spectrum,
        sigma=sigma,
        dim=Dim.spectrum,
        independent=False,
    )
    components = unmixing.lstsq(calibration, spectrum)
    components_weighted = unmixing.weighted_least_squares(calibration, spectrum)

In [None]:
class SingleSpectrum:
    spectrum = binlets_poisson(
        Measured.spectrum,
        sigma=sigma,
        dim=Dim.spectrum,
        independent=True,
    )
    components = unmixing.lstsq(calibration, spectrum)
    components_weighted = unmixing.weighted_least_squares(calibration, spectrum)

In [None]:
class SingleComponent:
    components = components_weighted = binlets_independent_components(
        calibration,
        Measured.spectrum,
        sigma=sigma,
    )

In [None]:
class FullComponent:
    spectrum = binlets_components_transform(
        Measured.spectrum,
        sigma=sigma,
        dim=Dim.spectrum,
        calibration=calibration,
    )
    components = unmixing.lstsq(calibration, spectrum)
    components_weighted = unmixing.weighted_least_squares(calibration, spectrum)

## Results

In [None]:
methods = [
    SingleSpectrum,
    ContinuousSpectrum,
    FullSpectrum,
    SingleComponent,
    FullComponent,
]

In [None]:
GroundTruth.components.plot.imshow(col="components")
GroundTruth.spectrum.isel({Dim.spectrum: slice(None, None, 8)}).plot.imshow(
    col=Dim.spectrum
)

In [None]:
def plot():
    fig, axes = plt.subplots(
        nrows=len(methods),
        ncols=8,
        sharex="col",
        sharey="col",
        figsize=(12, 6),
    )

    x = Measured.components_weighted - GroundTruth.components

    for axrow, method in zip(axes, methods):
        axrow[0].set_ylabel(method.__name__, rotation=0, ha="right")

        y = method.components_weighted - GroundTruth.components
        for ax, k in zip(axrow, calibration.coords[Dim.components].values):
            ax.set(aspect="equal")
            sel = {Dim.components: k}
            xr = x.sel(sel).values.ravel()
            yr = y.sel(sel).values.ravel()
            ax.hist2d(xr, yr, bins=100, norm=LogNorm())
            ax.axline((0, 0), slope=1, color="red")

    for ax, k in zip(axes[0], calibration.coords[Dim.components].values):
        ax.set(title=k)


plot()

In [None]:
import pandas as pd


def yield_cov():
    x = Measured.components_weighted - GroundTruth.components

    for method in methods:
        y = method.components_weighted - GroundTruth.components
        for k in calibration.coords[Dim.components].values:
            sel = {Dim.components: k}
            xr = x.sel(sel).values.ravel()
            yr = y.sel(sel).values.ravel()
            var_x, var_y = np.diag(np.cov(xr, yr))
            cov_ratio = (var_x / var_y) ** 0.5
            yield {"method": method.__name__, "channel": k, "MRSE ratio": cov_ratio}


cov_ratios = pd.DataFrame(yield_cov())
cov_ratios.pivot(index="method", columns="channel").round(3).loc[
    [m.__name__ for m in methods]
]