<a href="https://colab.research.google.com/github/marketakvasova/LSEC_segmentation/blob/main/LSEC_segmentation_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Automatic segmentation of electron microscope images**

This notebook is intended for training a neural network for the task of binary segmentation of fenestrations of Liver sinusoidal entdothelial cells (LSECS).

# How to use this notebook

To train a network, first connect to a GPU (**Runtime -> Change runtime time -> Hardware accelerator -> GPU**).

If you are using a pretrained network for inference and not training, being connected only to a **CPU** is slower, but possible.

This notebook works with data saved on your Google Drive. Network training requires pairs of images and their corresponding masks saved in two diferent folders. The image-mask pairs don't need to be named exactly the same, but they should correspond when sorted alphabetically.

In [None]:
# @title  { display-mode: "form" }
#@markdown ##**Run this cell to connect to Google Drive**
#@markdown A new window will open where you will be able to connect.

#@markdown When you are connected, you can see your Drive content in the left sidebar under **Files**.

from google.colab import drive
drive.mount('/content/gdrive')

# **1. Setup**

In [None]:
!pip install wandb
# !pip install torchmetrics
!pip install segmentation-models-pytorch

import segmentation_models_pytorch as smp
import os
import torch.cuda
from torch.utils.data import Dataset
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision import transforms
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
from torchsummary import summary
import shutil
import cv2 as cv
from numpy.lib.stride_tricks import as_strided
import pywt
from scipy.stats import norm
from google.colab.patches import cv2_imshow
import gc
import wandb
from scipy.signal import convolve2d
import math
import seaborn as sns
import itertools

DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Probably no point in running this, if the gpu is not connected
print(DEVICE)

# **2. Utils**

In [None]:
# @title  { display-mode: "form" }
#@markdown ##**Data utils**
class MyDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted([f for f in os.listdir(self.image_dir) if os.path.isfile(os.path.join(self.image_dir, f))])
        self.masks = sorted([f for f in os.listdir(self.mask_dir) if os.path.isfile(os.path.join(self.mask_dir, f))])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.masks[index]) # mask and image need to be called the same
        image = cv.imread(img_path, cv.IMREAD_GRAYSCALE).astype(np.float32)
        mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE).astype(np.float32)
        # mask /= 255
        mask[mask == 255.0] = 1

        augmentations = self.transform(image=image, mask=mask)
        image = augmentations["image"]
        mask = augmentations["mask"]

        return image, mask

def normalize_hist(img):
    clahe = cv.createCLAHE(10, tileGridSize=(11, 11))
    img = clahe.apply(img)
    img = cv.medianBlur(img, 3)
    return img


def get_loaders(img_train, mask_train, img_val, mask_val, batch_size, num_workers=0, pin_memory=True):
    train_transform, val_transform = get_transforms()

    train_data = MyDataset(
        image_dir=img_train,
        mask_dir=mask_train,
        transform=train_transform
    )
    val_data = MyDataset(
        image_dir=img_val,
        mask_dir=mask_val,
        transform=val_transform
    )

    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True
    )

    val_loader = DataLoader(
        val_data,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False
    )

    return train_loader, val_loader


def get_transforms():
    train_transform = A.Compose(
        [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Affine(scale=(0.9, 1.1)),
            A.Normalize(
                mean = 0.5,
                std = 0.5,
                max_pixel_value=255.0,
            ),
            ToTensorV2()
        ]
    )

    val_transform = A.Compose(
        [
            A.Normalize(
                mean = 0.5,
                std = 0.5,
                max_pixel_value=255.0,
            ),
            ToTensorV2()
        ]
    )
    return train_transform, val_transform

test_transform = A.Compose(
    [
        A.Normalize(
        mean = 0.5,
        std = 0.5,
        max_pixel_value=255.0,
        ),
            ToTensorV2()
    ]
)


def merge_images(image, mask):
    merge = np.zeros((mask.shape[0], mask.shape[1], 3))
    merge[:, :, 0] = image # B channel (0, 1, 2) = (B, G, R)
    merge[:, :, 2] = image # R channel
    merge[:, :, 1] = mask # G channel
    merge[:, :, 2][mask == 255.0] = 255 # R channel
    merge = merge.astype('uint8')
    return merge


def merge_original_mask(image_path, mask_path, output_folder):
    image = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
    mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE)
    merge = merge_images(image, mask)
    filename_ext = os.path.basename(image_path)
    filename, ext = os.path.splitext(filename_ext)
    cv.imwrite(os.path.join(output_folder, filename+"_original_mask_merge"+ext), merge)


def merge_masks(mask1_path, mask2_path, output_folder):
    print('merging masks')
    mask1 = cv.imread(mask1_path, cv.IMREAD_GRAYSCALE)
    mask2 = cv.imread(mask2_path, cv.IMREAD_GRAYSCALE)
    # merge = merge_images(image, mask)
    merge = np.zeros((mask1.shape[0], mask1.shape[1], 3))

    merge[:, :, 1][mask1 == 255.0] = 255
    merge[:, :, 2][mask2 == 255.0] = 255

    filename_ext = os.path.basename(mask1_path)
    filename, ext = os.path.splitext(filename_ext)
    cv.imwrite(os.path.join(output_folder, filename+"_mask_compare"+ext), merge)


def create_weighting_patches(patch_size, edge_size):
    patch = np.ones((patch_size, patch_size), dtype=float)

    # Calculate the linear decrease values
    decrease_values = np.linspace(1, 0, num=edge_size)
    decrease_values = np.tile(decrease_values, (patch_size, 1))
    increase_values = np.linspace(0, 1, num=edge_size)
    increase_values = np.tile(increase_values, (patch_size, 1))

    # Middle patch
    # Apply linear decrease to all four edges
    middle = patch.copy()
    middle[:, 0:edge_size] *= increase_values
    middle[:, patch_size-edge_size:patch_size] *= decrease_values
    middle[0:edge_size, :] *= increase_values.T
    middle[patch_size-edge_size:patch_size, :] *= decrease_values.T
    # cv2_imshow((middle*255).astype(np.uint8))

    # Left
    left = patch.copy()
    left[:, patch_size-edge_size:patch_size] *= decrease_values
    left[0:edge_size, :] *= increase_values.T
    left[patch_size-edge_size:patch_size, :] *= decrease_values.T
    # cv2_imshow((left*255).astype(np.uint8))

    # Right
    right = patch.copy()
    right[:, 0:edge_size] *= increase_values
    right[0:edge_size, :] *= increase_values.T
    right[patch_size-edge_size:patch_size, :] *= decrease_values.T
    # cv2_imshow((right*255).astype(np.uint8))

    # Top
    top = patch.copy()
    top[:, 0:edge_size] *= increase_values
    top[:, patch_size-edge_size:patch_size] *= decrease_values
    top[patch_size-edge_size:patch_size, :] *= decrease_values.T
    # cv2_imshow((top*255).astype(np.uint8))

    # Bottom
    bottom = patch.copy()
    bottom[:, 0:edge_size] *= increase_values
    bottom[:, patch_size-edge_size:patch_size] *= decrease_values
    bottom[0:edge_size, :] *= increase_values.T
    # cv2_imshow((bottom*255).astype(np.uint8))

    # Left Top edge
    top_left = patch.copy()
    top_left[:, patch_size-edge_size:patch_size] *= decrease_values
    top_left[patch_size-edge_size:patch_size, :] *= decrease_values.T
    # cv2_imshow((top_left*255).astype(np.uint8))

    # Right top edge
    top_right = patch.copy()
    top_right[:, 0:edge_size] *= increase_values
    top_right[patch_size-edge_size:patch_size, :] *= decrease_values.T
    # cv2_imshow((top_right*255).astype(np.uint8))

    # Left bottom edge
    bottom_left = patch.copy()
    bottom_left[:, patch_size-edge_size:patch_size] *= decrease_values
    bottom_left[0:edge_size, :] *= increase_values.T
    # cv2_imshow((bottom_left*255).astype(np.uint8))

    # Right Bottom edge
    bottom_right = patch.copy()
    bottom_right[:, 0:edge_size] *= increase_values
    bottom_right[0:edge_size, :] *= increase_values.T
    # cv2_imshow((bottom_right*255).astype(np.uint8))

    return middle, top_left, top, top_right, right, bottom_right, bottom, bottom_left, left


def add_mirrored_border(image, border_size, window_size):
    height, width = image.shape

    bottom_edge = window_size - ((height + border_size) % (window_size - border_size))
    right_edge = window_size - ((width + border_size) % (window_size - border_size))

    top_border = np.flipud(image[0:border_size, :])
    bottom_border = np.flipud(image[height - (border_size+bottom_edge):height, :])
    top_bottom_mirrored = np.vstack((top_border, image, bottom_border))

    left_border = np.fliplr(top_bottom_mirrored[:, 0:border_size])
    right_border = np.fliplr(top_bottom_mirrored[:, width - (border_size+right_edge):width])
    mirrored_image = np.hstack((left_border, top_bottom_mirrored, right_border))
    return mirrored_image

def inference_on_image_with_overlap(model, image_path):
    window_size = 224
    oh, ow = 20, 20

    input_image = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
    image_height, image_width = input_image.shape
    original_height, original_width = image_height, image_width

    mirrored_image = add_mirrored_border(input_image, oh, window_size)
    image_height, image_width = mirrored_image.shape


    weights = np.zeros((image_height, image_width))
    output_probs = np.zeros((image_height, image_width))
    output_mask = np.zeros((image_height, image_width))
    middle, top_left, top, top_right, right, bottom_right, bottom, bottom_left, left = create_weighting_patches(window_size, oh)

    for x in range(0, image_height-window_size+1, window_size - oh):
        for y in range(0, image_width-window_size+1, window_size - ow):
            if x == 0:
                if y == 0:
                    weighting_window = top_left
                elif y == image_width - window_size:
                    weighting_window = top_right
                else:
                    weighting_window = top
            elif x == image_height - window_size:
                if y == 0:
                    weighting_window = bottom_left
                elif y == image_width - window_size:
                    weighting_window = bottom_right
                else:
                    weighting_window = bottom
            elif y == 0:
                weighting_window = left
            elif y == image_width - window_size:
                weighting_window = right
            else:
                weighting_window = middle
            square_section = mirrored_image[x:x + window_size, y:y + window_size]
            weights[x:x + window_size, y:y + window_size] += weighting_window
            square_section = normalize_hist(square_section)
            square_tensor = test_transform(image=square_section)['image'].unsqueeze(0).to(DEVICE)  # Add batch and channel dimension

            with torch.no_grad():
                output = torch.sigmoid(model(square_tensor)).float()

            # Scale the probablity to 0-255
            output = output*255
            output_pil = output.squeeze(0).cpu().numpy().squeeze()
            output_probs[x:x+window_size, y:y+window_size] += output_pil*weighting_window
    output_probs = output_probs[oh:original_height+oh, ow:original_width+ow]
    weights *= 255
    threshold = int(255*0.4)
    output_mask = np.where(output_probs > threshold, 255, 0)
    output_mask = output_mask.astype(np.uint8)
    return output_mask


def preprocess_image(image):
    image = normalize_hist(image)
    return image

def create_train_val_patches(train_image_folder, train_mask_folder, val_image_folder, val_mask_folder, output_folder, patch_size, reduction):
    train_image_patches_path, train_mask_patches_path = create_image_patches(train_image_folder, train_mask_folder, output_folder, patch_size, reduction, img_type='train', )
    val_image_patches_path, val_mask_patches_path = create_image_patches(val_image_folder, val_mask_folder, output_folder, patch_size, reduction, img_type='val')
    return train_image_patches_path, train_mask_patches_path, val_image_patches_path, val_mask_patches_path

def create_image_patches(image_folder, mask_folder, output_folder, patch_size, reduction, img_type):
    image_patches_path = os.path.join(output_folder, img_type+'_image_patches')
    mask_patches_path = os.path.join(output_folder, img_type+'_mask_patches')

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    if os.path.exists(image_patches_path):
        shutil.rmtree(image_patches_path)
    os.mkdir(image_patches_path)
    if os.path.exists(mask_patches_path):
        shutil.rmtree(mask_patches_path)
    os.mkdir(mask_patches_path)


    patch_area = patch_size**2
    image_filenames = [f for f in os.listdir(image_folder) if os.path.isfile(os.path.join(image_folder, f))]
    image_filenames = sorted(image_filenames)
    mask_filenames = [f for f in os.listdir(mask_folder) if os.path.isfile(os.path.join(mask_folder, f))]
    mask_filenames = sorted(mask_filenames)

    for image_name, mask_name in zip(image_filenames, mask_filenames):
        input_path = os.path.join(image_folder, image_name)
        mask_path = os.path.join(mask_folder, mask_name)

        img = cv.imread(input_path, cv.IMREAD_GRAYSCALE)
        mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE)
        height, width = img.shape

        shape = (height // patch_size, width // patch_size, patch_size, patch_size)
        strides = (patch_size * width , patch_size , width, 1)

        img_strided = as_strided(img, shape=shape,
                        strides=strides, writeable=False) #TODO: check if the patches do not overlap
        mask_strided = as_strided(mask, shape=shape,
                        strides=strides, writeable=False)
        k = 0
        for i in range(img_strided.shape[0]):
            for j in range(img_strided.shape[1]):
                if k % reduction == 0: # this reduces the number of patches if the training takes too long(1 for no reduction)
                    img_patch = img_strided[i, j]
                    mask_patch = mask_strided[i, j]
                    # Compute the percentage of white pixels
                    fenestration_area = np.sum(mask_patch == 255)
                    patch_filename = f"{os.path.splitext(os.path.basename(image_name))[0]}_patch_{i}_{j}.tif"
                    # preprocess image
                    img_patch = preprocess_image(img_patch)
                    cv.imwrite(os.path.join(image_patches_path, patch_filename), img_patch)
                    cv.imwrite(os.path.join(mask_patches_path, patch_filename), mask_patch)
                k += 1
    return image_patches_path, mask_patches_path


# Denoising
#   References for non-local means filtering and noise variance estimation:
#
#   [1] Antoni Buades, Bartomeu Coll, and Jean-Michel Morel, A Non-Local
#       Algorithm for Image Denoising, Computer Vision and Pattern
#       Recognition 2005. CVPR 2005, Volume 2, (2005), pp. 60-65.
#   [2] John Immerkaer, Fast Noise Variance Estimation, Computer Vision and
#       Image Understanding, Volume 64, Issue 2, (1996), pp. 300-302

def estimate_degree_of_smoothing(I): # This is how the estimation is done in Matlab (see imnlmfilt in Matlab)
    H, W = I.shape
    I = I.astype(np.float32)
    kernel = np.array([[1, -2, 1], [-2, 4, -2], [1, -2, 1]])
    conv_result = np.abs(convolve2d(I[:, :], kernel, mode='valid'))
    res = np.sum(conv_result)
    degree_of_smoothing = (res * np.sqrt(0.5 * np.pi) / (6 * (W - 2) * (H - 2)))
    if degree_of_smoothing == 0:
        degree_of_smoothing = np.finfo(np.float32).eps
    return degree_of_smoothing


def nlm_filt(image):
    window_size = 5
    search_window_size = 21
    degree_of_smoothing = estimate_degree_of_smoothing(image)
    image = cv.fastNlMeansDenoising(image, None, h = degree_of_smoothing, templateWindowSize = 5, searchWindowSize = 21)
    return image


def anscombe_transform(data):
    return 2.0 * np.sqrt(data + 3.0/8.0)


def inverse_anscombe_transform(data):
    # Reference
    # https://github.com/broxtronix/pymultiscale/blob/master/pymultiscale/anscombe.py
    return (1.0/4.0 * np.power(data, 2) +
        1.0/4.0 * np.sqrt(3.0/2.0) * np.power(data, -1.0) -
        11.0/8.0 * np.power(data, -2.0) +
        5.0/8.0 * np.sqrt(3.0/2.0) * np.power(data, -3.0) - 1.0 / 8.0)


def wavelet_denoising(data, threshold=1.5, wavelet='coif4', threshold_type='soft'):
    coeffs = pywt.wavedec2(data, wavelet = wavelet, level=3)
    coeffs[-1] = tuple(pywt.threshold(c, threshold, threshold_type) for c in coeffs[-1])
    coeffs[-2] = tuple(pywt.threshold(c, threshold, threshold_type) for c in coeffs[-2])
    coeffs[-3] = tuple(pywt.threshold(c, threshold, threshold_type) for c in coeffs[-3])
    return pywt.waverec2(coeffs, wavelet)


def wavelet_denoise(image, threshold):
    image = anscombe_transform(image)
    image = wavelet_denoising(image, threshold)
    image = inverse_anscombe_transform(image)
    # TODO: not sure this is the correct way how to do this
    image = image/np.max(image)*255
    return image.astype(np.uint8)

def show_fitted_ellipses(image_path, ellipses):
    image = cv.imread(image_path)
    for ellipse in ellipses:
        if ellipse is not None:
            cv.ellipse(image, ellipse, (0, 0, 255), 1)
            center, axes, angle = ellipse
            center_x, center_y = center
            major_axis_length, minor_axis_length = axes
            rotation_angle = angle
            # print(center_x, center_y)
            cv.circle(image, (int(center_x), int(center_y)),radius=1, color=(0, 0, 255), thickness=-1)

        # print("Center:", center)
        # print("Major Axis Length:", major_axis_length)
        # print("Minor Axis Length:", minor_axis_length)
        # print("Rotation Angle:", rotation_angle)

    cv2_imshow(image)

def fit_ellipses(filtered_contours, centers):
    ellipses = []
    num_ellipses = 0
    for contour, cnt_center in zip(filtered_contours, centers):
        if len(contour) >= 5:  # Ellipse fitting requires at least 5 points
            ellipse = cv.fitEllipse(contour) # TODO: maybe try a different computation, if this does not work well on edges (probably ok)
            # ellipse = cv.minAreaRect(cnt) # the fitEllipse functions fails sometimes(when the fenestration is on the edge and only a part of it is visible)
            dist = cv.norm(cnt_center, ellipse[0])
            # print(dist)
            if dist < 20:
                ellipses.append(ellipse)
                num_ellipses += 1
            else:
                ellipses.append((None, None, None))
        else:
            ellipses.append((None, None, None))
    # print(len(filtered_contours), len(ellipses))
    return ellipses, num_ellipses

def find_fenestration_contours(image_path):
    seg_mask = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
    contours, _ = cv.findContours(seg_mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    return contours

def find_contour_centers(contours):
    contour_centers = []
    for cnt in contours:
        M = cv.moments(cnt)
        center_x = int(M['m10'] / (M['m00'] + 1e-10))
        center_y = int(M['m01'] / (M['m00'] + 1e-10))
        contour_centers.append((center_x, center_y))
    return contour_centers

def equivalent_circle_diameter(major_axis_length, minor_axis_length):
    return math.sqrt(major_axis_length * minor_axis_length)



def show_statistics(fenestration_areas, fenestration_areas_from_ellipses, roundness_of_ellipses, equivalent_diameters, min_roundness=0, min_d=None, max_d=None):
    palette = itertools.cycle(sns.color_palette())
    plt.figure(figsize=(21, 5))

    # Plot histogram of fenestration areas
    plt.subplot(1, 4, 1)
    sns.histplot(fenestration_areas, stat='probability')
    # plt.hist(fenestration_areas, bins=20, color='red', edgecolor='black', density=density)
    plt.title('Histogram of Fenestration Areas')
    plt.xlabel('Area ($\mathrm{nm}^2$)')
    # plt.ylabel('Frequency')
    plt.grid(True)

    # Plot histogram of areas of fitted elipses
    plt.subplot(1, 4, 2)
    sns.histplot(fenestration_areas_from_ellipses, stat='probability', color=next(palette)) # this will be the first color (blue)
    # plt.hist(fenestration_areas_from_ellipses, bins=20, color='red', edgecolor='black', density=density)
    plt.title('Histogram of Fenestration Areas (fitted ellipses)')
    plt.xlabel('Area ($\mathrm{nm}^2$)')
    # plt.ylabel('Frequency')
    plt.grid(True)

    # Plot histogram of roundness
    plt.subplot(1, 4, 3)
    r = sns.histplot(roundness_of_ellipses, stat='probability', color=next(palette), binwidth=0.025)
    r.set(xlim=(min_roundness, None))
    # plt.hist(roundness_of_ellipses, bins=10, color='blue', edgecolor='black', density=density)
    plt.title('Histogram of Roundness')
    plt.xlabel('Roundness (-)')
    # plt.ylabel('Frequency')
    plt.grid(True)
    # print(np.array(roundness_of_ellipses).max())

    # Plot histogram of equivalent circle diameters
    plt.subplot(1, 4, 4)
    d = sns.histplot(equivalent_diameters, stat='probability', color=next(palette), binwidth=10)
    d.set(xlim=(0, max_d))
    # plt.hist(equivalent_diameters, bins=20, color='green', edgecolor='black', density=density)
    plt.title('Histogram of Equivalent Circle Diameters')
    plt.xlabel('Diameter (nm)')
    # plt.ylabel('Frequency')
    plt.grid(True)


In [None]:
# @title  { display-mode: "form" }
#@markdown ##**Training utils**
def save_checkpoint(model, model_path):#, filename="my_checkpoint.pth"):
    print("=> Saving checkpoint")
    model.save(model_path)
    # torch.save(state, filename)

def save_state_dict(model, model_path):
    print("=> Saving checkpoint")
    torch.save(model.state_dict(), model_path)

def load_state_dict(model, model_path):
    print("=> Loading checkpoint")
    model.load_state_dict(torch.load(model_path))

def validate_model(model, loader, loss_fn):
    model.eval()
    total_loss = 0.0
    total_dice_score = 0.0
    total_samples = 0
    eps = 1e-8
    with torch.no_grad():
        for idx, (x, y) in enumerate(loader):
            x = x.to(DEVICE)
            y = y.to(DEVICE).unsqueeze(1)
            # Forward
            out = model(x)
            loss = get_loss(out, y, loss_fn)
            total_loss += loss.item() * x.size(0)
            if WANDB_CONNECTED:
                wandb.log({"val/batch loss": loss.item()})

            predicted_probs = torch.sigmoid(out)
            predicted = (predicted_probs > 0.5).float()
            intersection = torch.sum(predicted * y)
            dice_score = (2.0 * intersection + eps) / (torch.sum(predicted) + torch.sum(y) + eps)
            total_dice_score += dice_score.item() * x.size(0)

            total_samples += x.size(0)
    model.train()

    average_loss = total_loss / total_samples
    average_dice_score = total_dice_score / total_samples

    return average_loss, average_dice_score



def view_prediction(loader, model, device="cpu"):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            output = torch.sigmoid(model(x))
            preds = (output > 0.5).float()
            preds = preds.cpu().data.numpy()
            output = output.cpu().data.numpy()
            for i in range(preds.shape[0]):
                f=plt.figure(figsize=(128,32))
                # Original image
                plt.subplot(1,5*preds.shape[0],i+1)
                x = x.cpu()
                plt.imshow(x[i, 0, :, :], cmap='gray') # preds is a batch
                plt.title('Validation image')
                # NN output(probability)
                plt.subplot(1,5*preds.shape[0],i+2)
                plt.imshow(output[i, 0, :, :], interpolation='nearest', cmap='magma') # preds is a batch
                plt.title('NN output')
                # Segmentation
                plt.subplot(1,5*preds.shape[0],i+3)
                plt.imshow(preds[i, 0, :, :], cmap='gray') # preds is a batch
                plt.title('Prediction')
                # True mask
                plt.subplot(1,5*preds.shape[0],i+4)
                plt.imshow(y.unsqueeze(1)[i, 0, :, :], cmap='gray')
                plt.title('Ground truth')
                # IoU
                plt.subplot(1,5*preds.shape[0],i+5)
                im1 = y.unsqueeze(1)[i, 0, :, :]
                im2 = preds[i, 0, :, :]
                plt.imshow(im1, alpha=0.8, cmap='Blues')
                plt.imshow(im2, alpha=0.6,cmap='Oranges')
                plt.title('IoU')

            plt.show()
            break # TODO: change this so it does not loop
    model.train()

import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

def build_optimizer(model, config, beta1=None, beta2=None):
    if config.optimizer == "sgd":
        optimizer = optim.SGD(model.parameters(),
                              lr=config.learning_rate,
                              weight_decay=config.weight_decay,
                              momentum=config.momentum)
    elif config.optimizer == "adam":
        optimizer = optim.Adam(model.parameters(),
                               lr=config.learning_rate,
                            #    betas=(config.beta1, config.beta2),
                               weight_decay=config.weight_decay)
    return optimizer


def build_dataloaders(config):
    train_image_patches_path = config.train_image_patches_path
    train_mask_patches_path = config.train_mask_patches_path
    val_image_patches_path = config.val_image_patches_path
    val_mask_patches_path = config.val_mask_patches_path

    train_loader, val_loader = get_loaders(
        train_image_patches_path,
        train_mask_patches_path,
        val_image_patches_path,
        val_mask_patches_path,
        config.batch_size,
        num_workers=0,
        pin_memory=True
    )
    return train_loader, val_loader # this is the simplest way to do it, wandb train cannot take any arguments

class EarlyStopper():
    def __init__(self, patience):
        self.patience = patience
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > self.min_validation_loss:
            self.counter += 1
            print(self.counter)
            if self.counter >= self.patience:
                return True
        return False

def train_epoch(model, train_loader, optimizer, loss_fn):
    model.train()
    total_loss = 0.0
    total_samples = 0
    running_loss = 0
    losses = []
    scaler = torch.cuda.amp.GradScaler()
    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device=DEVICE)
        targets = targets.unsqueeze(1).to(device=DEVICE)
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = get_loss(predictions, targets, loss_fn)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item() * data.size(0)
        total_samples += data.size(0)

        if WANDB_CONNECTED:
            wandb.log({"train/batch loss": loss.item()})

    mean_loss = total_loss / total_samples
    model.eval()
    return mean_loss

def build_model(model_name):
    in_channels = 1
    out_channels = 1
    if '+' in model_name:
        name_parts = model_name.split('+')
        encoder = name_parts[-2]
        if name_parts[-1] == 'imagenet' or name_parts[-1] == 'ssl':
            weights = name_parts[-1]
        else:
            weights = None
    out_activation = None

    if 'Unet++' in model_name:
        model = smp.UnetPlusPlus(
                encoder_name=encoder,
                encoder_weights=weights,
                in_channels=in_channels,
                classes=out_channels,
                activation=out_activation,).to(DEVICE)
    elif 'Linknet' in model_name:
        model = smp.Linknet(
                encoder_name=encoder,
                encoder_weights=weights,
                in_channels=in_channels,
                classes=out_channels,
                activation=out_activation,).to(DEVICE)
    elif 'FPN' in model_name:
        model = smp.FPN(
                encoder_name=encoder,
                encoder_weights=weights,
                in_channels=in_channels,
                classes=out_channels,
                activation=out_activation,).to(DEVICE)
    elif 'DeepLabV3' in model_name:
        model = smp.DeepLabV3(
                encoder_name=encoder,
                encoder_weights=weights,
                in_channels=in_channels,
                classes=out_channels,
                activation=out_activation,).to(DEVICE)
    else:
        model = smp.Unet(
                encoder_name=encoder,
                encoder_weights=weights,
                in_channels=in_channels,
                classes=out_channels,
                activation=out_activation,).to(DEVICE)
    return model

def get_loss(pred, target, func_name):
    loss_func = None
    if func_name == 'dice':
        loss_func = smp.losses.DiceLoss(mode='binary')
        loss = loss_func(pred, target)
    elif func_name == 'bce':
        loss_func = nn.BCEWithLogitsLoss()
        loss = loss_func(pred, target)
    elif func_name == 'jaccard':
        loss_func = smp.losses.JaccardLoss(mode='binary')
        loss = loss_func(pred, target)
    elif func_name == 'weighted_bce':
        loss_func = nn.BCEWithLogitsLoss(pos_weight = torch.tensor(4))
        loss = loss_func(pred, target)
    elif func_name == 'focal':
        loss_func = smp.losses.FocalLoss(mode='binary')
        loss = loss_func(pred, target)
    elif func_name == 'dice+bce':
        loss_func1 = smp.losses.DiceLoss(mode='binary')
        loss1 = loss_func1(pred, target)
        loss_func2 = nn.BCEWithLogitsLoss()
        loss2 = loss_func2(pred, target)
        loss = 0.5*loss1 + 0.5*loss2
    elif func_name == '5dice+95bce':
        loss_func1 = smp.losses.DiceLoss(mode='binary')
        loss1 = loss_func1(pred, target)
        loss_func2 = nn.BCEWithLogitsLoss()
        loss2 = loss_func2(pred, target)
        loss = 0.05*loss1 + 0.95*loss2
    elif func_name == '20dice+80bce':
        loss_func1 = smp.losses.DiceLoss(mode='binary')
        loss1 = loss_func1(pred, target)
        loss_func2 = nn.BCEWithLogitsLoss()
        loss2 = loss_func2(pred, target)
        loss = 0.2*loss1 + 0.8*loss2
    elif func_name == 'dice+focal':
        loss_func1 = smp.losses.DiceLoss(mode='binary')
        loss1 = loss_func1(pred, target)
        loss_func2 = smp.losses.FocalLoss(mode='binary')
        loss2 = loss_func2(pred, target)
        loss = 0.5*loss1 + 0.5*loss2
    elif func_name == 'tversky':
        loss_func = smp.losses.TverskyLoss(mode='binary', alpha=0.7, beta=0.3)
        loss = loss_func(pred, target)

    return loss

def wandb_train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
        config = wandb.config

        train_loader, val_loader = build_dataloaders(config)
        model = build_model(config.model_type)
        optimizer = build_optimizer(model, config)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

        # best_dice_score = 0
        smallest_val_loss = 1000.0
        early_stopper = EarlyStopper(patience=10)
        for epoch in range(config.epochs):
            print(f'Epoch {epoch}')
            avg_loss = train_epoch(model, train_loader, optimizer, config.loss_function)#, loss_fn)
            metrics = {"train/loss": avg_loss, "train/epoch": epoch}
            val_loss, dice_score = validate_model(model, val_loader, config.loss_function)
            scheduler.step(val_loss)

            if early_stopper.early_stop(val_loss):
                print(f"early stop on epoch {epoch}")
                with open('./gdrive/MyDrive/lsecs/dice_score_test/train_log.txt', "a+") as file:
                    file.write(f'{config.model_type} early stop on epoch {epoch}\n')
                break

            if val_loss < smallest_val_loss:
                torch.save(model.state_dict(), os.path.join(config.model_path, f'{config.model_type}_{config.loss_function}_{config.image_denoising_methods}.pth'))
            smallest_val_loss = min(val_loss, smallest_val_loss)

            val_metrics = {"val/val_loss": val_loss,
                           "val/dice_score": dice_score}
            wandb.log({**metrics, **val_metrics})

class DictObject:
    def __init__(self, **entries):
        self.__dict__.update(entries)

def train(config, loaded_model=None):
    if WANDB_CONNECTED:
        wandb.init(
            project="LSEC_segmentation",
            config=config)
    config = DictObject(**config)
    train_loader, val_loader = build_dataloaders(config)
    if loaded_model is None:
        model = build_model(config.encoder+'+'+config.model)
    else:
        model = loaded_model
    optimizer = build_optimizer(model, config)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

    smallest_val_loss = 1000.0
    train_losses = []
    val_losses = []
    dice_scores = []
    early_stopper = EarlyStopper(patience=10)
    for epoch in range(config.num_epochs):
        print(f'Epoch {epoch}')
        model.train()
        train_loss = train_epoch(model, train_loader, optimizer, config.loss_function)
        train_losses.append(train_loss)
        val_loss, dice_score = validate_model(model, val_loader, config.loss_function)
        scheduler.step(val_loss)
        current_lr = scheduler.get_last_lr()[0]
        print("Current learning rate:", current_lr)

        if early_stopper.early_stop(val_loss):
            print(f"early stop on epoch {epoch}")
            with open('./gdrive/MyDrive/lsecs/dice_score_test/train_log.txt', "a+") as file:
                file.write(f'{config.model_type} early stop on epoch {epoch}\n')
            break
        if val_loss < smallest_val_loss:
            print(config.model_path)
            torch.save(model.state_dict(), config.model_path)
        smallest_val_loss = min(val_loss, smallest_val_loss)

        dice_scores.append(dice_score)
        val_losses.append(val_loss)
        print(f'Dice score: {round(dice_score, 2)}')
        # view_prediction(val_loader, model, device = DEVICE)
        print(f'train loss: {train_loss}, val loss: {val_loss}')
        if WANDB_CONNECTED:
            wandb.log({"train/train_loss": train_loss,
                       "train/epoch": epoch,
                       "val/val_loss": val_loss,
                       "val/dice_score":dice_score,
                       })
    if WANDB_CONNECTED:
        wandb.finish()

    return train_losses, val_losses, dice_scores

#**2. Insert training images**

In [None]:
# @title  { display-mode: "form" }
#@markdown ##**Insert Google Drive paths:**

#@markdown All Google Drive paths should start with ./gdrive/MyDrive/ (Check the folder structure in the left sidebar under **Files**).

#@markdown If you want to create new 224x224 patches, check the following box.
#@markdown If you already have image patches, insert the folders below, uncheck the box, and leave output_patches_folder_empty.
create_patches = True # @param {type:"boolean"}

# training_images = './gdrive/MyDrive/lsecs/masks_added_small/train_images' #@param {type:"string"}
# training_masks = './gdrive/MyDrive/lsecs/masks_added_small/train_masks' #@param {type:"string"}
# validation_images = './gdrive/MyDrive/lsecs/masks_added_small/val_images' #@param {type:"string"}
# validation_masks = './gdrive/MyDrive/lsecs/masks_added_small/val_masks' #@param {type:"string"}
training_images = './gdrive/MyDrive/lsecs/cropped_cells/train_images' #@param {type:"string"}
training_masks = './gdrive/MyDrive/lsecs/cropped_cells/train_masks' #@param {type:"string"}
validation_images = './gdrive/MyDrive/lsecs/cropped_cells/val_images' #@param {type:"string"}
validation_masks = './gdrive/MyDrive/lsecs/cropped_cells/val_masks' #@param {type:"string"}


output_patches_folder = './gdrive/MyDrive/lsecs/cropped_cells/patches' #@param {type:"string"}
#@markdown This can be used to reduce the number of patches when making them (1 means no reduction, 2 means half the number of patches will be saved...)
reduction_rate = 2 # @param {type:"number"}

training_images = training_images.strip()
training_masks = training_masks.strip()
validation_images = validation_images.strip()
validation_masks = validation_masks.strip()

output_patches_folder = output_patches_folder.strip()

if not os.path.exists(training_images):
    print(f'{training_images} does not exist.')
if not os.path.exists(training_masks):
    print(f'{training_masks} does not exist.')
if not os.path.exists(validation_images):
    print(f'{validation_images} does not exist.')
if not os.path.exists(validation_masks):
    print(f'{validation_masks} does not exist.')

SAVE_PATCHES_TO_DISK = True
patch_size = 224

if create_patches:
    if SAVE_PATCHES_TO_DISK:
        # output_folder = "./gdrive/MyDrive/lsecs/cropped_selections/patches"
        print(f'Saving patches to {output_patches_folder}')
    else:
        output_patches_folder = os.getcwd()
    train_img_patches_path, train_mask_patches_path, val_img_patches_path, val_mask_patches_path = create_train_val_patches(training_images, training_masks, validation_images, validation_masks, output_patches_folder, patch_size, reduction_rate)
else: # The patches will be read from disk
    train_img_patches_path = training_images
    train_mask_patches_path = training_masks
    val_img_patches_path = validation_images
    val_mask_patches_path = validation_masks

print(f'Training image patches are located in {train_img_patches_path}, {len(os.listdir(train_img_patches_path))} patches.')
print(f'Training mask patches are located in {train_mask_patches_path}')
print(f'Validation image patches are located in {val_img_patches_path}, {len(os.listdir(val_img_patches_path))} patches.')
print(f'Validation mask patches are located in {val_mask_patches_path}')

# **3. Training**

In [None]:
#@markdown You can load model weights and retrain them, otherwise ImageNet weights will be loaded.
load_model = False # @param {type:"boolean"}
model_path = './gdrive/MyDrive/lsecs/model_weights.pth' #@param {type:"string"}
#@markdown Check, if you want to use wandb for training evaluation:
use_wandb = False # @param {type:"boolean"}
WANDB_CONNECTED = use_wandb
#@markdown Where to save the trained model:
out_model_path = './gdrive/MyDrive/lsecs/new_model_weights.pth' #@param {type:"string"}

config = {
    'num_epochs' : 2,
    'model': 'unet',
    'loss_function': 'bce',
    'encoder': 'resnet34',
    'batch_size' : 32,
    'optimizer' : 'sgd',
    'learning_rate' : 0.04,
    'weight_decay' : 0.01,
    'momentum' : 0.07,
    'train_image_patches_path': train_img_patches_path,
    'train_mask_patches_path': train_mask_patches_path,
    'val_image_patches_path': val_img_patches_path,
    'val_mask_patches_path': val_mask_patches_path,
    'model_path':out_model_path,
}

if load_model:
    model = build_model('resnet34+none')
    loaded_state_dict = torch.load(model_path)
    model.load_state_dict(loaded_state_dict)
    model.eval()
else:
    model = None
train_losses, val_losses, dice_scores = train(config, model)

# 4. **Training log**

In [None]:
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Time')
plt.legend()
plt.show()

In [None]:
plt.plot(dice_scores, label='Dice score')
plt.xlabel('Epoch')
plt.ylabel('Dice score')
plt.title('Dice Score Over Time')
plt.legend()
plt.show()

# Wandb sweep

In [None]:
# This can be used to train multiple networks in one run

# Choose training parameters
output_folder = "./gdrive/MyDrive/lsecs/cropped_selections"

# wandb sweep config
sweep_config = {
    'method': 'grid'#'grid'#
    }
metric = {
    'name': 'val/dice_score',
    'goal': 'maximize'
    }

sweep_config['metric'] = metric

parameters_dict = {
    'optimizer': {
        # 'values': ['adam', 'sgd']
        'value': 'sgd'
        },
    'learning_rate': {
        'value': 0.04,
        # a flat distribution between min and max
        # 'distribution': 'uniform',
        # 'min': 0.001,
        # 'max': 0.01
      },
    'weight_decay': {
        # 'value': 0.0189,
        'value': 0.01
        # 'distribution': 'uniform',
        # 'min': 0.01,
        # 'max' : 0.02,
    },
    # sgd parameters
    'momentum':{
        'value': 0.07,
        # 'distribution': 'uniform',
        # 'min': 0.06,
        # 'max' : 0.09,
    },

    # 'dropout': {
    #     'value': 0.5,
    #     #   'values': [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
    #     },
    'epochs': {
        'value': 1,
        },

    # Dataloader params
    'train_image_patches_path': {
        'value': train_img_patches_path
        },
    'train_mask_patches_path': {
        'value': train_mask_patches_path
        },
    'val_image_patches_path': {
        'value': val_img_patches_path
        },
    'val_mask_patches_path': {
        'value': val_mask_patches_path
        },
    'batch_size': {
        'value': 32,
        # # integers between min and max
        # # with evenly-distributed logarithms
        # 'distribution': 'q_log_uniform_values',
        # 'q': 2, # the discrete step of the distribution
        # 'min': 4,
        # 'max': 8,
      },
    # Adam parameters
    # 'beta1': {
    #     'distribution': 'uniform',
    #     'min': 0.95,
    #     'max' : 0.999,
    # },
    # 'beta2': {
    #     'distribution': 'uniform',
    #     'min': 0.95,
    #     'max' : 0.999,
    # },
        # 'fc_layer_size': {
    #     'values': [128, 256, 512]
    #     },
    'image_denoising_methods': {
        'value': '11_10_indiv',
        # 'values': ['no_denoise', 'med5']
        # 'values': ['nlm', 'med5']
        # 'values': ['clahe+median5', 'med7', 'median5', 'median5+clahe', 'wave1_5+med3', 'wave2_5', 'wave2_5+med5'],#['wavelet', 'wavelet+median', 'advanced median'] # k waveletu jeste pridat ruzne thresholdy
    },
    'loss_function':{
        # 'value': 'bcelog',
        # 'values': ['dice', 'dice+bce', 'dice+focal', 'tversky'],#['dice', 'bcelog', 'jaccard', 'weighted_bce', 'focal'],#, 'tversky', 'hausdorff']
        # 'values': ['dice', 'dice+bce', 'focal','bcelog', 'dice+focal'],
        'values': ['bce'],


    },
    'model_type':{
        # 'values': ['plain_unet', 'resnet34+imagenet', 'resnet50+imagenet', 'inceptionv4+imagenet', 'efficientnet-b7+imagenet', 'resnet18+swsl', 'resnet18+imagenet','vgg11+imagenet'],
        # 'values': ['vgg11+imagenet', 'vgg13+imagenet', 'vgg16+imagenet', 'vgg19+imagenet',  'resnet18+ssl','resnet34+imagenet','resnet50+ssl', 'resnext50_32x4d+ssl'],
        # 'values': ['vgg11+imagenet','vgg13+imagenet', 'vgg16+imagenet', 'vgg19+imagenet',  'resnet18+ssl',  'resnet34+imagenet','resnet50+ssl', 'efficientnet-b7+imagenet'],
        # 'value': 'vgg11+imagenet',
        # 'values':['vgg11+imagenet','vgg13+imagenet', 'resnet18+ssl', 'resnet34+imagenet', 'efficientnet-b7+imagenet'],
        # 'values':['vgg13+imagenet', 'resnet18+ssl', 'resnet34+imagenet', [
        'values': ['resnet34+imagenet']

    },
    'model_path':{
        'value': './gdrive/MyDrive/lsecs',
    },
}

sweep_config['parameters'] = parameters_dict
sweep_id = wandb.sweep(sweep_config, project="LSEC_segmentation")

WANDB_CONNECTED = True
wandb.agent(sweep_id, wandb_train, count=510)

# **4. Inference evaluation**

In [None]:
# @title  { display-mode: "form" }
#@markdown Insert folders with cell images and their ground truth masks for comparison:
images_path = './gdrive/MyDrive/lsecs/dice_score_test/images' #@param {type:"string"}
ground_truth_mask_folder = './gdrive/MyDrive/lsecs/dice_score_test/ground_truth_masks_1205' #@param {type:"string"}
semiauto_mask_folder = './gdrive/MyDrive/lsecs/dice_score_test/semiautomatic_masks' #@param {type:"string"}

models_path = './gdrive/MyDrive/lsecs/model_weights/model_weights.pth' #@param {type:"string"}
cell_mask_path = './gdrive/MyDrive/lsecs/dice_score_test/cell_masks' #@param {type:"string"}

# filter_by_diameter = False # @param {type:"boolean"}

remove_false_fenestrations = True # @param {type:"boolean"}
pixel_size_nm = 9.28 #@param {type:"number"}
min_diameter_nm = 50 #@param {type:"number"}
max_diameter_nm = 350 #@param {type:"number"}
min_roundness = 0.4 # @param {type:"slider", min:0, max:1, step:0.1}

log_file_path = './gdrive/MyDrive/lsecs/dice_score_test/log.txt'


# images_path = './gdrive/MyDrive/lsecs/dice_score_test/semiautomatic_masks'

from sklearn.metrics import r2_score
import scipy.stats

ground_truth_mask_folder = ground_truth_mask_folder.strip()
images_path = images_path.strip()
models_path = models_path.strip()
cell_mask_path = cell_mask_path.strip()
semiauto_mask_folder = semiauto_mask_folder.strip()


# model_names = sorted([f for f in os.listdir(models_path) if os.path.isfile(os.path.join(models_path, f)) and 'pt' in f])
cells = sorted([f for f in os.listdir(cell_mask_path) if os.path.isfile(os.path.join(cell_mask_path, f))])

# print(model_names)



if not os.path.exists(images_path):
    print("Images folder does not exist")
    # exit()
if not os.path.exists(ground_truth_mask_folder):
    print("Folder with ground truth masks does not exist")
    # exit()

ground_truth_images = sorted([f for f in os.listdir(ground_truth_mask_folder) if os.path.isfile(os.path.join(ground_truth_mask_folder, f))])
images = sorted([f for f in os.listdir(images_path) if os.path.isfile(os.path.join(images_path, f))])
semiauto_images = sorted([f for f in os.listdir(semiauto_mask_folder) if os.path.isfile(os.path.join(semiauto_mask_folder, f))])

if len(ground_truth_images) != len(images) or len(images) != len(semiauto_images):
    print('The number of ground truths and images differs.')

def remove_contour_from_mask(contour, mask):
    # Fill the contour with black pixels
    cv.drawContours(mask, [contour], -1, 0, thickness=cv.FILLED)
    return mask

def remove_fenestrations(mask, min_d, max_d, min_roundness, pixel_size_nm):
    contours, _ = cv.findContours(mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    fenestration_areas = [cv.contourArea(cnt) * (pixel_size_nm**2) for cnt in contours]
    contour_centers = find_contour_centers(contours)
    ellipses, num_ellipses = fit_ellipses(contours, contour_centers)
    roundness_of_ellipses = []
    equivalent_diameters = []
    fenestration_areas_from_ellipses = []

    for contour, ellipse in zip(contours, ellipses):
        if ellipse != (None, None, None) and ellipse is not None:
            center, axes, _ = ellipse
            minor_axis_length, major_axis_length = axes
            if major_axis_length != 0 and major_axis_length < 20*minor_axis_length:
                roundness = minor_axis_length/major_axis_length
                if roundness >= min_roundness:
                    roundness_of_ellipses.append(roundness)
                diameter = pixel_size_nm * equivalent_circle_diameter(major_axis_length, minor_axis_length)
                if (diameter < min_d or diameter > max_d) or  (roundness < min_roundness) or np.isnan(diameter):
                    mask = remove_contour_from_mask(contour, mask)
                else:
                    equivalent_diameters.append(diameter)
                    fenestration_areas_from_ellipses.append((diameter**2)/4*math.pi)
            else:
                mask = remove_contour_from_mask(contour, mask)
        else:
            mask = remove_contour_from_mask(contour, mask)
    return mask

def compute_dice_score(image1, image2):
    eps = 1e-8
    image1[image1 == 255] = 1
    image2[image2 == 255] = 1
    intersection_sum = np.logical_and(image1, image2).sum()
    dice_score = (2*intersection_sum+eps)/(image1.sum() + image2.sum() + eps)
    return dice_score

# images_path = './gdrive/MyDrive/lsecs/dice_score_test/semiautomatic_masks'
# dice_scores = []
# with open(log_file_path, "a+") as file:
#     file.write(f'{len(images)} images\n')
#     for ground_truth_mask_name, image_name, cell_name  in zip(ground_truth_images, images, cells):
#         print(f'Compare: {ground_truth_mask_name} - {image_name} - {cell_name}')
#         file.write(f'Compare: {ground_truth_mask_name} - {image_name}\n')
#         ground_truth_mask_path = os.path.join(ground_truth_mask_folder, ground_truth_mask_name)
#         image_path = os.path.join(images_path, image_name)
#         cell_path = os.path.join(cell_mask_path, cell_name)
#         cell = cv.imread(cell_path, cv.IMREAD_GRAYSCALE)
#         ground_truth_mask = cv.imread(ground_truth_mask_path, cv.IMREAD_GRAYSCALE)
#         image_mask = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
#         image_mask[cell == 0] = 0
#         current_dice_score = compute_dice_score(ground_truth_mask, image_mask)
#         print(f'Image Dice score: {round(current_dice_score*100, 1)}')
#         file.write(f'Image Dice score: {round(current_dice_score*100, 1)}\n')
#         dice_scores.append(current_dice_score)

#     dice_scores = np.array(dice_scores)
#     mean_dice = round(np.mean(dice_scores)*100, 1)
#     std_dice = round(np.std(dice_scores)*100, 1)

#     print(f'Semiautomatic Mean dice: {mean_dice} +- {std_dice}\n')
#     file.write(f'Semiautomatic Mean dice: {mean_dice} += {std_dice}\n\n')


def get_fenestrations_from_image(mask):
    contours, _ = cv.findContours(mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    # fenestration_areas = [cv.contourArea(cnt) * (pixel_size_nm**2) for cnt in contours]
    contour_centers = find_contour_centers(contours)
    ellipses, num_ellipses = fit_ellipses(contours, contour_centers)
    return ellipses, num_ellipses, contours, len(contours)

def get_image_stats(contours, ellipses, pixel_size_nm, min_d=0, max_d=100000, min_roundness=0):
    fenestration_areas = [cv.contourArea(cnt) * (pixel_size_nm**2) for cnt in contours]
    roundness_of_ellipses = []
    equivalent_diameters = []
    fenestration_areas_from_ellipses = []

    # Remove all contours that do not fit the chosen conditions
    # Also remove all contours that were too small to fit an ellipse
    for contour, ellipse in zip(contours, ellipses):
        if ellipse is not None and ellipse != (None, None, None):
            # print(ellipse)
            center, axes, _ = ellipse
            # center_x, center_y = center
            minor_axis_length, major_axis_length = axes
            if major_axis_length != 0 and major_axis_length < 20*minor_axis_length: # The fitting algorithm can fail sometimes
                roundness = minor_axis_length/major_axis_length
                if roundness >= min_roundness:
                    roundness_of_ellipses.append(roundness)
                diameter_pix = equivalent_circle_diameter(major_axis_length, minor_axis_length)
                diameter = pixel_size_nm * diameter_pix
                # print(contour)
                # print(diameter)
                if (diameter < min_d or diameter > max_d) or (roundness < min_roundness) or np.isnan(diameter):
                    # mask = remove_contour_from_mask(contour, mask)
                    continue
                else:
                    equivalent_diameters.append(diameter)
                    fenestration_areas_from_ellipses.append((diameter_pix**2)/4*math.pi)
    return equivalent_diameters, roundness_of_ellipses, fenestration_areas_from_ellipses

import seaborn as sns
# Get stats

save_plots = False
save_image_masks = False
plot_path = './gdrive/MyDrive/lsecs/plots/'

# for model_name in model_names:
all_my_diameters = []
all_gt_diameters = []
all_s_diameters = []

all_my_means = []
all_gt_means = []
all_s_means = []

all_my_roundness = []
all_gt_roundness = []
all_s_roundness = []

num_all_my_ellipses = []
num_all_gt_ellipses = []
num_all_s_ellipses = []

my_dice_scores = []
s_dice_scores = []

all_gt_porosities_pix = []
all_s_porosities_pix = []
all_my_porosities_pix = []

all_gt_porosities_ell = []
all_s_porosities_ell = []
all_my_porosities_ell = []

all_gt_frequencies = []
all_s_frequencies = []
all_my_frequencies = []

# file.write(f'{model_name}\n')
# print(model_name)
model = build_model('resnet34+none')
loaded_state_dict = torch.load(models_path) # TODO:these models do not include sigmoid and preprocessing yet
model.load_state_dict(loaded_state_dict)
model.eval()
# dice_scores = []
# dice_scores_filt = []
for ground_truth_mask_name, image_name, semiauto_image_name, cell_name  in zip(ground_truth_images, images, semiauto_images, cells):
    print(f'Compare: {ground_truth_mask_name} - {image_name} - {semiauto_image_name} -{cell_name}')
    ground_truth_mask_path = os.path.join(ground_truth_mask_folder, ground_truth_mask_name)
    image_path = os.path.join(images_path, image_name)
    semiauto_image_path = os.path.join(semiauto_mask_folder, semiauto_image_name)
    cell_path = os.path.join(cell_mask_path, cell_name)
    image = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
    cell = cv.imread(cell_path, cv.IMREAD_GRAYSCALE)
    ground_truth_mask = cv.imread(ground_truth_mask_path, cv.IMREAD_GRAYSCALE)
    semiauto_mask = cv.imread(semiauto_image_path, cv.IMREAD_GRAYSCALE)

    new_mask = inference_on_image_with_overlap(model, image_path)

    new_mask[cell == 0] = 0
    semiauto_mask[cell == 0] = 0
    ground_truth_mask[cell == 0] = 0


    # I need to fill the cell nucleus to compute the area
    contours, hierarchy = cv.findContours(
        cell, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)

    # Get the largest object in the image(the cell)
    empty_mask = np.zeros_like(cell)
    areas = []
    for cnt in contours:
        area = cv.contourArea(cnt)
        areas.append(area)
    areas = np.array(areas)
    max_area_idx = np.argmax(areas)

    # Fill this object and compute its area
    c = 0
    for cnt in contours:
        if c == max_area_idx:
            cv.drawContours(empty_mask, [cnt], -1, 1, thickness=cv.FILLED)
            cell_contour, hierarchy = cv.findContours(
                empty_mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
        c += 1
    cell_area = float(np.sum(empty_mask))


    # Remove fenestrations
    new_mask_filt = remove_fenestrations(new_mask.copy(), min_diameter_nm, max_diameter_nm, min_roundness, pixel_size_nm)
    s_mask_filt = remove_fenestrations(semiauto_mask.copy(), min_diameter_nm, max_diameter_nm, min_roundness, pixel_size_nm)
    gt_mask_filt = remove_fenestrations(ground_truth_mask.copy(), min_diameter_nm, max_diameter_nm, min_roundness, pixel_size_nm)

    if save_image_masks:
        cv.imwrite(os.path.join(plot_path, image_name), new_mask_filt)

    # Save filtered masks
        merge = np.zeros((new_mask_filt.shape[0], new_mask_filt.shape[1], 3))
        merge = merge.astype('uint8')
        merge[:, :, 2][gt_mask_filt == 255.0] = 255 #R
        merge2 = merge.copy()
        merge3 = merge.copy()
        merge[:, :, 0][s_mask_filt == 255.0] = 255 #B
        merge[:, :, 1][s_mask_filt == 255.0] = 100 #G, so the blue is more visible

        merge[:, :, 1][(s_mask_filt == 255.0) & (gt_mask_filt == 255.0)] = 255 #G

        cv.imwrite(os.path.join(plot_path, image_name+"_mask_compare_semiautomatic"+'.png'), merge)

        merge2[:, :, 1][new_mask_filt == 255.0] = 255 #G
        merge2[:, :, 0][(new_mask_filt == 255.0) & (gt_mask_filt == 255.0)] = 255 #B
        cv.imwrite(os.path.join(plot_path, image_name+"_mask_compare_automatic"+'.png'), merge2)

        # merge3[:, :, 0][s_mask_filt == 255.0] = 255 #B
        # merge3[:, :, 1][s_mask_filt == 255.0] = 100 #G, so the blue is more visible
        # merge3[:, :, 1][new_mask_filt == 255.0] = 255 #G
        # cv.imwrite(os.path.join(plot_path, image_name+"_mask_compare"+'.png'), merge3)




    gt_cell_area_pix = np.sum(gt_mask_filt/255)
    s_cell_area_pix = np.sum(s_mask_filt/255)
    my_cell_area_pix = np.sum(new_mask_filt/255)

    gt_fen_area = float(gt_cell_area_pix)
    s_fen_area = float(s_cell_area_pix)
    my_fen_area = float(my_cell_area_pix)

    gt_porosity_pix = round(gt_fen_area/cell_area*100, 1)
    s_porosity_pix = round(s_fen_area/cell_area*100, 1)
    my_porosity_pix = round(my_fen_area/cell_area*100, 1)

    all_gt_porosities_pix.append(gt_porosity_pix)
    all_s_porosities_pix.append(s_porosity_pix)
    all_my_porosities_pix.append(my_porosity_pix)

    # print(f'gt: {gt_porosity_pix}, s: {s_porosity_pix}, my: {my_porosity_pix}')



    # Compute Dice scores
    my_current_dice = compute_dice_score(gt_mask_filt, new_mask_filt)
    s_current_dice = compute_dice_score(gt_mask_filt, s_mask_filt)

    my_dice_scores.append(my_current_dice)
    s_dice_scores.append(s_current_dice)
    # print(f'Image Dice score: {round(my_current_dice*100, 1)}')
    #

    # my data
    my_ellipses, num_my_ellipses, my_objects, num_all_my_objects = get_fenestrations_from_image(new_mask)
    my_equivalent_diameters, my_roundness_of_ellipses, my_fenestration_areas_from_ellipses = get_image_stats(my_objects, my_ellipses, pixel_size_nm, min_d=50, max_d=400, min_roundness=0.4)
    my_ell_area = np.sum(np.array(my_fenestration_areas_from_ellipses))
    all_my_porosities_ell.append(round(my_ell_area/cell_area*100, 1))
    # ground truth
    ellipses, num_ellipses, objects, num_all_objects = get_fenestrations_from_image(ground_truth_mask)
    equivalent_diameters, roundness_of_ellipses, fenestration_areas_from_ellipses = get_image_stats(objects, ellipses, pixel_size_nm, min_d=50, max_d=400, min_roundness=0.4)
    gt_ell_area = np.sum(np.array(fenestration_areas_from_ellipses))
    all_gt_porosities_ell.append(round(gt_ell_area/cell_area*100, 1))
    # semiautomatic_data
    s_ellipses, s_num_ellipses, s_objects, s_num_all_objects = get_fenestrations_from_image(semiauto_mask)
    s_equivalent_diameters, s_roundness_of_ellipses, s_fenestration_areas_from_ellipses = get_image_stats(s_objects, s_ellipses, pixel_size_nm, min_d=50, max_d=400, min_roundness=0.4)
    s_ell_area = np.sum(np.array(s_fenestration_areas_from_ellipses))
    all_s_porosities_ell.append(round(s_ell_area/cell_area*100, 1))

    all_my_diameters.extend(my_equivalent_diameters)
    all_gt_diameters.extend(equivalent_diameters)
    all_s_diameters.extend(s_equivalent_diameters)

    all_my_roundness.extend(my_roundness_of_ellipses)
    all_gt_roundness.extend(roundness_of_ellipses)
    all_s_roundness.extend(s_roundness_of_ellipses)


    n_my_ellipses = len(my_equivalent_diameters)
    n_s_ellipses = len(s_equivalent_diameters)
    n_gt_ellipses = len(equivalent_diameters)

    gt_freq = n_gt_ellipses/(cell_area*((pixel_size_nm/1000)**2))
    s_freq = n_s_ellipses/(cell_area*((pixel_size_nm/1000)**2))
    my_freq = n_my_ellipses/(cell_area*((pixel_size_nm/1000)**2))

    all_gt_frequencies.append(round(gt_freq, 1))
    all_s_frequencies.append(round(s_freq, 1))
    all_my_frequencies.append(round(my_freq, 1))

    num_all_my_ellipses.append(n_my_ellipses)
    num_all_gt_ellipses.append(n_gt_ellipses)
    num_all_s_ellipses.append(n_s_ellipses)

    gt_mean = round(np.mean(np.array(equivalent_diameters)))
    gt_std = round(np.std(np.array(equivalent_diameters)))
    my_mean = round(np.mean(np.array(my_equivalent_diameters)))
    my_std = round(np.std(np.array(my_equivalent_diameters)))
    s_mean = round(np.mean(np.array(s_equivalent_diameters)))
    s_std = round(np.std(np.array(s_equivalent_diameters)))


    all_my_means.append(my_mean)
    all_gt_means.append(gt_mean)
    all_s_means.append(s_mean)

    fig, ax = plt.subplots(3, 2, figsize=(15, 23))
    ax = ax.flatten()
    # plt.subplots(1, 2, figsize=(15, 7))
    sns.set_theme()  # This changes the look of plots.
    nbins = 20
    # Calculate density for each dataset
    hist_x, bin_edges_x  = np.histogram(equivalent_diameters, bins=20, range = (min_diameter_nm, max_diameter_nm), density=True)
    hist_y, bin_edges_y = np.histogram(my_equivalent_diameters, bins=20, range = (min_diameter_nm, max_diameter_nm), density=True)
    hist_z, bin_edges_z = np.histogram(s_equivalent_diameters, bins=20, range = (min_diameter_nm, max_diameter_nm), density=True)
    hist_x = hist_x*np.diff(bin_edges_x)
    hist_y = hist_y*np.diff(bin_edges_y)
    hist_z = hist_z*np.diff(bin_edges_z)

    bin_width = (max_diameter_nm - min_diameter_nm)/3/nbins
    shift1 = bin_width/2
    shift2 = bin_width/2 + bin_width
    shift3 = bin_width/2 + 2*bin_width


    ax[0].bar(bin_edges_x[:-1]+shift1, hist_x, color='r', alpha=0.9, width=bin_width)
    ax[0].bar(bin_edges_y[:-1]+shift2, hist_y, color='g', alpha=0.9, width=bin_width)
    ax[0].bar(bin_edges_z[:-1]+shift3, hist_z, color='b', alpha=0.9, width=bin_width)
    # ax[0].legend([f'Ground truth, mean: {gt_mean} nm\n{len(equivalent_diameters)} fenestrations', f'Automatic, mean: {my_mean} nm\n{len(my_equivalent_diameters)} fenestrations', f'Semiautomatic, mean: {s_mean} nm\n{len(s_equivalent_diameters)} fenestrations'], fontsize='small')
    ax[0].legend([f'Ground truth\nmean d = {gt_mean} nm', f'Automatic\nmean d = {my_mean} nm', f'Semiautomatic\nmean d = {s_mean} nm'], fontsize='medium')

    # Add title and axis labels
    # ax[0].set_title(image_name)
    ax[0].set_xlabel('Equivalent diameter (nm)', fontsize=16)
    ax[0].set_ylabel('Probability', fontsize=16)
    ax[0].set_xlim(min_diameter_nm-20, max_diameter_nm+20)

    # x_ticks = [50, 100, 150, 200, 250, 300, 350, 400]  # Define which x ticks to show
    x_ticks = list(range(min_diameter_nm, max_diameter_nm+1, 50))
    ax[0].set_xticks(x_ticks)  # Set the x ticks
    ax[0].set_xticklabels(x_ticks)  # Set the labels for the x ticks
    ax[0].tick_params(axis='x', labelsize=12)
    ax[0].tick_params(axis='y', labelsize=12)


    # Calculate density for each dataset
    nbins = 12
    hist_x, bin_edges_x  = np.histogram(roundness_of_ellipses, bins=nbins, range = (min_roundness, 1), density=True)
    hist_y, bin_edges_y = np.histogram(my_roundness_of_ellipses, bins=nbins, range = (min_roundness, 1), density=True)
    hist_z, bin_edges_z = np.histogram(s_roundness_of_ellipses, bins=nbins, range = (min_roundness, 1), density=True)
    hist_x = hist_x*np.diff(bin_edges_x)
    hist_y = hist_y*np.diff(bin_edges_y)
    hist_z = hist_z*np.diff(bin_edges_z)

    bin_width = (1 - min_roundness)/3/nbins
    shift1 = bin_width/2
    shift2 = bin_width/2 + bin_width
    shift3 = bin_width/2 + 2*bin_width

    ax[1].bar(bin_edges_x[:-1]+shift1, hist_x, color='r', alpha=0.9, width=bin_width)
    ax[1].bar(bin_edges_y[:-1]+shift2, hist_y, color='g', alpha=0.9, width=bin_width)
    ax[1].bar(bin_edges_z[:-1]+shift3, hist_z, color='b', alpha=0.9, width=bin_width)
    # # i += 1
    # plt.subplot(1, 2)
    # ax[1].hist([roundness_of_ellipses, my_roundness_of_ellipses, s_roundness_of_ellipses], color=['g','r','b'], alpha=0.8, density=True)
    ax[1].legend(['Ground truth', 'Automatic', 'Semiautomatic'], fontsize='medium')
    # Add title and axis labels
    # ax[1].set_title(image_name)
    ax[1].set_xlabel('Roundness of fitted ellipses', fontsize=16)
    ax[1].set_ylabel('Probability', fontsize=16)
    ax[1].set_xlim(min_roundness - 0.02, 1.02)
    ax[1].tick_params(axis='x', labelsize=12)
    ax[1].tick_params(axis='y', labelsize=12)

    # Show image patch
    r = slice(3600,4200)
    c = slice(2600,3200)
    image_patch = image[r, c]
    show_size = (300, 300)
    image_patch = cv.resize(image_patch, show_size)
    ax[2].imshow(image_patch, cmap='gray', vmin=0, vmax=255)  # Specify min and max values
    ax[2].set_title('Image patch example (600x600 pixels)', fontsize=16)
    ax[2].axis('off')

    ax[3].axis('off')

    gt_patch = ground_truth_mask[r, c]
    s_patch = semiauto_mask[r, c]
    my_patch = new_mask[r, c]
    # print(gt_patch.shape)

    gt_patch = remove_fenestrations(gt_patch, min_diameter_nm, max_diameter_nm, min_roundness, pixel_size_nm)
    s_patch = remove_fenestrations(s_patch, min_diameter_nm, max_diameter_nm, min_roundness, pixel_size_nm)
    my_patch = remove_fenestrations(my_patch, min_diameter_nm, max_diameter_nm, min_roundness, pixel_size_nm)

    # image_patch = cv.resize(image_patch, show_size, interpolation=cv.INTER_NEAREST)

    merge = np.zeros((gt_patch.shape[0], gt_patch.shape[1], 3))
    merge[:, :, 0][gt_patch == 255] = 255 # R channel
    merge = merge.astype('uint8')
    ax[4].imshow(merge, cmap='gray', vmin=0, vmax=255)  # Specify min and max values
    ax[4].set_title('Ground truth mask', fontsize=16)
    ax[4].axis('off')

    merge[:, :, 2][s_patch == 255] = 255 # B channel
    merge[:, :, 1][s_patch == 255] = 100 # G channel
    merge[:, :, 1][my_patch == 255] = 255 # G channel
    merge = merge.astype('uint8')

    ax[5].imshow(merge, cmap='gray', vmin=0, vmax=255)  # Specify min and max values
    ax[5].set_title('Mask comparison', fontsize=16)
    ax[5].axis('off')
    leg = f'R = automatic FN & semiautomatic FN\nG = automatic FP\nB = semiautomatic FP'
    leg += '\nC = automatic FN & semiautomatic FN'
    leg += '\nM = automatic FN & semiautomatic TP'
    leg += '\nY  = automatic TP & semiautomatic FN'
    leg += '\nW = automatic TP & semiautomatic TP'
    ax[5].text(0, merge.shape[0]+210, leg, ha='left', fontsize=16)
    if save_plots:
        plt.savefig(os.path.join(plot_path, f'stats_{image_name}.svg'), format='svg')


    plt.show()

# Stats for whole dataset
gt_mean = round(np.mean(np.array(all_gt_diameters)))
gt_std = round(np.std(np.array(all_gt_diameters)))
my_mean = round(np.mean(np.array(all_my_diameters)))
my_std = round(np.std(np.array(all_my_diameters)))
s_mean = round(np.mean(np.array(all_s_diameters)))
s_std = round(np.std(np.array(all_s_diameters)))

fig, ax = plt.subplots(1, 2, figsize=(13, 7))
ax = ax.flatten()
sns.set_theme()
# Calculate density for each dataset
nbins=20
hist_x, bin_edges_x  = np.histogram(all_gt_diameters, bins=nbins, range = (min_diameter_nm, max_diameter_nm), density=True)
hist_y, bin_edges_y = np.histogram(all_my_diameters, bins=nbins, range = (min_diameter_nm, max_diameter_nm), density=True)
hist_z, bin_edges_z = np.histogram(all_s_diameters, bins=nbins, range = (min_diameter_nm, max_diameter_nm), density=True)
hist_x = hist_x*np.diff(bin_edges_x)
hist_y = hist_y*np.diff(bin_edges_y)
hist_z = hist_z*np.diff(bin_edges_z)

bin_width = (max_diameter_nm - min_diameter_nm)/3/nbins
shift1 = bin_width/2
shift2 = bin_width/2 + bin_width
shift3 = bin_width/2 + 2*bin_width

ax[0].bar(bin_edges_x[:-1]+shift1, hist_x, color='r', alpha=0.9, width=bin_width)
ax[0].bar(bin_edges_y[:-1]+shift2, hist_y, color='g', alpha=0.9, width=bin_width)
ax[0].bar(bin_edges_z[:-1]+shift3, hist_z, color='b', alpha=0.9, width=bin_width)
# ax[0].legend([f'Ground truth, mean: {gt_mean} nm\n{len(all_gt_diameters)} fenestrations', f'Automatic, mean: {my_mean} nm\n{len(all_my_diameters)} fenestrations', f'Semiautomatic, mean: {s_mean} nm\n{len(all_s_diameters)} fenestrations'], fontsize='small')
ax[0].legend([f'Ground truth\nmean d = {gt_mean} nm', f'Automatic\nmean d = {my_mean} nm', f'Semiautomatic\nmean d = {s_mean} nm'], fontsize='medium')

# Add title and axis labels
ax[0].set_title('Whole dataset', fontsize=22)
ax[0].set_xlabel('Equivalent diameter (nm)', fontsize=16)
ax[0].set_ylabel('Probability', fontsize=16)
ax[0].set_xlim(min_diameter_nm-20, max_diameter_nm+20)

# x_ticks = [50, 100, 150, 200, 250, 300, 350, 400]  # Define which x ticks to show
x_ticks = list(range(min_diameter_nm, max_diameter_nm+1, 50))
ax[0].set_xticks(x_ticks)  # Set the x ticks
ax[0].set_xticklabels(x_ticks)  # Set the labels for the x ticks
ax[0].tick_params(axis='x', labelsize=12)
ax[0].tick_params(axis='y', labelsize=12)


# Stats for whole dataset
gt_mean = round(np.mean(np.array(all_gt_roundness)))
gt_std = round(np.std(np.array(all_gt_roundness)))
my_mean = round(np.mean(np.array(all_my_roundness)))
my_std = round(np.std(np.array(all_my_roundness)))
s_mean = round(np.mean(np.array(all_s_roundness)))
s_std = round(np.std(np.array(all_s_roundness)))

# Calculate density for each dataset


nbins=12
hist_x, bin_edges_x  = np.histogram(all_gt_roundness, bins=nbins, range = (min_roundness, 1), density=True)
hist_y, bin_edges_y = np.histogram(all_my_roundness, bins=nbins, range = (min_roundness, 1), density=True)
hist_z, bin_edges_z = np.histogram(all_s_roundness, bins=nbins, range = (min_roundness, 1), density=True)
hist_x = hist_x*np.diff(bin_edges_x)
hist_y = hist_y*np.diff(bin_edges_y)
hist_z = hist_z*np.diff(bin_edges_z)

bin_width = (1 - min_roundness)/3/nbins
shift1 = bin_width/2
shift2 = bin_width/2 + bin_width
shift3 = bin_width/2 + 2*bin_width

ax[1].bar(bin_edges_x[:-1]+shift1, hist_x, color='r', alpha=0.9, width=bin_width)
ax[1].bar(bin_edges_y[:-1]+shift2, hist_y, color='g', alpha=0.9, width=bin_width)
ax[1].bar(bin_edges_z[:-1]+shift3, hist_z, color='b', alpha=0.9, width=bin_width)
# ax[1].legend([f'Ground truth\nmean roundness = {gt_mean}', f'Automatic\nmean roundness = {my_mean}', f'Semiautomatic\nmean roundness = {s_mean}'], fontsize='small')
# Add title and axis labels
ax[1].set_title('Whole dataset', fontsize=22)
ax[1].set_xlabel('Roundness of fitted ellipses', fontsize=16)
ax[1].set_ylabel('Probability', fontsize=16)
ax[1].legend(['Ground truth', 'Automatic', 'Semiautomatic'], fontsize='medium')
ax[1].set_xlim(min_roundness-0.02, 1.02)
ax[1].tick_params(axis='x', labelsize=12)
ax[1].tick_params(axis='y', labelsize=12)
if save_plots:
    plt.savefig(os.path.join(plot_path, 'whole.svg'), format='svg')
plt.show()

# Correlation for the number of found ellipses
num_all_my_ellipses = np.array(num_all_my_ellipses)
num_all_gt_ellipses = np.array(num_all_gt_ellipses)
num_all_s_ellipses =  np.array(num_all_s_ellipses)

sorted_indices = np.argsort(num_all_gt_ellipses)
num_all_gt_ellipses = num_all_gt_ellipses[sorted_indices]
num_all_my_ellipses = num_all_my_ellipses[sorted_indices]
num_all_s_ellipses = num_all_s_ellipses[sorted_indices]

# Fit a linear function to the data
my_coefficients = np.polyfit(num_all_gt_ellipses, num_all_my_ellipses, 1)
my_fit_line = np.poly1d(my_coefficients)

# Fit a linear function to the data
s_coefficients = np.polyfit(num_all_gt_ellipses, num_all_s_ellipses, 1)
s_fit_line = np.poly1d(s_coefficients)

my_r_squared = round(r2_score(num_all_gt_ellipses, my_fit_line(num_all_gt_ellipses)), 2)
my_r_squared ='{:.2f}'.format(my_r_squared)

s_r_squared = round(r2_score(num_all_gt_ellipses, s_fit_line(num_all_s_ellipses)), 2)
s_r_squared ='{:.2f}'.format(s_r_squared)

s_fitted_values = np.polyval(s_coefficients, num_all_gt_ellipses)
s_residuals = num_all_s_ellipses - s_fitted_values
s_residuals_sd = np.std(s_residuals)

my_fitted_values = np.polyval(my_coefficients, num_all_gt_ellipses)
my_residuals = num_all_my_ellipses - my_fitted_values
my_residuals_sd = np.std(my_residuals)


s_derivative_coefficients = np.polyder(s_coefficients)
my_derivative_coefficients = np.polyder(my_coefficients)

s_tg = np.polyval(s_derivative_coefficients, 1000)
my_tg = np.polyval(my_derivative_coefficients, 1000)
s_tg = '{:.2f}'.format(round(s_tg, 2))
my_tg = '{:.2f}'.format(round(my_tg, 2))
s_residuals_sd = round(s_residuals_sd)
my_residuals_sd = round(my_residuals_sd)


print(f's tangent {s_tg}+-{s_residuals_sd}')
print(f'my tangent {my_tg}+-{my_residuals_sd}')


fig, ax = plt.subplots(1, 1, figsize=(7, 7))
# ax = ax.flatten()
sns.set_theme()
ax.plot(num_all_gt_ellipses, my_fit_line(num_all_gt_ellipses), color='g', linestyle='--', alpha=0.6, label=f'Automatic linear fit: s = {my_tg}±{my_residuals_sd}')
ax.plot(num_all_gt_ellipses, s_fit_line(num_all_gt_ellipses), color='b', linestyle='--', alpha=0.6, label=f'Semiautomatic linear fit: s = {s_tg}±{s_residuals_sd}')
ax.scatter(num_all_gt_ellipses, num_all_my_ellipses, color='g', marker='o', label=f'Automatic: R² = {my_r_squared}', alpha=np.ones_like(num_all_gt_ellipses))
ax.scatter(num_all_gt_ellipses, num_all_s_ellipses, color='b', marker='s', label=f'Semiautomatic: R² = {s_r_squared}', alpha=np.ones_like(num_all_gt_ellipses))

ax.legend(fontsize='medium')
ax.set_xlabel('Ground truth number of fenestrations per image', fontsize=16)
ax.set_ylabel('Detected number of fenestrations per image', fontsize=16)
ax.set_title(f'Correlation plot of the number of detected fenestrations\nwith the automatic and semiautomatic method', fontsize=22)
ax.grid(True)
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
if save_plots:
    plt.savefig(os.path.join(plot_path, 'corr_num.svg'), format='svg')
plt.show()

# print(r2_score(num_all_gt_ellipses,num_all_s_ellipses), r2_score(num_all_gt_ellipses,num_all_my_ellipses))
# print(scipy.stats.pearsonr(num_all_gt_ellipses, num_all_s_ellipses)[0]**2, scipy.stats.pearsonr(num_all_gt_ellipses, num_all_my_ellipses)[0]**2)


print('num ellipses')
print(f'gt {np.sum(num_all_gt_ellipses)}, my {np.sum(num_all_my_ellipses)}, s {np.sum(num_all_s_ellipses)}')
print(my_dice_scores)
print(s_dice_scores)
print('my')
for dice in my_dice_scores:
    print(round(dice, 2))
my_dice_scores = np.array(my_dice_scores)
my_mean_dice = round(np.mean(my_dice_scores), 2)
my_std_dice = round(np.std(my_dice_scores), 2)
print(f'Auto mean dice: {my_mean_dice} +- {my_std_dice}')
print('s')
for dice in s_dice_scores:
    print(round(dice, 2))
s_dice_scores = np.array(s_dice_scores)
s_mean_dice = round(np.mean(s_dice_scores), 2)
s_std_dice = round(np.std(s_dice_scores), 2)
print(f'Semiauto mean dice: {s_mean_dice} +- {s_std_dice}')

# correlation of means


all_my_means = np.array(all_my_means)
all_gt_means = np.array(all_gt_means)
all_s_means =  np.array(all_s_means)

sorted_indices = np.argsort(all_gt_means)
all_gt_means = all_gt_means[sorted_indices]
all_my_means = all_my_means[sorted_indices]
all_s_means = all_s_means[sorted_indices]

# Fit a linear function to the data
my_coefficients = np.polyfit(all_gt_means, all_my_means, 1)
my_fit_line = np.poly1d(my_coefficients)

# Fit a linear function to the data
s_coefficients = np.polyfit(all_gt_means, all_s_means, 1)
s_fit_line = np.poly1d(s_coefficients)

s_fitted_values = np.polyval(s_coefficients, all_gt_means)
s_residuals = all_s_means - s_fitted_values
s_residuals_sd = np.std(s_residuals)

my_fitted_values = np.polyval(my_coefficients, all_gt_means)
my_residuals = all_my_means - my_fitted_values
my_residuals_sd = np.std(my_residuals)

my_r_squared = round(r2_score(all_gt_means, my_fit_line(all_gt_means)), 2)
my_r_squared ='{:.2f}'.format(my_r_squared)

s_r_squared = round(r2_score(all_gt_means, s_fit_line(all_gt_means)), 2)
s_r_squared ='{:.2f}'.format(s_r_squared)

s_derivative_coefficients = np.polyder(s_coefficients)
my_derivative_coefficients = np.polyder(my_coefficients)

s_tg = np.polyval(s_derivative_coefficients, 150)
my_tg = np.polyval(my_derivative_coefficients, 150)
s_tg = '{:.2f}'.format(round(s_tg, 2))
my_tg = '{:.2f}'.format(round(my_tg, 2))
s_residuals_sd = round(s_residuals_sd)
my_residuals_sd = round(my_residuals_sd)

print('mean values')
print(f's tangent {s_tg}+-{s_residuals_sd}')
print(f'my tangent {my_tg}+-{my_residuals_sd}')

fig, ax = plt.subplots(1, 1, figsize=(7, 7))
sns.set_theme()
ax.plot(all_gt_means, my_fit_line(all_gt_means), color='g', linestyle='--', alpha=0.6, label=f'Automatic linear fit: s = {my_tg}±{my_residuals_sd}')
ax.plot(all_gt_means, s_fit_line(all_gt_means), color='b', linestyle='--', alpha=0.6, label=f'Semiautomatic linear fit: s = {s_tg}±{s_residuals_sd}')
ax.scatter(all_gt_means, all_my_means, color='g', marker='o', label=f'Automatic: R² = {my_r_squared}', alpha=np.ones_like(all_gt_means))
ax.scatter(all_gt_means, all_s_means, color='b', marker='s', label=f'Semiautomatic: R² = {s_r_squared}', alpha=np.ones_like(all_gt_means))

ax.legend(fontsize='medium')
ax.set_xlabel('Ground truth mean fenestration\nequivalent diameter per image (nm)', fontsize=16)
ax.set_ylabel('Detected mean fenestration\nequivalent diameter per image (nm)', fontsize=16)
ax.set_title(f'Correlation plot of the mean fenestration diameter\nwith the automatic and semiautomatic method', fontsize=22)
ax.grid(True)
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
x_ticks = list(range(110, 180+1, 10))
ax.set_xticks(x_ticks)
ax.set_yticks(x_ticks)
if save_plots:
    plt.savefig(os.path.join(plot_path, 'corr_diameter.svg'), format='svg')
plt.show()
print('gt means')
for mean in all_gt_means:
    print(mean)
print('my means')
for mean in all_my_means:
    print(mean)
print('s means')
for mean in all_s_means:
    print(mean)



print('porosity')
print('gt')
for p in all_gt_porosities_pix:
    print(p)
print('s')
for p in all_s_porosities_pix:
    print(p)
print('my')
for p in all_my_porosities_pix:
    print(p)


# Porosity correlation
all_gt_porosities_pix = np.array(all_gt_porosities_pix)
all_s_porosities_pix = np.array(all_s_porosities_pix)
all_my_porosities_pix =  np.array(all_my_porosities_pix)

sorted_indices = np.argsort(all_gt_porosities_pix)
all_gt_porosities_pix = all_gt_porosities_pix[sorted_indices]
all_s_porosities_pix = all_s_porosities_pix[sorted_indices]
all_my_porosities_pix = all_my_porosities_pix[sorted_indices]


# Fit a linear function to the data
my_coefficients = np.polyfit(all_gt_porosities_pix, all_my_porosities_pix, 1)
my_fit_line = np.poly1d(my_coefficients)

# Fit a linear function to the data
s_coefficients = np.polyfit(all_gt_porosities_pix, all_s_porosities_pix, 1)
s_fit_line = np.poly1d(s_coefficients)

s_fitted_values = np.polyval(s_coefficients, all_gt_porosities_pix)
s_residuals = all_s_porosities_pix - s_fitted_values
s_residuals_sd = np.std(s_residuals)

my_fitted_values = np.polyval(my_coefficients, all_gt_porosities_pix)
my_residuals = all_my_porosities_pix - my_fitted_values
my_residuals_sd = np.std(my_residuals)


my_r_squared = round(r2_score(all_gt_porosities_pix, my_fit_line(all_gt_porosities_pix)), 2)
my_r_squared ='{:.2f}'.format(my_r_squared)

s_r_squared = round(r2_score(all_gt_porosities_pix, s_fit_line(all_gt_porosities_pix)), 2)
s_r_squared ='{:.2f}'.format(s_r_squared)

s_derivative_coefficients = np.polyder(s_coefficients)
my_derivative_coefficients = np.polyder(my_coefficients)

s_tg = np.polyval(s_derivative_coefficients, 5)
my_tg = np.polyval(my_derivative_coefficients, 5)
s_tg = '{:.2f}'.format(round(s_tg, 2))
my_tg = '{:.2f}'.format(round(my_tg, 2))
s_residuals_sd = round(s_residuals_sd, 2)
my_residuals_sd = round(my_residuals_sd, 2)

# print('mean values')
# print(f's tangent {s_tg}+-{s_residuals_sd}')
# print(f'my tangent {my_tg}+-{my_residuals_sd}')

fig, ax = plt.subplots(1, 1, figsize=(7, 7))
sns.set_theme()
ax.plot(all_gt_porosities_pix, my_fit_line(all_gt_porosities_pix), color='g', linestyle='--', alpha=0.6, label=f'Automatic linear fit: s = {my_tg}±{my_residuals_sd}')
ax.plot(all_gt_porosities_pix, s_fit_line(all_gt_porosities_pix), color='b', linestyle='--', alpha=0.6, label=f'Semiautomatic linear fit: s = {s_tg}±{s_residuals_sd}')
ax.scatter(all_gt_porosities_pix, all_my_porosities_pix, color='g', marker='o', label=f'Automatic: R² = {my_r_squared}', alpha=np.ones_like(all_gt_porosities_pix))
ax.scatter(all_gt_porosities_pix, all_s_porosities_pix, color='b', marker='s', label=f'Semiautomatic: R² = {s_r_squared}', alpha=np.ones_like(all_gt_porosities_pix))

ax.legend(fontsize='medium')
ax.set_xlabel('Ground truth porosity per image (%)', fontsize=16)
ax.set_ylabel('Detected porosity per image (%)', fontsize=16)
ax.set_title(f'Correlation plot of cell porosity of fenestrations detected\nwith the automatic and semiautomatic method', fontsize=22)
ax.grid(True)
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
# x_ticks = list(range(110, 180+1, 10))
# ax.set_xticks(x_ticks)
# ax.set_yticks(x_ticks)
if save_plots:
    plt.savefig(os.path.join(plot_path, 'corr_porosity_pix.svg'), format='svg')
plt.show()



print('porosity')
print('gt')
for p in all_gt_porosities_ell:
    print(p)
print('s')
for p in all_s_porosities_ell:
    print(p)
print('my')
for p in all_my_porosities_ell:
    print(p)



# Porosity correlation
all_gt_porosities_ell = np.array(all_gt_porosities_ell)
all_s_porosities_ell = np.array(all_s_porosities_ell)
all_my_porosities_ell =  np.array(all_my_porosities_ell)

sorted_indices = np.argsort(all_gt_porosities_ell)
all_gt_porosities_ell = all_gt_porosities_ell[sorted_indices]
all_s_porosities_ell = all_s_porosities_ell[sorted_indices]
all_my_porosities_ell = all_my_porosities_ell[sorted_indices]

my_r_squared = round(r2_score(all_gt_porosities_ell, my_fit_line(all_gt_porosities_ell)), 2)
my_r_squared ='{:.2f}'.format(my_r_squared)

s_r_squared = round(r2_score(all_gt_porosities_ell, s_fit_line(all_gt_porosities_ell)), 2)
s_r_squared ='{:.2f}'.format(s_r_squared)

# Fit a linear function to the data
my_coefficients = np.polyfit(all_gt_porosities_ell, all_my_porosities_ell, 1)
my_fit_line = np.poly1d(my_coefficients)

# Fit a linear function to the data
s_coefficients = np.polyfit(all_gt_porosities_ell, all_s_porosities_ell, 1)
s_fit_line = np.poly1d(s_coefficients)

s_fitted_values = np.polyval(s_coefficients, all_gt_porosities_ell)
s_residuals = all_s_porosities_ell - s_fitted_values
s_residuals_sd = np.std(s_residuals)

my_fitted_values = np.polyval(my_coefficients, all_gt_porosities_ell)
my_residuals = all_my_porosities_ell - my_fitted_values
my_residuals_sd = np.std(my_residuals)


s_derivative_coefficients = np.polyder(s_coefficients)
my_derivative_coefficients = np.polyder(my_coefficients)

s_tg = np.polyval(s_derivative_coefficients, 5)
my_tg = np.polyval(my_derivative_coefficients, 5)
s_tg = '{:.2f}'.format(round(s_tg, 2))
my_tg = '{:.2f}'.format(round(my_tg, 2))
s_residuals_sd = round(s_residuals_sd, 2)
my_residuals_sd = round(my_residuals_sd, 2)

# print('mean values')
# print(f's tangent {s_tg}+-{s_residuals_sd}')
# print(f'my tangent {my_tg}+-{my_residuals_sd}')

fig, ax = plt.subplots(1, 1, figsize=(7, 7))
sns.set_theme()
ax.plot(all_gt_porosities_ell, my_fit_line(all_gt_porosities_ell), color='g', linestyle='--', alpha=0.6, label=f'Automatic linear fit: s = {my_tg}±{my_residuals_sd}')
ax.plot(all_gt_porosities_ell, s_fit_line(all_gt_porosities_ell), color='b', linestyle='--', alpha=0.6, label=f'Semiautomatic linear fit: s = {s_tg}±{s_residuals_sd}')
ax.scatter(all_gt_porosities_ell, all_my_porosities_ell, color='g', marker='o', label=f'Automatic: R² = {my_r_squared}', alpha=np.ones_like(all_gt_porosities_ell))
ax.scatter(all_gt_porosities_ell, all_s_porosities_ell, color='b', marker='s', label=f'Semiautomatic: R² = {s_r_squared}', alpha=np.ones_like(all_gt_porosities_ell))

ax.legend(fontsize='medium')
ax.set_xlabel('Ground truth porosity per image (%)', fontsize=16)
ax.set_ylabel('Detected porosity per image (%)', fontsize=16)
ax.set_title(f'Correlation plot of cell porosity computed with fitted ellipses\nwith the automatic and semiautomatic method', fontsize=22)
ax.grid(True)
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
# x_ticks = list(range(110, 180+1, 10))
# ax.set_xticks(x_ticks)
# ax.set_yticks(x_ticks)
if save_plots:
    plt.savefig(os.path.join(plot_path, 'corr_porosity_ell.svg'), format='svg')
plt.show()


# fen freq
print('frequency')
print('gt')
for p in all_gt_frequencies:
    print(p)
print('s')
for p in all_s_frequencies:
    print(p)
print('my')
for p in all_my_frequencies:
    print(p)


# Porosity correlation
all_gt_frequencies = np.array(all_gt_frequencies)
all_s_frequencies = np.array(all_s_frequencies)
all_my_frequencies =  np.array(all_my_frequencies)

sorted_indices = np.argsort(all_gt_frequencies)
all_gt_frequencies = all_gt_frequencies[sorted_indices]
all_s_frequencies = all_s_frequencies[sorted_indices]
all_my_frequencies = all_my_frequencies[sorted_indices]


# Fit a linear function to the data
my_coefficients = np.polyfit(all_gt_frequencies, all_my_frequencies, 1)
my_fit_line = np.poly1d(my_coefficients)

# Fit a linear function to the data
s_coefficients = np.polyfit(all_gt_frequencies, all_s_frequencies, 1)
s_fit_line = np.poly1d(s_coefficients)

s_fitted_values = np.polyval(s_coefficients, all_gt_frequencies)
s_residuals = all_s_frequencies - s_fitted_values
s_residuals_sd = np.std(s_residuals)

my_fitted_values = np.polyval(my_coefficients, all_gt_frequencies)
my_residuals = all_my_frequencies - my_fitted_values
my_residuals_sd = np.std(my_residuals)

my_r_squared = round(r2_score(all_gt_frequencies, my_fit_line(all_gt_frequencies)), 2)
my_r_squared ='{:.2f}'.format(my_r_squared)

s_r_squared = round(r2_score(all_gt_frequencies, s_fit_line(all_gt_frequencies)), 2)
s_r_squared ='{:.2f}'.format(s_r_squared)


s_derivative_coefficients = np.polyder(s_coefficients)
my_derivative_coefficients = np.polyder(my_coefficients)

s_tg = np.polyval(s_derivative_coefficients, 5)
my_tg = np.polyval(my_derivative_coefficients, 5)
s_tg = '{:.2f}'.format(round(s_tg, 2))
my_tg = '{:.2f}'.format(round(my_tg, 2))
s_residuals_sd = round(s_residuals_sd, 2)
my_residuals_sd = round(my_residuals_sd, 2)


fig, ax = plt.subplots(1, 1, figsize=(7, 7))
sns.set_theme()
ax.plot(all_gt_frequencies, my_fit_line(all_gt_frequencies), color='g', linestyle='--', alpha=0.6, label=f'Automatic linear fit: s = {my_tg}±{my_residuals_sd}')
ax.plot(all_gt_frequencies, s_fit_line(all_gt_frequencies), color='b', linestyle='--', alpha=0.6, label=f'Semiautomatic linear fit: s = {s_tg}±{s_residuals_sd}')
ax.scatter(all_gt_frequencies, all_my_frequencies, color='g', marker='o', label=f'Automatic: R² = {my_r_squared}', alpha=np.ones_like(all_gt_frequencies))
ax.scatter(all_gt_frequencies, all_s_frequencies, color='b', marker='s', label=f'Semiautomatic: R² = {s_r_squared}', alpha=np.ones_like(all_gt_frequencies))

ax.legend(fontsize='medium')
ax.set_xlabel('Fenestration frequency per image\n(fenestrations/μm²)', fontsize=16)
ax.set_ylabel('Detected fenestration frequency per image\n(fenestrations/μm²)', fontsize=16)
ax.set_title(f'Correlation plot of fenestration frequency detected\nwith the automatic and semiautomatic method', fontsize=22)
ax.grid(True)
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
# x_ticks = list(range(110, 180+1, 10))
# ax.set_xticks(x_ticks)
# ax.set_yticks(x_ticks)
if save_plots:
    plt.savefig(os.path.join(plot_path, 'corr_freq.svg'), format='svg')
plt.show()




#     # if remove_false_fenestrations:
#     #     new_mask_filt = remove_fenestrations(new_mask, min_diameter_nm, max_diameter_nm, min_roundness, pixel_size_nm)
#     #     # current_dice_score_filt = compute_dice_score(ground_truth_mask, new_mask_filt)
#     #     # dice_scores_filt.append(current_dice_score_filt)
#     #     # print(f'Image Dice score: {round(current_dice_score*100, 1)}, ({round(current_dice_score_filt*100, 1)})')
#     #     # file.write(f'Image Dice score: {round(current_dice_score*100, 1)}, ({round(current_dice_score_filt*100, 1)})\n')
#     # else:
#     #     # print(f'Image Dice score: {round(current_dice_score*100, 1)}')
#     #     # file.write(f'Image Dice score: {round(current_dice_score*100, 1)}\n')

# # dice_scores = np.array(dice_scores)
# # mean_dice = round(np.mean(dice_scores)*100, 1)
# # std_dice = round(np.std(dice_scores)*100, 1)
# # if remove_false_fenestrations:
# #     dice_scores_filt = np.array(dice_scores_filt)
# #     mean_dice_filt = round(np.mean(dice_scores_filt)*100, 1)
# #     std_dice_filt = round(np.std(dice_scores_filt)*100, 1)
# #     print(f'{model_name} Mean dice: {mean_dice} +- {std_dice} ({mean_dice_filt} +- {std_dice_filt})\n')
# #     file.write(f'{model_name} Mean dice: {mean_dice} += {std_dice} ({mean_dice_filt} +- {std_dice_filt})\n\n')
# # else:
# #     print(f'{model_name} Mean dice: {mean_dice} +- {std_dice}\n')
# #     file.write(f'{model_name} Mean dice: {mean_dice} += {std_dice}\n\n')



