# Linear Unmixing of Spectral Fluorescence Microscopy Data

In this notebook we perform linear unmixing on spectral data from Fluorescence Microscopy.

In this case, for a given pixel, we suppose to have a set of intensity measurements at different wavelengths, e.g., $y = [y(\lambda_1),y(\lambda_2),\dots,y(\lambda_n)]$, with $n=32$ for instance. For each one of these spectral bands $\lambda_i$, with $i=1,\dots,n$, and for each fluorophore $f$, with $f=1,\dots,m$, we assume the reference spectra $R_f=[R_f(\lambda_1), R_f(\lambda_2), \dots, R_f(\lambda_n)]$ to be known. 

### 1. Data Preparation

Load mixed image & metadata:

In [None]:
import os
import json
import tifffile as tiff

DATA_DIR = '/group/jug/federico/microsim/sim_spectral_data/240715_v1'
load_mip = False

In [None]:
mixed_opt_img = tiff.imread(
    os.path.join(
        DATA_DIR, 
        f"{"mips" if load_mip else "imgs"}/optical_mixed{"_mip" if load_mip else ""}.tif"
    )
)
print("Loaded optical mixed image!")

In [None]:
mixed_digital_img = tiff.imread(
    os.path.join(
        DATA_DIR, 
        f"{"mips" if load_mip else "imgs"}/digital_mixed{"_mip" if load_mip else ""}.tif"
    )
)
print("Loaded digital mixed image!")

In [None]:
with open(os.path.join(DATA_DIR, "sim_coords.json"), "r") as f:
    coords_metadata = json.load(f)

try:    
    with open(os.path.join(DATA_DIR, "sim_metadata.json"), "r") as f:
        sim_metadata = json.load(f)
except FileNotFoundError as e:
    print("Metadata file not found!")
    sim_metadata = None

In [None]:
# Load GT
gt_img = tiff.imread(os.path.join(DATA_DIR, "ground_truth_img.tif"))

In [None]:
from utils import coarsen_img

try:
    downscaling = int(sim_metadata["downscale"])
except:
    downscaling = 2
gt_img_downsc = coarsen_img(gt_img, downscaling)

In [None]:
mixed_opt_img.shape, mixed_digital_img.shape, gt_img.shape, gt_img_downsc.shape, coords_metadata.keys(), sim_metadata.keys()

In [None]:
for k, v in sim_metadata.items():
    print(f"{k}: {v}")

Compute *PSNR* for the Digital Image w.r.t. the downscaled optical image

In [None]:
try:
    downscaling = int(sim_metadata["downscale"])
except:
    downscaling = 2
mixed_opt_img_downsc = coarsen_img(mixed_opt_img, downscaling)

In [None]:
from utils.metrics import spectral_PSNR

dig_psnr = spectral_PSNR(gt=mixed_opt_img_downsc, pred=mixed_digital_img)
print(f"PSNR digital wrt optical: {dig_psnr:.2f}")

Get reference spectra from `FPBase` using `microsim` API:

In [None]:
from microsim.schema.sample import Fluorophore

def fetch_FPs(fp_names: list[str]) -> list[Fluorophore]:
    return [Fluorophore.from_fpbase(name=fp_name) for fp_name in fp_names]

fp1, fp2, fp3 = fetch_FPs(sim_metadata["fluorophores"])

In [None]:
import xarray as xr

fp1_em = xr.DataArray(fp1.emission_spectrum.intensity, coords=[fp1.emission_spectrum.wavelength.magnitude], dims=["w"])
fp2_em = xr.DataArray(fp2.emission_spectrum.intensity, coords=[fp2.emission_spectrum.wavelength.magnitude], dims=["w"])
fp3_em = xr.DataArray(fp3.emission_spectrum.intensity, coords=[fp3.emission_spectrum.wavelength.magnitude], dims=["w"])

In [None]:
# Bin the emission spectra to the same wavelength range as the data
em_bins = coords_metadata["w_bins"]
sbins = sorted(set([bins[0] for bins in em_bins] + [em_bins[-1][1]]))

fp1_em_binned = fp1_em.groupby_bins(fp1_em["w"], sbins).sum()
fp2_em_binned = fp2_em.groupby_bins(fp2_em["w"], sbins).sum()
fp3_em_binned = fp3_em.groupby_bins(fp3_em["w"], sbins).sum()

In [None]:
# Replace nan values with 0
fp1_em_binned = fp1_em_binned.fillna(0)
fp2_em_binned = fp2_em_binned.fillna(0)
fp3_em_binned = fp3_em_binned.fillna(0)

**OBSERVATION**
The mixed image is a 16bit image (range: 0-6.5e4), whereas the intensity of fluorophores emission spectra ranges in 0-1 before the binning.

Intuitively, intensity ranges should be the same. However does this really matter?

In my understanding, the answer is NO. Let's see why:

- Suppose the case of 0-1 range normalization. In that case, normalization is obtained by simply dividing each pixel's intensity by the maximum intensity in the image. In other terms we basically divide by a scalar. Therefore, supposing that we normalize in this way both the mixed image and the reference spectra, the linear system becomes:

\begin{equation}
\frac{1}{k_I}y = \frac{1}{k_R}\mathbf{R}c
\end{equation}

where $k_I$ and $k_R$ are scalar. Therefore the solution of this system is the same up to some multiplicative constants.

Therefore, we can normalize everything in the range 0-1 so that quantities are in the same scale.

In [None]:
mixed_opt_img = (mixed_opt_img - mixed_opt_img.min()) / (mixed_opt_img.max() - mixed_opt_img.min())
fp1_em_binned = (fp1_em_binned - fp1_em_binned.min()) / (fp1_em_binned.max()- fp1_em_binned.min())
fp2_em_binned = (fp2_em_binned - fp2_em_binned.min()) / (fp2_em_binned.max()- fp2_em_binned.min())
fp3_em_binned = (fp3_em_binned - fp3_em_binned.min()) / (fp3_em_binned.max()- fp3_em_binned.min())

Prepare the matrix R of reference fluorophore intensities:

In [None]:
import numpy as np

fp_ref_matrix = np.stack([fp1_em_binned.values, fp2_em_binned.values, fp3_em_binned.values], axis=1)

### 2. Compute the LS solution

In [None]:
from methods.LeastSquares import lstsq_fit

Solving for Optical Image:

In [None]:
fp_conc_opt_img = lstsq_fit(mixed_opt_img, fp_ref_matrix)

Solving for Digital Image:

In [None]:
fp_conc_digital_img = lstsq_fit(mixed_digital_img, fp_ref_matrix)

### 3. Visualizing results

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle("MIP of FP concentrations computed from Optical (clean) image", fontsize=16)
if load_mip:
    ax[0].imshow(fp_conc_opt_img[0, :, :])
    ax[1].imshow(fp_conc_opt_img[1, :, :])
    ax[2].imshow(fp_conc_opt_img[2, :, :])
else:
    ax[0].imshow(fp_conc_opt_img.max(axis=1)[0, :, :])
    ax[1].imshow(fp_conc_opt_img.max(axis=1)[1, :, :])
    ax[2].imshow(fp_conc_opt_img.max(axis=1)[2, :, :])


fig, ax = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle("MIP of FP concentrations computed from Digital (noisy) image", fontsize=16)
if load_mip:
    ax[0].imshow(fp_conc_digital_img[0, :, :])
    ax[1].imshow(fp_conc_digital_img[1, :, :])
    ax[2].imshow(fp_conc_digital_img[2, :, :])
else:
    ax[0].imshow(fp_conc_digital_img.max(axis=1)[0, :, :])
    ax[1].imshow(fp_conc_digital_img.max(axis=1)[1, :, :])
    ax[2].imshow(fp_conc_digital_img.max(axis=1)[2, :, :])

### 4. Evaluation

Compute error with respect to ground truth.

The ground truth is a `(F, Z, Y, X)` array which reports the number of fluorophores per pixel.

In [None]:
# Visualize GT
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle("MIP of Ground Truth FP distribution (original)", fontsize=16)
ax[0].imshow(gt_img.max(axis=1)[0, :, :])
ax[1].imshow(gt_img.max(axis=1)[1, :, :])
ax[2].imshow(gt_img.max(axis=1)[2, :, :])


fig, ax = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle("MIP of Ground Truth FP distribution (downscaled)", fontsize=16)
ax[0].imshow(gt_img_downsc.max(axis=1)[0, :, :])
ax[1].imshow(gt_img_downsc.max(axis=1)[1, :, :])
ax[2].imshow(gt_img_downsc.max(axis=1)[2, :, :])

We appy min-max normalization to both ground truth (#FP/pixel) and the LS result (FP concentratio/pixel) in order to get comparable arrays

In [None]:
from utils import channel_wise_norm, pixel_wise_sum_to_one

norm_fp_conc_opt_img = channel_wise_norm(fp_conc_opt_img)
norm_fp_conc_digital_img = channel_wise_norm(fp_conc_digital_img)
norm_gt_img = channel_wise_norm(gt_img)
norm_gt_img_downsc = channel_wise_norm(gt_img_downsc)

# norm_fp_conc_opt_img = pixel_wise_sum_to_one(fp_conc_opt_img)
# norm_fp_conc_digital_img = pixel_wise_sum_to_one(fp_conc_digital_img)
# norm_gt_img = pixel_wise_sum_to_one(gt_img)
# norm_gt_img_downsc = pixel_wise_sum_to_one(gt_img_downsc)


Now, we compute and visualize the error:

In [None]:
from utils.metrics import pixel_wise_mse
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, ax = plt.subplots(3, 3, figsize=(15, 15))
# fig.suptitle("Ground Truth FP distribution vs. Unmixed FP Concentrations (normalize in 0-1)", fontsize=16)

mse1 = pixel_wise_mse(norm_gt_img[0, ...], norm_fp_conc_opt_img[0, ...])
ax[0,0].set_title("GT (flurophore distribution)")
im0 = ax[0,0].imshow(norm_gt_img.max(axis=1)[0, :, :])
ax[0,1].set_title("Unmixing Result (flurophore distribution)")
im1 = ax[0,1].imshow(norm_fp_conc_opt_img.max(axis=1)[0, :, :])
ax[0,2].set_title("Pixel-wise MSE (flurophore distribution)")
im2 = ax[0,2].imshow(mse1.max(axis=0), cmap="RdPu")
divider = make_axes_locatable(ax[0,2])
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im2, cax=cax)
ax[0,2].text(
    0.66, 0.1, f'MSE: {mse1.mean():.2e}', transform=ax[0,2].transAxes,
    fontsize=12, verticalalignment='center', bbox=dict(facecolor='white', alpha=0.5)
)

mse2 = pixel_wise_mse(norm_gt_img[1, ...], norm_fp_conc_opt_img[1, ...])
im3 = ax[1,0].imshow(norm_gt_img.max(axis=1)[1, :, :])
im4 = ax[1,1].imshow(norm_fp_conc_opt_img.max(axis=1)[1, :, :])
im5 = ax[1,2].imshow(mse2.max(axis=0), cmap="RdPu")
divider = make_axes_locatable(ax[1,2])
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im5, cax=cax)
ax[1,2].text(
    0.66, 0.1, f'MSE: {mse2.mean():.2e}', transform=ax[1,2].transAxes,
    fontsize=12, verticalalignment='center', bbox=dict(facecolor='white', alpha=0.5)
)

mse3 = pixel_wise_mse(norm_gt_img[2, ...], norm_fp_conc_opt_img[2, ...])
im6 = ax[2,0].imshow(norm_gt_img.max(axis=1)[2, :, :])
im7 = ax[2,1].imshow(norm_fp_conc_opt_img.max(axis=1)[2, :, :])
im8 = ax[2,2].imshow(mse3.max(axis=0), cmap="RdPu")
divider = make_axes_locatable(ax[2,2])
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im8, cax=cax)
ax[2,2].text(
    0.66, 0.1, f'MSE: {mse3.mean():.2e}', transform=ax[2,2].transAxes,
    fontsize=12, verticalalignment='center', bbox=dict(facecolor='white', alpha=0.5)
)

plt.tight_layout()

# Quantititive results
print(f"Optical Image MSE: {mse1.mean():.2e}, {mse2.mean():.2e}, {mse3.mean():.2e}")

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(15, 15))
# fig.suptitle("Ground Truth FP distribution vs. Unmixed FP Concentrations (normalize in 0-1)", fontsize=16)

mse1 = pixel_wise_mse(norm_gt_img_downsc[0, ...], norm_fp_conc_digital_img[0, ...])
ax[0,0].set_title("GT (flurophore distribution)")
im0 = ax[0,0].imshow(norm_gt_img_downsc.max(axis=1)[0, :, :])
ax[0,1].set_title("Unmixing Result (flurophore distribution)")
im1 = ax[0,1].imshow(norm_fp_conc_digital_img.max(axis=1)[0, :, :])
ax[0,2].set_title("Pixel-wise MSE (flurophore distribution)")
im2 = ax[0,2].imshow(mse1.max(axis=0), cmap="RdPu")
divider = make_axes_locatable(ax[0,2])
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im2, cax=cax)
ax[0,2].text(
    0.66, 0.1, f'MSE: {mse1.mean():.2e}', transform=ax[0,2].transAxes,
    fontsize=12, verticalalignment='center', bbox=dict(facecolor='white', alpha=0.5)
)

mse2 = pixel_wise_mse(norm_gt_img_downsc[1, ...], norm_fp_conc_digital_img[1, ...])
im3 = ax[1,0].imshow(norm_gt_img_downsc.max(axis=1)[1, :, :])
im4 = ax[1,1].imshow(norm_fp_conc_digital_img.max(axis=1)[1, :, :])
im5 = ax[1,2].imshow(mse2.max(axis=0), cmap="RdPu")
divider = make_axes_locatable(ax[1,2])
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im5, cax=cax)
ax[1,2].text(
    0.66, 0.1, f'MSE: {mse2.mean():.2e}', transform=ax[1,2].transAxes,
    fontsize=12, verticalalignment='center', bbox=dict(facecolor='white', alpha=0.5)
)

mse3 = pixel_wise_mse(norm_gt_img_downsc[2, ...], norm_fp_conc_digital_img[2, ...])
im6 = ax[2,0].imshow(norm_gt_img_downsc.max(axis=1)[2, :, :])
im7 = ax[2,1].imshow(norm_fp_conc_digital_img.max(axis=1)[2, :, :])
im8 = ax[2,2].imshow(mse3.max(axis=0), cmap="RdPu")
divider = make_axes_locatable(ax[2,2])
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im8, cax=cax)
ax[2,2].text(
    0.66, 0.1, f'MSE: {mse3.mean():.2e}', transform=ax[2,2].transAxes,
    fontsize=12, verticalalignment='center', bbox=dict(facecolor='white', alpha=0.5)
)

plt.tight_layout()

# Quantititive results
print(f"Digital Image MSE: {mse1.mean():.2e}, {mse2.mean():.2e}, {mse3.mean():.2e}")