# LaQuacco 🍅

## Laboratory Quality Control

### Module Imports

In [None]:
import multiprocessing
import os
import platform
import sys
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import laquacco as laq  # required by Jupyter

### User Input

In [None]:
# define number of concurrent workers
processes = multiprocessing.cpu_count() // 2 or 1

# define relative samples size for normalization
sample_perc = 20

# define signal threshold for channel statistics
chan_thrld = np.nan

# define file search patterns
data_dir = r".\tests\Polaris"  # use a raw string (r"")
data_ext = "*.tif"  # include files matching pattern
anti_ext = ""  # exclude files matching pattern
recurse = True  # find files in subdirectories

# render channel images of outliers
show_img = False

### Check Files

In [None]:
# get a list of all image files
files = sorted(
    laq.get_files(
        path=data_dir,
        pat=data_ext,
        anti=anti_ext,
        recurse=recurse,
    ),
    key=str.lower,
)

print(f"Found {len(files)} image files in {os.path.abspath(data_dir)}:")
for file in files:
    print(f"{file.replace(data_dir, '.')}")

### Main Program

In [None]:
if __name__ == "__main__":
    # safe import of main module avoids spawning multiple processes simultaneously
    if platform.system() == "Windows":
        multiprocessing.freeze_support()  # required by 'multiprocessing'

    # sample experimental image data
    try:
        samples = sorted(
            laq.get_samples(population=files, perc=sample_perc), key=str.lower
        )
        sample_args = [(sample, None) for sample in samples]
    except ValueError:
        print("Could not draw samples from experimental population.")
        sys.exit(1)

    """
    # analyze the sample data
    with multiprocessing.Pool(processes) as pool:
        sample_results = pool.starmap(laq.read_img_data, sample_args)
        pool.close()  # wait for worker tasks to complete
        pool.join()  # wait for worker process to exit
    samples_img_data = {sample: img_data for (sample, img_data) in sample_results}
    print(samples_img_data)
    print()
    """
    sample_results = []
    for sample in samples:
        sample_results.append(laq.read_img_data(sample))
    samples_img_data = {sample: img_data for (sample, img_data) in sample_results}
    print()

    # prepare channels
    chans_set = set()  # avoid duplicate entries
    for img_data in samples_img_data.values():
        for chan in img_data:
            if chan not in ["metadata"]:
                chans_set.add(chan)
    #chans = sorted(chans_set, key=str.lower)
    chans = list(chans_set)

    # prepare colormap
    color_map = laq.get_colormap(len(chans))

    # prepare lambdas for power transform
    chan_lmbdas = {}
    chan_thrlds = {}
    for chan in chans:
        chan_lmbdas[chan] = laq.get_mean(
            laq.get_chan_data(samples_img_data, chan, "chan_lmbda")
        )
        chan_thrlds[chan] = (
            laq.get_mean(laq.get_chan_data(samples_img_data, chan, "chan_thrld"))
            if np.isnan(chan_thrld)
            else float(chan_thrld)
        )
        print(f"{chan}:", flush=True)
        print(f"\tLambda: {chan_lmbdas[chan]},\n\tThreshold: {chan_thrlds[chan]}")
    print()

    """
    # analyze experimental image data
    image_args = [(image, chan_thrlds) for image in files]
    with multiprocessing.Pool(processes) as pool:
        image_results = pool.starmap(laq.read_img_data, image_args)
        pool.close()
        pool.join()
    images_img_data = {image: img_data for (image, img_data) in image_results}
    """
    image_results = []
    for file in files:
        image_results.append(laq.read_img_data(file, chan_thrlds))
    images_img_data = {image: img_data for (image, img_data) in image_results}
    print(images_img_data)
    print()

    # sort experimental image data by time stamp
    images_img_data = dict(
        sorted(images_img_data.items(), key=lambda v: v[1]["metadata"]["date_time"])
    )

    # get min and max values for plotting
    chans_minmax = {}
    for c, chan in enumerate(chans):
        chan_minmaxs =  laq.get_chan_data(images_img_data, chan, "sign_minmax")
        chans_minmax[chan] = (np.nanmean([min for min, _ in chan_minmaxs]),
                              np.nanmean([max for _, max in chan_minmaxs]))

### Data Plots I - Distribution Chart

In [None]:
# prepare figure dimensions (global)
dpi = plt.rcParams["figure.dpi"]
min_pixw, min_pixh = 1600, 1200
min_width, min_height = min_pixw / dpi, min_pixh / dpi
plt.rcParams["figure.figsize"] = [min_width, min_height]

# prepare data lists
data_means = []
data_norms = []

# get data for plots
fig, ax = plt.subplots()
for c, chan in enumerate(chans):
    # get statistics summary
    signal_means = laq.get_chan_data(images_img_data, chan, "sign_mean")
    data_means.append(signal_means)
    data_norms.append(
        laq.boxcox_transform(np.array(signal_means), lmbda=chan_lmbdas[chan])[0]
    )

# create violin plot
vp = ax.violinplot(data_means, showmeans=False, showmedians=True, showextrema=False)
for v in vp["bodies"]:
    v.set_facecolor("black")
    v.set_edgecolor("black")
vp['cmedians'].set_edgecolor('dimgray')

# create boxplot
bp = ax.boxplot(data_norms, meanline=True, showmeans=True)
for b in bp["medians"]:
    b.set_color("black")
for b in bp["means"]:
    b.set_color("black")
    b.set_linestyle("dashed")
ax.set_xticks(
    [x for x in range(1, len(chans) + 1)],
    labels=chans,
    rotation=90,
    fontsize="small",
)

# add legend
legend = plt.legend(
    [vp["bodies"][0], bp["boxes"][0]],
    ["measured", "normalized"],
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    fontsize="small",
)

# show plot
plt.show()

### Data Plots I - Extreme Values

In [None]:
# get boxplot information
boxplots_stats = laq.get_boxplots_stats(bp, chans)

# list all images with extreme values per channel
for c, chan in enumerate(chans):
    print(f"\n{chan}:", flush=True)
    outliers_bp = []
    for n, data_norm in enumerate(data_norms[c]):
        if data_norm > boxplots_stats[chan]["w3"]:
            outliers_bp.append(("▲  ", n, files[n], chan, data_means[c][n]))
        elif data_norm < boxplots_stats[chan]["w1"]:
            outliers_bp.append(("▼  ", n, files[n], chan, data_means[c][n]))
    # show color bar at top
    cmap = mpl.cm.nipy_spectral
    norm = mpl.colors.Normalize(vmin=chans_minmax[chan][0],
                                vmax=chans_minmax[chan][1])
    scalarmappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
    scalarmappable.set_array([])
    fig = plt.figure(figsize=(min_width, 1))
    ax = fig.add_axes([0.0, 0.0, 1, 0.5])
    cbar = fig.colorbar(scalarmappable, cax=ax, orientation='horizontal')
    plt.show()
    if outliers_bp:
        # print list of outliers with optional channel images
        for indicator, position, file, channel, mean in outliers_bp:
            print(
                f"\n\t{indicator} {position} = {os.path.basename(file)}",
                f"  ({mean}) [{chans_minmax[chan][0]}-{chans_minmax[chan][1]}]",
            )
            if show_img:
                plt.imshow(
                    laq.get_chan_img(file, channel),
                    cmap="nipy_spectral",
                    vmin=chans_minmax[chan][0],
                    vmax=chans_minmax[chan][1],
                    resample=False,
                )
                plt.show()
    else:
        print(f"\t►  (none)")

### Data Plots II - Levey-Jennings Charts

In [None]:
# Levey-Jennings chart
slice_margin = len(files) - 1  # extend slice to either sides
fit_trend = False  # fit a linear regression model of the mean
file_len = len(files)
slice_size = min(file_len, 2 * slice_margin + 1)
assert (
    slice_size > 3
), "Zero degrees of freedom to estimate the standard deviation from the trend line."
xs = range(0, file_len)
np_nan = np.full(file_len, np.nan)
signals_lj = {}
extrema_lj = {}
for c, chan in enumerate(chans):
    # prepare variables
    run_stats = {stat: np_nan.copy() for stat in ["slice", "means", "stdevs"]}
    trend_stats = {stat: np_nan.copy() for stat in ["slice", "vals", "stdevs"]}
    # get image statistics
    signal_means = laq.get_chan_data(images_img_data, chan, "sign_mean")
    signal_stdevs = laq.get_chan_data(images_img_data, chan, "sign_stdev")
    signal_stderrs = laq.get_chan_data(images_img_data, chan, "sign_stderr")
    signals_lj[chan] = signal_means
    # get trend statistics
    if fit_trend:
        slope, inter = np.polyfit(xs, signal_means, deg=1)
        trend_stats["vals"] = slope * xs + inter
    else:
        trend_stats["vals"].fill(laq.get_mean(signal_means))
    # get running statistics
    for i, mean in enumerate(signal_means):
        run_stats["slice"] = laq.get_run_slice(signal_means, i, slice_margin)
        if run_stats["slice"].size == slice_size:
            run_stats["means"][i] = laq.get_mean(run_stats["slice"])
            run_stats["stdevs"][i] = laq.get_mean(
                laq.get_run_slice(signal_stdevs, i, slice_margin)
            )
            trend_stats["slice"] = laq.get_run_slice(
                trend_stats["vals"], i, slice_margin
            )
            trend_stats["stdevs"][i] = laq.get_stdev(
                run_stats["slice"],
                mean=laq.get_mean(trend_stats["slice"]),
                ddof=3,  # estimated: slope, intercept, and mean
            )
    # get extrema from trend line
    extrema_lj_keys = [("p2stdev", "m2stdev"), ("p1stdev", "m1stdev")]
    extrema_lj[chan] = {
        extrema_lj_keys[0][0]: trend_stats["vals"] + 2.0 * trend_stats["stdevs"],
        extrema_lj_keys[1][0]: trend_stats["vals"] + 1.0 * trend_stats["stdevs"],
        extrema_lj_keys[1][1]: trend_stats["vals"] - 1.0 * trend_stats["stdevs"],
        extrema_lj_keys[0][1]: trend_stats["vals"] - 2.0 * trend_stats["stdevs"],
    }
    # plot statistics
    if chan == chans[-1]:
        signal_labels = [os.path.basename(image) for image in images_img_data.keys()]
        plt.xticks(rotation=90, fontsize="small")
    else:
        signal_labels = range(0, len(images_img_data))
    for dist in [2.0, 1.0, -1.0, -2.0]:
        linestyle = (0, (1, 2))
        if abs(dist) == 2.0:
            linestyle = linestyle = (0, (1, 4))
        plt.plot(
            run_stats["means"] + dist * run_stats["stdevs"],
            color="black",
            linewidth=1,
            linestyle=linestyle,
        )
    for upper, lower in extrema_lj_keys:
        plt.fill_between(
            xs,
            extrema_lj[chan][upper],
            extrema_lj[chan][lower],
            color="black",
            alpha=0.2,
        )
    plt.plot(trend_stats["vals"], color="black", linewidth=1, linestyle="solid")
    plt.plot(run_stats["means"], color="black", linewidth=1, linestyle="dashed")
    plt.errorbar(
        signal_labels,
        signal_means,
        yerr=signal_stderrs,
        fmt="o-",
        linewidth=1,
        markersize=2,
        color=color_map[c],
        label=chan + " [SIG]",
    )
    legend = plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize="small")
    plt.ylim(bottom=0.0)
    plt.show()

### Data Plots II - Extreme Values

In [None]:
# list all images with extreme values per channel
for c, chan in enumerate(chans):
    print(f"\n{chan}:", flush=True)
    outliers_lj = []
    for s, signal_lj in enumerate(signals_lj[chan]):
        if signal_lj > extrema_lj[chan]["p2stdev"][s]:
            outliers_lj.append(("▲▲ ", s, files[s], chan, signal_lj))
        elif signal_lj > extrema_lj[chan]["p1stdev"][s]:
            outliers_lj.append(("▲  ", s, files[s], chan, signal_lj))
        elif signal_lj < extrema_lj[chan]["m2stdev"][s]:
            outliers_lj.append(("▼▼ ", s, files[s], chan, signal_lj))
        elif signal_lj < extrema_lj[chan]["m1stdev"][s]:
            outliers_lj.append(("▼  ", s, files[s], chan, signal_lj))
    # show color bar at top
    cmap = mpl.cm.nipy_spectral
    norm = mpl.colors.Normalize(vmin=chans_minmax[chan][0],
                                vmax=chans_minmax[chan][1])
    scalarmappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
    scalarmappable.set_array([])
    fig = plt.figure(figsize=(min_width, 1))
    ax = fig.add_axes([0.0, 0.0, 1, 0.5])
    cbar = fig.colorbar(scalarmappable, cax=ax, orientation='horizontal')
    plt.show()
    if outliers_lj:
        # print list of outliers with optional channel images
        for indicator, position, file, channel, mean in outliers_lj:
            print(
                f"\n\t{indicator} {position} = {os.path.basename(file)}"
                f"  ({mean}) [{chans_minmax[chan][0]}-{chans_minmax[chan][1]}]"
            )
            if show_img:
                plt.imshow(
                    laq.get_chan_img(file, channel),
                    cmap="nipy_spectral",
                    vmin=chans_minmax[chan][0],
                    vmax=chans_minmax[chan][1],
                    resample=False,
                )
                plt.show()
    else:
        print(f"\t►  (none)")


### Data Plots III - H-Scores

In [None]:
import tifffile


score_img_datas = laq.score_img_data(files[0], chans_minmax)

fig, ax = plt.subplots()
bottom = np.zeros(len(chans))
score_labels = tuple(score_img_datas.keys())
score_values = {"score_1": [], "score_2": [], "score_3": []}
for channel, scores in score_img_datas.items():
    for score in scores:
        score_values[score].append(score_img_datas[channel][score])
for score, scores in score_values.items():
    p = ax.bar(score_labels, scores, width=0.5, label=score, bottom=bottom)
    bottom += scores
ax.set_yscale("log")
ax.set_ylim(1e-1, 1.025e2)
plt.axhline(y=1, color="tab:blue", linestyle="dashed")
plt.axhline(y=10, color="tab:orange", linestyle="dashed")
plt.axhline(y=100, color="tab:green", linestyle="dashed")
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1],
          title='H-Score (total)',
          loc="center left",
          bbox_to_anchor=(1, 0.5),
          fontsize="small",)

plt.show()