# WBC Quality control figure for paper

In [None]:
from scip_workflows.common import *


In [None]:
import matplotlib.gridspec as gridspec
import zarr
from matplotlib.colors import LinearSegmentedColormap, ListedColormap, Normalize

from scip.masking import spot, threshold

plt.rcParams["figure.dpi"] = 200


In [None]:
try:
    features = snakemake.input.features
    index = snakemake.input.index
    columns = snakemake.input.columns
    images_parent = snakemake.config["images_parent"]
    output = snakemake.output[0]
except NameError:
    data_root = Path("/home/maximl/scratch/data/vsc/datasets/wbc/")
    # data_root = Path(os.environ["VSC_DATA_VO_USER"]) / "datasets/wbc"
    data = data_root / "scip" / "20220713131400"
    features = data / "features.parquet"
    index = data / "indices/index.npy"
    columns = data / "indices/columns.npy"
    images_parent = "/home/maximl/scratch/data/vsc/datasets/wbc/images"
    output = data / "figures" / "wbc_qc_masks.png"


In [None]:
df_scip = pq.read_table(features).to_pandas()

df_scip = df_scip[numpy.load(columns, allow_pickle=True)]
df_scip = df_scip.loc[numpy.load(index, allow_pickle=True)]
df_scip.index = df_scip.index.set_levels([2, 3, 4], level="meta_group")
df_scip.shape


In [None]:
df_scip["meta_path"] = df_scip["meta_path"].apply(
    lambda p: Path(images_parent).joinpath(*Path(p).parts[-2:])
)


In [None]:
channel_ind = [0, 8, 5]
channel_names = ["BF1", "BF2", "SSC"]
n = 15


In [None]:
sel1 = df_scip["feat_spot_area_SSC"] < 100
sel2 = df_scip["feat_spot_area_SSC"] > 50

spot_cells = df_scip[sel1 & sel2]


In [None]:
pixels = []
masks = dict(threshold=[], spot=[])
for i in range(n):
    r = spot_cells.iloc[i]
    print(r.meta_path, r.meta_zarr_idx)
    z = zarr.open(r.meta_path, mode="r")
    pixels.append(
        z[r.meta_zarr_idx].reshape(z.attrs["shape"][r.meta_zarr_idx])[channel_ind]
    )
    masks["threshold"].append(
        threshold.get_mask(dict(pixels=pixels[-1]), 0, smooth=[0.5, 0, 0.5])["mask"]
    )
    masks["spot"].append(spot.get_mask(dict(pixels=pixels[-1]), 0, spotsize=5)["mask"])


In [None]:
# basec = plt.get_cmap("Reds")(100)[:3]
basec = (0, 0, 0)
cm = LinearSegmentedColormap.from_list("test", [basec + (0,), basec + (1,)], N=2)


In [None]:
fig = plt.figure(dpi=200, figsize=(n * 0.5, len(channel_ind) * 1.7), tight_layout=True)
grid = gridspec.GridSpec(2, 1, figure=fig)
cmap = plt.get_cmap("viridis")

gs = {
    k: grid[i, 0].subgridspec(len(channel_ind), n)
    for i, (k, v) in enumerate(masks.items())
}
for k, v in masks.items():
    for i, (mask, pixel) in enumerate(zip(v, pixels)):
        if k == "spot":
            mask = mask[[2]]
            pixel = pixel[[2]]
        for j, (m, p) in enumerate(zip(mask, pixel)):
            ax = plt.Subplot(fig, gs[k][j, i])
            ax.imshow(p)
            ax.imshow(cm(~m), alpha=0.5)
            ax.set_axis_off()
            fig.add_subplot(ax)
            if i == 0:
                if k == "spot":
                    ax.set_title(
                        k.capitalize() + " mask " + channel_names[2],
                        loc="left",
                        fontdict=dict(fontsize=9),
                    )
                else:
                    ax.set_title(
                        k.capitalize() + " mask " + channel_names[j],
                        loc="left",
                        fontdict=dict(fontsize=9),
                    )

plt.savefig(output, bbox_inches="tight")
