# **Quality Control notebook**

---

<font size = 4>

---

<font size = 4>*Disclaimer*:

<font size = 4>This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.


<font size = 4>**Please also cite this original paper when using or developing this notebook.**

# **How to use this notebook?**

---

<font size = 4>Video describing how to use our notebooks are available on youtube:
  - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook
  - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook


---
###**Structure of a notebook**

<font size = 4>The notebook contains two types of cell:  

<font size = 4>**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.

<font size = 4>**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.

---
###**Table of contents, Code snippets** and **Files**

<font size = 4>On the top left side of the notebook you find three tabs which contain from top to bottom:

<font size = 4>*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.

<font size = 4>*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.

<font size = 4>*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. 

<font size = 4>**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.

<font size = 4>**Note:** The "sample data" in "Files" contains default files. Do not upload anything in here!

---
###**Making changes to the notebook**

<font size = 4>**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.

<font size = 4>To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).
You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment.

#**0. Before getting started**
---
<font size = 4> To use this notebook, pay attention to the data structure. The images you want to compare need to be organised in separate folders and have the same name.

<font size = 4>Here's a common data structure that can work:
*   Experiment A
    - **Training_source**
        - img_1.tif, img_2.tif, ... 
    - **Training_target**
        - img_1.tif, img_2.tif, ...        
    - **Prediction**
        - img_1.tif, img_2.tif, ... 

---

# **1. Initialise the Colab session**
---

## **1.1. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. 

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:

#@markdown ##Run this cell to connect your Google Drive to Colab

#@markdown * Click on the URL. 

#@markdown * Sign in your Google Account. 

#@markdown * Copy the authorization code. 

#@markdown * Enter the authorization code. 

#@markdown * Click on "Files" site on the right. Refresh the site. Your Google Drive folder should now be available here as "drive". 

#mounts user's Google Drive to Google Colab.

from google.colab import drive
drive.mount('/content/gdrive')




# **1.2. Install the dependencies**
---


In [None]:
#@markdown ##Install the dependencies

Notebook_version = '1.13'
Network = 'Quality_control'

!pip install tifffile # contains tools to operate tiff-files
!pip install wget
!pip install memory_profiler
!pip install fpdf
%load_ext memory_profiler

# ------- Common variable to all ZeroCostDL4Mic notebooks -------
import numpy as np
from matplotlib import pyplot as plt
import urllib
import os, random
import shutil 
import zipfile
from tifffile import imread, imsave
import time
import sys
import wget
from pathlib import Path
import pandas as pd
import csv
from glob import glob
from scipy import signal
from scipy import ndimage
from skimage import io
from sklearn.linear_model import LinearRegression
from skimage.util import img_as_uint
import matplotlib as mpl
from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio as psnr
from astropy.visualization import simple_norm
from skimage import img_as_float32
from skimage.util import img_as_ubyte
from tqdm import tqdm 
from fpdf import FPDF, HTMLMixin
from datetime import datetime
import subprocess
from pip._internal.operations.freeze import freeze

from tabulate import tabulate
from astropy.visualization import simple_norm

from ipywidgets import interact

# Colors for the warning messages
class bcolors:
  WARNING = '\033[31m'

W  = '\033[0m'  # white (normal)
R  = '\033[31m' # red

#Disable some of the tensorflow warnings
import warnings
warnings.filterwarnings("ignore")

print("Libraries installed")

# Check if this is the latest version of the notebook
All_notebook_versions = pd.read_csv("https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv", dtype=str)
print('Notebook version: '+Notebook_version)
Latest_Notebook_version = All_notebook_versions[All_notebook_versions["Notebook"] == Network]['Version'].iloc[0]
print('Latest notebook version: '+Latest_Notebook_version)
if Notebook_version == Latest_Notebook_version:
  print("This notebook is up-to-date.")
else:
  print(bcolors.WARNING +"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki")


## ------------------- Instance segmentation metrics ------------------------------

# Here we load the def that perform the QC, code adapted from the StarDist repo  https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py

import numpy as np
from numba import jit
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
from collections import namedtuple


matching_criteria = dict()

def label_are_sequential(y):
    """ returns true if y has only sequential labels from 1... """
    labels = np.unique(y)
    return (set(labels)-{0}) == set(range(1,1+labels.max()))


def is_array_of_integers(y):
    return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer)


def _check_label_array(y, name=None, check_sequential=False):
    err = ValueError("{label} must be an array of {integers}.".format(
        label = 'labels' if name is None else name,
        integers = ('sequential ' if check_sequential else '') + 'non-negative integers',
    ))
    is_array_of_integers(y) or print("An error occured")
    if check_sequential:
        label_are_sequential(y) or print("An error occured")
    else:
        y.min() >= 0 or print("An error occured")
    return True


def label_overlap(x, y, check=True):
    if check:
        _check_label_array(x,'x',True)
        _check_label_array(y,'y',True)
        x.shape == y.shape or _raise(ValueError("x and y must have the same shape"))
    return _label_overlap(x, y)

@jit(nopython=True)
def _label_overlap(x, y):
    x = x.ravel()
    y = y.ravel()
    overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
    for i in range(len(x)):
        overlap[x[i],y[i]] += 1
    return overlap


def intersection_over_union(overlap):
    _check_label_array(overlap,'overlap')
    if np.sum(overlap) == 0:
        return overlap
    n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
    n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
    return overlap / (n_pixels_pred + n_pixels_true - overlap)

matching_criteria['iou'] = intersection_over_union


def intersection_over_true(overlap):
    _check_label_array(overlap,'overlap')
    if np.sum(overlap) == 0:
        return overlap
    n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
    return overlap / n_pixels_true

matching_criteria['iot'] = intersection_over_true


def intersection_over_pred(overlap):
    _check_label_array(overlap,'overlap')
    if np.sum(overlap) == 0:
        return overlap
    n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
    return overlap / n_pixels_pred

matching_criteria['iop'] = intersection_over_pred


def precision(tp,fp,fn):
    return tp/(tp+fp) if tp > 0 else 0
def recall(tp,fp,fn):
    return tp/(tp+fn) if tp > 0 else 0
def accuracy(tp,fp,fn):
    return tp/(tp+fp+fn) if tp > 0 else 0
def f1(tp,fp,fn):    
    return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0

def _safe_divide(x,y):
    return x/y if y>0 else 0.0

def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False):
 
    _check_label_array(y_true,'y_true')
    _check_label_array(y_pred,'y_pred')
    y_true.shape == y_pred.shape or _raise(ValueError("y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes".format(y_true=y_true, y_pred=y_pred)))
    criterion in matching_criteria or _raise(ValueError("Matching criterion '%s' not supported." % criterion))
    if thresh is None: thresh = 0
    thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh)

    y_true, _, map_rev_true = relabel_sequential(y_true)
    y_pred, _, map_rev_pred = relabel_sequential(y_pred)

    overlap = label_overlap(y_true, y_pred, check=False)
    scores = matching_criteria[criterion](overlap)
    assert 0 <= np.min(scores) <= np.max(scores) <= 1

    # ignoring background
    scores = scores[1:,1:]
    n_true, n_pred = scores.shape
    n_matched = min(n_true, n_pred)

    def _single(thr):
        not_trivial = n_matched > 0 and np.any(scores >= thr)
        if not_trivial:
            # compute optimal matching with scores as tie-breaker
            costs = -(scores >= thr).astype(float) - scores / (2*n_matched)
            true_ind, pred_ind = linear_sum_assignment(costs)
            assert n_matched == len(true_ind) == len(pred_ind)
            match_ok = scores[true_ind,pred_ind] >= thr
            tp = np.count_nonzero(match_ok)
        else:
            tp = 0
        fp = n_pred - tp
        fn = n_true - tp


        # the score sum over all matched objects (tp)
        sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0

        # the score average over all matched objects (tp)
        mean_matched_score = _safe_divide(sum_matched_score, tp)
        # the score average over all gt/true objects
        mean_true_score    = _safe_divide(sum_matched_score, n_true)
        panoptic_quality   = _safe_divide(sum_matched_score, tp+fp/2+fn/2)

        stats_dict = dict (
            criterion          = criterion,
            thresh             = thr,
            fp                 = fp,
            tp                 = tp,
            fn                 = fn,
            precision          = precision(tp,fp,fn),
            recall             = recall(tp,fp,fn),
            accuracy           = accuracy(tp,fp,fn),
            f1                 = f1(tp,fp,fn),
            n_true             = n_true,
            n_pred             = n_pred,
            mean_true_score    = mean_true_score,
            mean_matched_score = mean_matched_score,
            panoptic_quality   = panoptic_quality,
        )
        if bool(report_matches):
            if not_trivial:
                stats_dict.update (
                    # int() to be json serializable
                    matched_pairs  = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)),
                    matched_scores = tuple(scores[true_ind,pred_ind]),
                    matched_tps    = tuple(map(int,np.flatnonzero(match_ok))),
                )
            else:
                stats_dict.update (
                    matched_pairs  = (),
                    matched_scores = (),
                    matched_tps    = (),
                )
        return namedtuple('Matching',stats_dict.keys())(*stats_dict.values())

    return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh))


def matching_dataset(y_true, y_pred, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):
    """matching metrics for list of images, see `stardist.matching.matching`
    """
    len(y_true) == len(y_pred) or _raise(ValueError("y_true and y_pred must have the same length."))
    return matching_dataset_lazy (
        tuple(zip(y_true,y_pred)), thresh=thresh, criterion=criterion, by_image=by_image, show_progress=show_progress, parallel=parallel,
    )


def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):

    expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'))

    single_thresh = False
    if np.isscalar(thresh):
        single_thresh = True
        thresh = (thresh,)

    tqdm_kwargs = {}
    tqdm_kwargs['disable'] = not bool(show_progress)
    if int(show_progress) > 1:
        tqdm_kwargs['total'] = int(show_progress)

    # compute matching stats for every pair of label images
    if parallel:
        from concurrent.futures import ThreadPoolExecutor
        fn = lambda pair: matching(*pair, thresh=thresh, criterion=criterion, report_matches=False)
        with ThreadPoolExecutor() as pool:
            stats_all = tuple(pool.map(fn, tqdm(y_gen,**tqdm_kwargs)))
    else:
        stats_all = tuple (
            matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False)
            for y_t,y_p in tqdm(y_gen,**tqdm_kwargs)
        )

    # accumulate results over all images for each threshold separately
    n_images, n_threshs = len(stats_all), len(thresh)
    accumulate = [{} for _ in range(n_threshs)]
    for stats in stats_all:
        for i,s in enumerate(stats):
            acc = accumulate[i]
            for k,v in s._asdict().items():
                if k == 'mean_true_score' and not bool(by_image):
                    # convert mean_true_score to "sum_matched_score"
                    acc[k] = acc.setdefault(k,0) + v * s.n_true
                else:
                    try:
                        acc[k] = acc.setdefault(k,0) + v
                    except TypeError:
                        pass

    # normalize/compute 'precision', 'recall', 'accuracy', 'f1'
    for thr,acc in zip(thresh,accumulate):
        set(acc.keys()) == expected_keys or _raise(ValueError("unexpected keys"))
        acc['criterion'] = criterion
        acc['thresh'] = thr
        acc['by_image'] = bool(by_image)
        if bool(by_image):
            for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
                acc[k] /= n_images
        else:
            tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true']
            sum_matched_score = acc['mean_true_score']

            mean_matched_score = _safe_divide(sum_matched_score, tp)
            mean_true_score    = _safe_divide(sum_matched_score, n_true)
            panoptic_quality   = _safe_divide(sum_matched_score, tp+fp/2+fn/2)

            acc.update(
                precision          = precision(tp,fp,fn),
                recall             = recall(tp,fp,fn),
                accuracy           = accuracy(tp,fp,fn),
                f1                 = f1(tp,fp,fn),
                mean_true_score    = mean_true_score,
                mean_matched_score = mean_matched_score,
                panoptic_quality   = panoptic_quality,
            )

    accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate)
    return accumulate[0] if single_thresh else accumulate


# copied from scikit-image master for now (remove when part of a release)
def relabel_sequential(label_field, offset=1):
    
    offset = int(offset)
    if offset <= 0:
        raise ValueError("Offset must be strictly positive.")
    if np.min(label_field) < 0:
        raise ValueError("Cannot relabel array that contains negative values.")
    max_label = int(label_field.max()) # Ensure max_label is an integer
    if not np.issubdtype(label_field.dtype, np.integer):
        new_type = np.min_scalar_type(max_label)
        label_field = label_field.astype(new_type)
    labels = np.unique(label_field)
    labels0 = labels[labels != 0]
    new_max_label = offset - 1 + len(labels0)
    new_labels0 = np.arange(offset, new_max_label + 1)
    output_type = label_field.dtype
    required_type = np.min_scalar_type(new_max_label)
    if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize:
        output_type = required_type
    forward_map = np.zeros(max_label + 1, dtype=output_type)
    forward_map[labels0] = new_labels0
    inverse_map = np.zeros(new_max_label + 1, dtype=output_type)
    inverse_map[offset:] = labels0
    relabeled = forward_map[label_field]
    return relabeled, forward_map, inverse_map


## ------------------- Image-to-image comparaison metrics ------------------------------


## Pearson correlation


## lpips ?


def ssim(img1, img2):
  return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)


def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):
    """This function is adapted from Martin Weigert"""
    """Percentile-based image normalization."""

    mi = np.percentile(x,pmin,axis=axis,keepdims=True)
    ma = np.percentile(x,pmax,axis=axis,keepdims=True)
    return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)


def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32
    """This function is adapted from Martin Weigert"""
    if dtype is not None:
        x   = x.astype(dtype,copy=False)
        mi  = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)
        ma  = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)
        eps = dtype(eps)

    try:
        import numexpr
        x = numexpr.evaluate("(x - mi) / ( ma - mi + eps )")
    except ImportError:
        x =                   (x - mi) / ( ma - mi + eps )

    if clip:
        x = np.clip(x,0,1)

    return x

def norm_minmse(gt, x, normalize_gt=True):
    """This function is adapted from Martin Weigert"""

    """
    normalizes and affinely scales an image pair such that the MSE is minimized  
     
    Parameters
    ----------
    gt: ndarray
        the ground truth image      
    x: ndarray
        the image that will be affinely scaled 
    normalize_gt: bool
        set to True of gt image should be normalized (default)
    Returns
    -------
    gt_scaled, x_scaled 
    """
    if normalize_gt:
        gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)
    x = x.astype(np.float32, copy=False) - np.mean(x)    
    gt = gt.astype(np.float32, copy=False) - np.mean(gt)    
    scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())
    return gt, scale * x


#--------------------- Display functions --------------------------------

def visualise_image_comparison_QC(image, dimension, Source_folder, Prediction_folder, Ground_truth_folder, QC_folder, QC_scores):
  
  img_Source = io.imread(os.path.join(Source_folder, image))
  img_Prediction = io.imread(os.path.join(Prediction_folder, image))
  img_GT = io.imread(os.path.join(Ground_truth_folder, image))

  if dimension == "3D":
    Z_plane = int(img_GT.shape[0] / 2)+1
  
  img_SSIM_GTvsSource = io.imread(os.path.join(QC_folder, 'SSIM_GTvsSource_'+image))
  img_SSIM_GTvsPrediction = io.imread(os.path.join(QC_folder, 'SSIM_GTvsPrediction_'+image))
  img_RSE_GTvsSource = io.imread(os.path.join(QC_folder, 'RSE_GTvsSource_'+image))
  img_RSE_GTvsPrediction = io.imread(os.path.join(QC_folder, 'RSE_GTvsPrediction_'+image))
  
  SSIM_GTvsP_forDisplay = QC_scores.loc[[image], 'Prediction v. GT mSSIM'].tolist()
  SSIM_GTvsS_forDisplay = QC_scores.loc[[image], 'Input v. GT mSSIM'].tolist()
  NRMSE_GTvsP_forDisplay = QC_scores.loc[[image], 'Prediction v. GT NRMSE'].tolist()
  NRMSE_GTvsS_forDisplay = QC_scores.loc[[image], 'Input v. GT NRMSE'].tolist()
  PSNR_GTvsP_forDisplay = QC_scores.loc[[image], 'Prediction v. GT PSNR'].tolist()
  PSNR_GTvsS_forDisplay = QC_scores.loc[[image], 'Input v. GT PSNR'].tolist()

  plt.figure(figsize=(15,15))

#-------------------Target (Ground-truth)-------------
  plt.subplot(3,3,1)
  plt.axis('off')

  if dimension == "2D":
    plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))
  
  if dimension == "3D":
    plt.imshow(img_GT[Z_plane], norm=simple_norm(img_GT, percent = 99))
  plt.title('Target',fontsize=15)

#-----------------------Source---------------------
  plt.subplot(3,3,2)
  plt.axis('off')

  if dimension == "2D":  
    plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))

  if dimension == "3D":
    plt.imshow(img_Source[Z_plane], norm=simple_norm(img_Source, percent = 99))
  plt.title('Source',fontsize=15)

#---------------------Prediction------------------------------
  plt.subplot(3,3,3)
  plt.axis('off')
  
  if dimension == "2D":
    plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))

  if dimension == "3D":
    plt.imshow(img_Prediction[Z_plane], norm=simple_norm(img_Prediction, percent = 99))
  plt.title('Prediction',fontsize=15)

  #Setting up colours
  cmap = plt.cm.CMRmap

#---------------------SSIM between GT and Source---------------------
  plt.subplot(3,3,5)
  #plt.axis('off')
  plt.tick_params(
      axis='both',      # changes apply to the x-axis and y-axis
      which='both',      # both major and minor ticks are affected
      bottom=False,      # ticks along the bottom edge are off
      top=False,        # ticks along the top edge are off
      left=False,       # ticks along the left edge are off
      right=False,         # ticks along the right edge are off
      labelbottom=False,
      labelleft=False)
     
  if dimension == "2D":
    imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)
  if dimension == "3D":
    imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[Z_plane], cmap = cmap, vmin=0, vmax=1)
  
  plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)
  plt.title('Target vs. Source',fontsize=15)
  plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay[0],3)),fontsize=14)
  plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)

#---------------------SSIM between GT and Prediction---------------------
  plt.subplot(3,3,6)
    #plt.axis('off')
  plt.tick_params(
      axis='both',      # changes apply to the x-axis and y-axis
      which='both',      # both major and minor ticks are affected
      bottom=False,      # ticks along the bottom edge are off
      top=False,        # ticks along the top edge are off
      left=False,       # ticks along the left edge are off
      right=False,         # ticks along the right edge are off
      labelbottom=False,
      labelleft=False)
  if dimension == "2D":    
    imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)
  
  if dimension == "3D":  
    imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[Z_plane], cmap = cmap, vmin=0,vmax=1)
  
  plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)
  plt.title('Target vs. Prediction',fontsize=15)
  plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay[0],3)),fontsize=14)

#---------------------Root Squared Error between GT and Source---------------------
  plt.subplot(3,3,8)
    #plt.axis('off')
  plt.tick_params(
      axis='both',      # changes apply to the x-axis and y-axis
      which='both',      # both major and minor ticks are affected
      bottom=False,      # ticks along the bottom edge are off
      top=False,        # ticks along the top edge are off
      left=False,       # ticks along the left edge are off
      right=False,         # ticks along the right edge are off
      labelbottom=False,
      labelleft=False) 
  
  if dimension == "2D":  
    imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)

  if dimension == "3D": 
    imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[Z_plane], cmap = cmap, vmin=0, vmax = 1)
  
  plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)
  plt.title('Target vs. Source',fontsize=15)
  plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay[0],3))+', PSNR: '+str(round(PSNR_GTvsS_forDisplay[0],3)),fontsize=14)  
  plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)

#---------------------Root Squared Error between GT and Prediction---------------------
  plt.subplot(3,3,9)
    #plt.axis('off')
  plt.tick_params(
      axis='both',      # changes apply to the x-axis and y-axis
      which='both',      # both major and minor ticks are affected
      bottom=False,      # ticks along the bottom edge are off
      top=False,        # ticks along the top edge are off
      left=False,       # ticks along the left edge are off
      right=False,         # ticks along the right edge are off
      labelbottom=False,
      labelleft=False)
  
  if dimension == "2D":
    imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)
  
  if dimension == "3D": 
    imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[Z_plane], cmap = cmap, vmin=0, vmax=1)
  
  plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)
  plt.title('Target vs. Prediction',fontsize=15)
  plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay[0],3))+', PSNR: '+str(round(PSNR_GTvsP_forDisplay[0],3)),fontsize=14)
  plt.savefig(QC_folder+"/QC_example_data.png",bbox_inches='tight',pad_inches=0)


def visualise_segmentation_QC(image, dimension, Source_folder, Prediction_folder, Ground_truth_folder, QC_folder, QC_scores):

  plt.figure(figsize=(25,5))
  
  source_image = io.imread(os.path.join(Source_folder, image))  

  target_image = io.imread(os.path.join(Ground_truth_folder, image))
  prediction = io.imread(os.path.join(Prediction_folder, image))

  IoU_forDisplay = QC_scores.loc[[image], 'Prediction v. GT Intersection over Union'].tolist()

  if dimension == "3D":  
    Z_plane = int(target_image.shape[0] / 2)+1    
      
  target_image_mask = target_image
  target_image_mask[target_image_mask > 0] = 255
  target_image_mask[target_image_mask == 0] = 0
  
  prediction_mask = prediction
  prediction_mask[prediction_mask > 0] = 255
  prediction_mask[prediction_mask == 0] = 0

  intersection = np.logical_and(target_image_mask, prediction_mask)
  union = np.logical_or(target_image_mask, prediction_mask)
  iou_score =  np.sum(intersection) / np.sum(union)

  norm = simple_norm(source_image, percent = 99)

  # Input
  plt.subplot(1,4,1)
  plt.axis('off')
  if dimension == "2D":
    n_channel = 1 if source_image.ndim == 2 else source_image.shape[-1]

    if n_channel > 1:
      plt.imshow(source_image)
    if n_channel == 1:
      plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')

  if dimension == "3D":
    plt.imshow(source_image[Z_plane], aspect='equal', norm=norm, cmap='magma', interpolation='nearest')

  plt.title('Input')

    #Ground-truth
  plt.subplot(1,4,2)
  plt.axis('off')
  if dimension == "2D":
    plt.imshow(target_image_mask, aspect='equal', cmap='Greens')
  
  if dimension == "3D":
    plt.imshow(target_image_mask[Z_plane], aspect='equal', cmap='Greens')

  plt.title('Ground Truth')

    #Prediction
  plt.subplot(1,4,3)
  plt.axis('off')
  if dimension == "2D":
    plt.imshow(prediction_mask, aspect='equal', cmap='Purples')
  if dimension == "3D":
    plt.imshow(prediction_mask[Z_plane], aspect='equal', cmap='Purples')

  plt.title('Prediction')

    #Overlay
  plt.subplot(1,4,4)
  plt.axis('off')
  if dimension == "2D":
    plt.imshow(target_image_mask, cmap='Greens')
    plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')
  
  if dimension == "3D":
    plt.imshow(target_image_mask[Z_plane], cmap='Greens')
    plt.imshow(prediction_mask[Z_plane], alpha=0.5, cmap='Purples')  

  plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(IoU_forDisplay[0],3 )));
  plt.savefig(QC_folder+"/QC_example_data.png",bbox_inches='tight',pad_inches=0)





# **2. Error mapping and quality metrics estimation**

---


## **Image similarity metrics**
---

<font size = 4>**The SSIM (structural similarity) map** 

<font size = 4>The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). 

<font size=4>**mSSIM** is the SSIM value calculated across the entire window of both images.

<font size=4>**The output below shows the SSIM maps with the mSSIM**

<font size = 4>**The RSE (Root Squared Error) map** 

<font size = 4>This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).


<font size =4>**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.

<font size = 4>**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.


---
## **Image segmentation metrics**
---


<font size = 4>The **Intersection over Union** (IuO) metric is a method that can be used to quantify the overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. 

<font size = 4>Here, the IuO is both calculated over the whole image and on a per-object basis. The value displayed below is the IuO value calculated over the entire image. The IuO value calculated on a per-object basis is used to calculate the other metrics displayed.

<font size = 4>“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. 

<font size = 4>When a segmented object has an IuO value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as  “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.

<font size = 4>The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).

<font size = 4>For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.

<font size = 4> The results can be found in the "*Quality Control*" folder which is located inside your "model_folder".



In [None]:
from tabulate import tabulate
from astropy.visualization import simple_norm

from ipywidgets import interact

#@markdown ##Choose the folders that contain the data to analyse

Source_folder = "" #@param{type:"string"}
Prediction_folder = "" #@param{type:"string"}
Ground_truth_folder = "" #@param{type:"string"}

#@markdown ##Choose the type of QC you want to perform

QC_type = "Image-to-image comparison" #@param ["Image-to-image comparison", "Segmentation", "Instance segmentation"]

#@markdown ###Are your data 2D or 3D images?

Data_type = "2D" #@param ["2D", "3D"]

# Create a quality control in the Prediction Folder

QC_folder = Prediction_folder+"/Quality Control"

if os.path.exists(QC_folder):
  shutil.rmtree(QC_folder)
os.makedirs(QC_folder)

# List images in Source_folder
Z = os.listdir(Source_folder)
print('Number of test dataset found in the folder: '+str(len(Z)))

random_choice = random.choice(os.listdir(Source_folder))
X = io.imread(Source_folder+"/"+random_choice)
n_channel = 1 if X.ndim == 2 else X.shape[-1]

# ------------------ Image-to-image comparison 2D -------------------------------------------------

if QC_type == "Image-to-image comparison" and Data_type == "2D" :

# Open and create the csv file that will contain all the QC metrics
  with open(QC_folder+"/QC_metrics.csv", "w", newline='') as file:
      writer = csv.writer(file)

    # Write the header in the csv file
      writer.writerow(["image","Prediction v. GT mSSIM","Input v. GT mSSIM", "Prediction v. GT NRMSE", "Input v. GT NRMSE", "Prediction v. GT PSNR", "Input v. GT PSNR"])  

    # Let's loop through the provided dataset in the QC folders

      for i in os.listdir(Source_folder):
        if not os.path.isdir(os.path.join(Source_folder,i)):
          print('Running QC on: '+i)
      # -------------------------------- Target test data (Ground truth) --------------------------------
          test_GT = io.imread(os.path.join(Ground_truth_folder, i))

      # -------------------------------- Source test data --------------------------------
          test_source = io.imread(os.path.join(Source_folder,i))

      # Normalize the images wrt each other by minimizing the MSE between GT and Source image
          test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)

      # -------------------------------- Prediction --------------------------------
          test_prediction = io.imread(os.path.join(Prediction_folder,i))

      # Normalize the images wrt each other by minimizing the MSE between GT and prediction
          test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True)        

      # -------------------------------- Calculate the metric maps and save them --------------------------------

      # Calculate the SSIM maps
          index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)
          index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)

      #Save ssim_maps
          img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)
          io.imsave(QC_folder+'/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)
          img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)
          io.imsave(QC_folder+'/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)
      
      # Calculate the Root Squared Error (RSE) maps
          img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))
          img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))

      # Save SE maps
          img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)
          img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)
          io.imsave(QC_folder+'/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)
          io.imsave(QC_folder+'/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)


      # -------------------------------- Calculate the RSE metrics and save them --------------------------------

      # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)
          NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))
          NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))
        
      # We can also measure the peak signal to noise ratio between the images
          PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)
          PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)

          writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])

  # ------------- For display ------------

  df = pd.read_csv (QC_folder+"/QC_metrics.csv")
  df.set_index("image", inplace=True)
  print(tabulate(df, headers='keys', tablefmt='psql'))


  print('--------------------------------------------------------------')
  @interact
  def show_QC_results(file = os.listdir(Source_folder)):

    visualise_image_comparison_QC(image = file, dimension=Data_type, Source_folder=Source_folder , Prediction_folder= Prediction_folder, Ground_truth_folder=Ground_truth_folder, QC_folder=QC_folder, QC_scores= df )  

  print('-----------------------------------')

# ------------------ Image-to-image comparison 3D -------------------------------------------------

if QC_type == "Image-to-image comparison" and Data_type == "3D" :

# Open and create the csv file that will contain all the QC metrics
  with open(QC_folder+"/QC_metrics.csv", "w", newline='') as file:
      writer = csv.writer(file)

    # Write the header in the csv file
      writer.writerow(["File name","Slice #","Prediction v. GT mSSIM","Input v. GT mSSIM", "Prediction v. GT NRMSE", "Input v. GT NRMSE", "Prediction v. GT PSNR", "Input v. GT PSNR"])  
    
    # These lists will be used to collect all the metrics values per slice
      file_name_list = []
      slice_number_list = []
      mSSIM_GvP_list = []
      mSSIM_GvS_list = []
      NRMSE_GvP_list = []
      NRMSE_GvS_list = []
      PSNR_GvP_list = []
      PSNR_GvS_list = []

    # These lists will be used to display the mean metrics for the stacks
      mSSIM_GvP_list_mean = []
      mSSIM_GvS_list_mean = []
      NRMSE_GvP_list_mean = []
      NRMSE_GvS_list_mean = []
      PSNR_GvP_list_mean = []
      PSNR_GvS_list_mean = []

    # Let's loop through the provided dataset in the QC folders
      for thisFile in os.listdir(Source_folder):
        if not os.path.isdir(os.path.join(Source_folder, thisFile)):
          print('Running QC on: '+thisFile)

          test_GT_stack = io.imread(os.path.join(Ground_truth_folder, thisFile))
          test_source_stack = io.imread(os.path.join(Source_folder,thisFile))
          test_prediction_stack = io.imread(os.path.join(Prediction_folder, thisFile))
          n_slices = test_GT_stack.shape[0]

        # Calculating the position of the mid-plane slice
          z_mid_plane = int(n_slices / 2)+1

          img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))
          img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))
          img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))
          img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))

          for z in range(n_slices): 
          # -------------------------------- Normalising the dataset --------------------------------

            test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)
            test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)

          # -------------------------------- Calculate the SSIM metric and maps --------------------------------

          # Calculate the SSIM maps and index
            index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)
            index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)

          #Calculate ssim_maps
            img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)
            img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False) 

          # -------------------------------- Calculate the NRMSE metrics --------------------------------

          # Calculate the Root Squared Error (RSE) maps
            img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))
            img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))

          # Calculate SE maps
            img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)
            img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)

          # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)
            NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))
            NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))

          # Calculate the PSNR between the images
            PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)
            PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)

            writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])
          
          # Collect values to display in dataframe output
            slice_number_list.append(z)
            mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)
            mSSIM_GvS_list.append(index_SSIM_GTvsSource)
            NRMSE_GvP_list.append(NRMSE_GTvsPrediction)
            NRMSE_GvS_list.append(NRMSE_GTvsSource)
            PSNR_GvP_list.append(PSNR_GTvsPrediction)
            PSNR_GvS_list.append(PSNR_GTvsSource)
        
        # If calculating average metrics for dataframe output
          file_name_list.append(thisFile)
          mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))
          mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))
          NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))
          NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))
          PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))
          PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))

         # ----------- Change the stacks to 32 bit images -----------

          img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)
          img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)
          img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)
          img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)

        # ----------- Saving the error map stacks -----------
          io.imsave(QC_folder+"/SSIM_GTvsSource_"+thisFile,img_SSIM_GTvsSource_stack_32)
          io.imsave(QC_folder+"/SSIM_GTvsPrediction_"+thisFile,img_SSIM_GTvsPrediction_stack_32)
          io.imsave(QC_folder+"/RSE_GTvsSource_"+thisFile,img_RSE_GTvsSource_stack_32)
          io.imsave(QC_folder+"/RSE_GTvsPrediction_"+thisFile,img_RSE_GTvsPrediction_stack_32)

#Averages of the metrics per stack as dataframe output
  pdResults = pd.DataFrame(file_name_list, columns = ["image"])
  pdResults["Prediction v. GT mSSIM"] = mSSIM_GvP_list_mean
  pdResults["Input v. GT mSSIM"] = mSSIM_GvS_list_mean
  pdResults["Prediction v. GT NRMSE"] = NRMSE_GvP_list_mean
  pdResults["Input v. GT NRMSE"] = NRMSE_GvS_list_mean
  pdResults["Prediction v. GT PSNR"] = PSNR_GvP_list_mean
  pdResults["Input v. GT PSNR"] = PSNR_GvS_list_mean

  print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')
  pdResults.set_index("image", inplace=True)
  pdResults.head()
  print(tabulate(pdResults, headers='keys', tablefmt='psql'))

  print('--------------------------------------------------------------')
  @interact
  def show_QC_results(file = os.listdir(Source_folder)):
    
    visualise_image_comparison_QC(image = file, dimension=Data_type, Source_folder=Source_folder , Prediction_folder= Prediction_folder, Ground_truth_folder=Ground_truth_folder, QC_folder=QC_folder, QC_scores= pdResults )  
    

  print('-----------------------------------')

# ------------------ Segmentation 2D -------------------------------------------------

if QC_type == "Segmentation" and Data_type == "2D":

  with open(QC_folder+"/QC_metrics.csv", "w", newline='') as file:
    writer = csv.writer(file, delimiter=",")
    writer.writerow(["image","Prediction v. GT Intersection over Union"])  

    for n in os.listdir(Source_folder):
    
      if not os.path.isdir(os.path.join(Source_folder,n)):
        print('Running QC on: '+n)
        test_input = io.imread(os.path.join(Source_folder,n))
        test_prediction = io.imread(os.path.join(Prediction_folder,n))
        test_ground_truth_image = io.imread(os.path.join(Ground_truth_folder, n))

       #Convert pixel values to 0 or 255
        test_prediction_0_to_255 = test_prediction
        test_prediction_0_to_255[test_prediction_0_to_255>0] = 255

      #Convert pixel values to 0 or 255
        test_ground_truth_0_to_255 = test_ground_truth_image
        test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255

      # Intersection over Union metric
        intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)
        union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)
        iou_score =  np.sum(intersection) / np.sum(union)
        writer.writerow([n, str(iou_score)])

  df = pd.read_csv (QC_folder+"/QC_metrics.csv")
  df.set_index("image", inplace=True)
  print(tabulate(df, headers='keys', tablefmt='psql'))


  # ------------- For display ------------
  print('--------------------------------------------------------------')
  @interact
  def show_QC_results(images = os.listdir(Source_folder)):  

    visualise_segmentation_QC(image=images, dimension=Data_type, Source_folder=Source_folder, Prediction_folder=Prediction_folder, Ground_truth_folder=Ground_truth_folder, QC_folder=QC_folder, QC_scores=df)
      
  print('-----------------------------------')


# ------------------ Segmentation 3D -------------------------------------------------

if QC_type == "Segmentation" and Data_type == "3D":

  with open(QC_folder+"/QC_metrics.csv", "w", newline='') as file:
    writer = csv.writer(file, delimiter=",")
    writer.writerow(["image","Slice #","Prediction v. GT Intersection over Union"])

    file_name_list = []
    slice_number_list = []
    iou_score_list = []  

    # These lists will be used to display the mean metrics for the stacks
    iou_score_list_mean = []

    for n in os.listdir(Source_folder):
    
      if not os.path.isdir(os.path.join(Source_folder,n)):
        print('Running QC on: '+n)
        test_input = io.imread(os.path.join(Source_folder,n))
        test_prediction = io.imread(os.path.join(Prediction_folder,n))
        test_ground_truth_image = io.imread(os.path.join(Ground_truth_folder, n))

        for z in range(n_slices):

       #Convert pixel values to 0 or 255
          test_prediction_0_to_255 = test_prediction[z]
          test_prediction_0_to_255[test_prediction_0_to_255>0] = 255

      #Convert pixel values to 0 or 255
          test_ground_truth_0_to_255 = test_ground_truth_image[z]
          test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255

      # Intersection over Union metric
          intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)
          union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)
          iou_score =  np.sum(intersection) / np.sum(union)

          slice_number_list.append(z)
          iou_score_list.append(iou_score)

          writer.writerow([n, str(z), str(iou_score)])

        iou_score_array = np.array(iou_score_list)
        iou_score_array[iou_score_array==0.0] = np.nan          

        # If calculating average metrics for dataframe output
        file_name_list.append(n)
        iou_score_list_mean.append(np.nanmean(iou_score_array))

  df = pd.read_csv (QC_folder+"/QC_metrics.csv")

#Averages of the metrics per stack as dataframe output
  pdResults = pd.DataFrame(file_name_list, columns = ["image"])
  pdResults["Prediction v. GT Intersection over Union"] = iou_score_list_mean

  print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')
  pdResults.set_index("image", inplace=True)
  pdResults.head()
  print(tabulate(pdResults, headers='keys', tablefmt='psql'))


  # ------------- For display ------------
  print('--------------------------------------------------------------')
  @interact
  def show_QC_results(images = os.listdir(Source_folder)):  

    visualise_segmentation_QC(image=images, dimension=Data_type, Source_folder=Source_folder, Prediction_folder=Prediction_folder, Ground_truth_folder=Ground_truth_folder, QC_folder=QC_folder, QC_scores=pdResults)
      
  print('-----------------------------------')



# ------------------ Instance Segmentation 2D -------------------------------------------------

if QC_type == "Instance segmentation" and Data_type == "2D":

  with open(QC_folder+"/QC_metrics.csv", "w", newline='') as file:
    writer = csv.writer(file, delimiter=",")
    writer.writerow(["image","Prediction v. GT Intersection over Union", "false positive", "true positive", "false negative", "precision", "recall", "accuracy", "f1 score", "n_true", "n_pred", "mean_true_score", "mean_matched_score", "panoptic_quality"])  

  # define the images

    for n in os.listdir(Source_folder):
    
      if not os.path.isdir(os.path.join(Source_folder,n)):
        print('Running QC on: '+n)
        test_input = io.imread(os.path.join(Source_folder,n))
        test_prediction = io.imread(os.path.join(Prediction_folder,n))
        test_ground_truth_image = io.imread(os.path.join(Ground_truth_folder, n))

        # Calculate the matching (with IoU threshold `thresh`) and all metrics

        stats = matching(test_ground_truth_image, test_prediction, thresh=0.5)      
      

       #Convert pixel values to 0 or 255
        test_prediction_0_to_255 = test_prediction
        test_prediction_0_to_255[test_prediction_0_to_255>0] = 255

      #Convert pixel values to 0 or 255
        test_ground_truth_0_to_255 = test_ground_truth_image
        test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255

      # Intersection over Union metric
        intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)
        union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)
        iou_score =  np.sum(intersection) / np.sum(union)
        writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])

  df = pd.read_csv (QC_folder+"/QC_metrics.csv")
  df.set_index("image", inplace=True)
  print(tabulate(df, headers='keys', tablefmt='psql'))


  # ------------- For display ------------
  print('--------------------------------------------------------------')
  @interact
  def show_QC_results(images = os.listdir(Source_folder)):
        
    visualise_segmentation_QC(image=images,dimension=Data_type, Source_folder=Source_folder, Prediction_folder=Prediction_folder, Ground_truth_folder=Ground_truth_folder, QC_folder=QC_folder, QC_scores=df)  

  print('-----------------------------------')


# ------------------ Instance Segmentation 3D in progress -------------------------------------------------

if QC_type == "Instance segmentation" and Data_type == "3D":

  with open(QC_folder+"/QC_metrics.csv", "w", newline='') as file:
    writer = csv.writer(file, delimiter=",")
    writer.writerow(["image","Slice #","Prediction v. GT Intersection over Union", "false positive", "true positive", "false negative", "precision", "recall", "accuracy", "f1 score", "n_true", "n_pred", "mean_true_score", "mean_matched_score", "panoptic_quality"])  

    # These lists will be used to collect all the metrics values per slice
    file_name_list = []
    slice_number_list = []
    iou_score_list = []
    fp_list = []
    tp_list = []
    fn_list = []
    precision_list = []
    recall_list = []
    accuracy_list = []
    f1_list = []
    n_true_list = []
    n_pred_list = []
    mean_true_score_list = []
    mean_matched_score_list = []
    panoptic_quality_list = []

    # These lists will be used to display the mean metrics for the stacks
    iou_score_list_mean = []
    fp_list_mean = []
    tp_list_mean = []
    fn_list_mean = []
    precision_list_mean = []
    recall_list_mean = []
    accuracy_list_mean = []
    f1_list_mean = []
    n_true_list_mean = []
    n_pred_list_mean = []
    mean_true_score_list_mean = []
    mean_matched_score_list_mean = []
    panoptic_quality_list_mean = []

    for n in os.listdir(Source_folder):
    
      if not os.path.isdir(os.path.join(Source_folder,n)):
        print('Running QC on: '+n)
        test_input = io.imread(os.path.join(Source_folder,n))
        test_prediction = io.imread(os.path.join(Prediction_folder,n))
        test_ground_truth_image = io.imread(os.path.join(Ground_truth_folder, n))

        n_slices = test_ground_truth_image.shape[0]

        for z in range(n_slices):

        # Calculate the matching (with IoU threshold `thresh`) and all metrics

          stats = matching(test_ground_truth_image[z], test_prediction[z], thresh=0.5) 
      

       #Convert pixel values to 0 or 255
          test_prediction_0_to_255 = test_prediction[z]
          test_prediction_0_to_255[test_prediction_0_to_255>0] = 255

      #Convert pixel values to 0 or 255
          test_ground_truth_0_to_255 = test_ground_truth_image[z]
          test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255

      # Intersection over Union metric
          intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)
          union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)
          iou_score =  np.sum(intersection) / np.sum(union)
          
          # Collect values to display in dataframe output
          slice_number_list.append(z)
          iou_score_list.append(iou_score)
          fp_list.append(stats.fp)
          tp_list.append(stats.tp)
          fn_list.append(stats.fn)
          precision_list.append(stats.precision)
          recall_list.append(stats.recall)
          accuracy_list.append(stats.accuracy)
          f1_list.append(stats.f1)
          n_true_list.append(stats.n_true)
          n_pred_list.append(stats.n_pred)
          mean_true_score_list.append(stats.mean_true_score)
          mean_matched_score_list.append(stats.mean_matched_score)
          panoptic_quality_list.append(stats.panoptic_quality)
  

          writer.writerow([n, str(z), str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])
        
        #Here we transform the lists into arrays so that 0 can be removed when computing the average over the stack

        iou_score_array = np.array(iou_score_list)
        iou_score_array[iou_score_array==0.0] = np.nan
        precision_array = np.array(precision_list)
        precision_array[precision_array==0.0] = np.nan        
        recall_array = np.array(recall_list)
        recall_array[recall_array==0.0] = np.nan
        accuracy_array = np.array(accuracy_list)
        accuracy_array[accuracy_array==0.0] = np.nan 
        f1_array = np.array(f1_list)
        f1_array[f1_array==0.0] = np.nan
        mean_true_score_array = np.array(mean_true_score_list)
        mean_true_score_array[mean_true_score_array==0.0] = np.nan          
        mean_matched_score_array = np.array(mean_matched_score_list)
        mean_matched_score_array[mean_matched_score_array==0.0] = np.nan
        panoptic_quality_array = np.array(panoptic_quality_list)
        panoptic_quality_array[panoptic_quality_array==0.0] = np.nan

        # If calculating average metrics for dataframe output
        file_name_list.append(n)
        iou_score_list_mean.append(np.nanmean(iou_score_array))
        fp_list_mean.append(sum(fp_list))
        tp_list_mean.append(sum(tp_list))
        fn_list_mean.append(sum(fn_list))
        precision_list_mean.append(np.nanmean(precision_array))
        recall_list_mean.append(np.nanmean(recall_array))
        accuracy_list_mean.append(np.nanmean(accuracy_array))
        f1_list_mean.append(np.nanmean(f1_array))
        n_true_list_mean.append(sum(n_true_list))
        n_pred_list_mean.append(sum(n_pred_list))
        mean_true_score_list_mean.append(np.nanmean(mean_true_score_array))
        mean_matched_score_list_mean.append(np.nanmean(mean_matched_score_array))
        panoptic_quality_list_mean.append(np.nanmean(panoptic_quality_array))

  df = pd.read_csv (QC_folder+"/QC_metrics.csv")

#Averages of the metrics per stack as dataframe output
  pdResults = pd.DataFrame(file_name_list, columns = ["image"])
  pdResults["Prediction v. GT Intersection over Union"] = iou_score_list_mean
  pdResults["false positive"] = fp_list_mean
  pdResults["true positive"] = tp_list_mean
  pdResults["false negative"] = fn_list_mean
  pdResults["precision"] = precision_list_mean
  pdResults["recall"] = recall_list_mean
  pdResults["accuracy"] = accuracy_list_mean
  pdResults["f1 score"] = f1_list_mean
  pdResults["n_true"] = n_true_list_mean
  pdResults["n_pred"] = n_pred_list_mean
  pdResults["mean_true_score"] = mean_true_score_list_mean
  pdResults["mean_matched_score"] = mean_matched_score_list_mean
  pdResults["panoptic_quality"] = panoptic_quality_list_mean

  print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')
  pdResults.set_index("image", inplace=True)
  pdResults.head()
  print(tabulate(pdResults, headers='keys', tablefmt='psql'))



  # ------------- For display ------------
  print('--------------------------------------------------------------')
  @interact
  def show_QC_results(images = os.listdir(Source_folder)):
        
    visualise_segmentation_QC(image=images,dimension=Data_type, Source_folder=Source_folder, Prediction_folder=Prediction_folder, Ground_truth_folder=Ground_truth_folder, QC_folder=QC_folder, QC_scores=pdResults)  

  print('-----------------------------------')





# **3. Version log**
---
<font size = 4>**v1.13**:  


*  This version now includes built-in version check and the version log that that you're reading now.


#**Thank you for using our Quality Control notebook!**