In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import xarray
from matplotlib.colors import LogNorm
from spectral import io

data_dir = Path("../data/")


class Dim:
    spectrum = "spectrum"
    components = "components"

In [None]:
crop_rectangles = dict(io.yield_crop_rectangles(data_dir / "crop rectangles.txt"))

calibration = io.load_calibration(
    data_dir / "reference spectra (original).xlsx",
    input_name=Dim.spectrum,
    output_name=Dim.components,
)

orig = io.read_stack(
    data_dir / "FOV 01751/orig",
    dim=Dim.spectrum,
).isel(crop_rectangles["FOV 01751"])

unmixed_reference = io.read_stack(
    data_dir / "FOV 01751/unmixed - using original ref spectra",
    dim=Dim.components,
)

## Unmixing with lstsq

In [None]:
def unmixing_with_lstsq(
    A: xarray.DataArray,
    b: xarray.DataArray,
) -> xarray.DataArray:
    if A.dims[0] in b.dims:
        common_dim, output_dim = A.dims
    else:
        common_dim, output_dim = A.dims[::-1]
    remaining_dims = [d for d in b.dims if d != common_dim]

    A = A.transpose(common_dim, output_dim)
    b = b.transpose(common_dim, *remaining_dims)
    x = np.linalg.lstsq(
        A.values,
        b.values.reshape(A.values.shape[0], -1),
        rcond=None,
    )[0].reshape(A.values.shape[1], *b.values.shape[1:])
    return xarray.DataArray(
        x,
        coords={output_dim: A.coords[output_dim]},
        dims=[output_dim, *remaining_dims],
    )

In [None]:
unmixed_with_lstsq = unmixing_with_lstsq(calibration, orig)

In [None]:
(unmixed_with_lstsq - unmixed_reference).plot.imshow(col="components", col_wrap=3)

In [None]:
fig, axes = plt.subplots(3, 8, sharex=True, sharey=True, figsize=(12, 5))

for ax, c in zip(axes.T, unmixed_reference.coords["components"].values):
    x = unmixed_reference.sel(components=c)
    y = unmixed_with_lstsq.sel(components=c)
    ax[0].set(title=c)
    ax[0].imshow(x)
    ax[1].imshow(y)
    ax[2].imshow(np.log(y / x), cmap="seismic")

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(12, 4))

for ax, c in zip(axes.flat, unmixed_reference.coords["components"].values):
    ax.set(title=c)
    ax.hist2d(
        unmixed_reference.sel(components=c).values.ravel(),
        unmixed_with_lstsq.sel(components=c).values.ravel(),
        bins=100,
        norm=LogNorm(),
    )