<a href="https://colab.research.google.com/github/marketakvasova/LSEC_segmentation/blob/main/LSEC_fenestration_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Automatic segmentation of electron microscope images**
---

This notebook is intended for segmenting fenestrations in SEM images of Liver sinusoidal entdothelial cells (LSECs).
You can run it either in Colab, or download it and run it on your PC (in VS Code for example).
If you want to run this on your pc, follow the steps described here: https://github.com/marketakvasova/LSEC_segmentation

Download the model weights from here: https://drive.google.com/drive/folders/18O8pFbqFLx34X1dliWbPf9EkqeFO0ASK and save them on your Google Drive or on your pc.

I you are using Colab, you can connect to a GPU in Runtime > Change runtime type > Hardware accelerator
(If connecting to a GPU is not possible, you can use a CPU, it is just ~10x slower.)

---
Run sections 1 and 2 to load the necessary functions and then edit the parameters in the following sections and run them.

(The cell runs when you click on the arrow on the left side of the cell.)

# **1. Import necessary libraries and connect to Drive if you are in Colab**

In [None]:
# @title  { display-mode: "form" }
#@markdown ##**Import necessary libraries and connect to Drive if you are in Colab**
#@markdown In Colab a popup window will appear to connect to Google Drive.
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    %pip install segmentation-models-pytorch
    from google.colab import drive
    drive.mount('/content/gdrive')
    from google.colab.patches import cv2_imshow

from segmentation_models_pytorch import Unet
import os
import torch.cuda
from torch.utils.data import Dataset
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2 as cv
import math

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Running on {DEVICE}.')

# **2. Load necessary functions**
---

In [2]:
# @title  { display-mode: "form" }
#@markdown ##**Load necessary functions**
class MyDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted([f for f in os.listdir(self.image_dir) if os.path.isfile(os.path.join(self.image_dir, f))])
        self.masks = sorted([f for f in os.listdir(self.mask_dir) if os.path.isfile(os.path.join(self.mask_dir, f))])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.masks[index]) # mask and image need to be called the same
        image = cv.imread(img_path, cv.IMREAD_GRAYSCALE).astype(np.float32)
        mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE).astype(np.float32)
        # mask /= 255
        mask[mask == 255.0] = 1

        augmentations = self.transform(image=image, mask=mask)
        image = augmentations["image"]
        mask = augmentations["mask"]

        return image, mask

def normalize_hist(img):
    clahe = cv.createCLAHE(10, tileGridSize=(11, 11))
    img = clahe.apply(img)
    img = cv.medianBlur(img, 3)
    return img

test_transform = A.Compose(
    [
        A.Normalize(
        mean = 0.5,
        std = 0.5,
        max_pixel_value=255.0,
        ),
            ToTensorV2()
    ]
)


def create_weighting_patches(patch_size, edge_size):
    patch = np.ones((patch_size, patch_size), dtype=float)

    # Calculate the linear decrease values
    decrease_values = np.linspace(1, 0, num=edge_size)
    decrease_values = np.tile(decrease_values, (patch_size, 1))
    increase_values = np.linspace(0, 1, num=edge_size)
    increase_values = np.tile(increase_values, (patch_size, 1))

    # Middle patch
    # Apply linear decrease to all four edges
    middle = patch.copy()
    middle[:, 0:edge_size] *= increase_values
    middle[:, patch_size-edge_size:patch_size] *= decrease_values
    middle[0:edge_size, :] *= increase_values.T
    middle[patch_size-edge_size:patch_size, :] *= decrease_values.T

    # Left
    left = patch.copy()
    left[:, patch_size-edge_size:patch_size] *= decrease_values
    left[0:edge_size, :] *= increase_values.T
    left[patch_size-edge_size:patch_size, :] *= decrease_values.T

    # Right
    right = patch.copy()
    right[:, 0:edge_size] *= increase_values
    right[0:edge_size, :] *= increase_values.T
    right[patch_size-edge_size:patch_size, :] *= decrease_values.T

    # Top
    top = patch.copy()
    top[:, 0:edge_size] *= increase_values
    top[:, patch_size-edge_size:patch_size] *= decrease_values
    top[patch_size-edge_size:patch_size, :] *= decrease_values.T

    # Bottom
    bottom = patch.copy()
    bottom[:, 0:edge_size] *= increase_values
    bottom[:, patch_size-edge_size:patch_size] *= decrease_values
    bottom[0:edge_size, :] *= increase_values.T

    # Left Top edge
    top_left = patch.copy()
    top_left[:, patch_size-edge_size:patch_size] *= decrease_values
    top_left[patch_size-edge_size:patch_size, :] *= decrease_values.T

    # Right top edge
    top_right = patch.copy()
    top_right[:, 0:edge_size] *= increase_values
    top_right[patch_size-edge_size:patch_size, :] *= decrease_values.T

    # Left bottom edge
    bottom_left = patch.copy()
    bottom_left[:, patch_size-edge_size:patch_size] *= decrease_values
    bottom_left[0:edge_size, :] *= increase_values.T

    # Right Bottom edge
    bottom_right = patch.copy()
    bottom_right[:, 0:edge_size] *= increase_values
    bottom_right[0:edge_size, :] *= increase_values.T

    return middle, top_left, top, top_right, right, bottom_right, bottom, bottom_left, left


def add_mirrored_border(image, border_size, window_size):
    height, width = image.shape

    bottom_edge = window_size - ((height + border_size) % (window_size - border_size))
    right_edge = window_size - ((width + border_size) % (window_size - border_size))

    top_border = np.flipud(image[0:border_size, :])
    bottom_border = np.flipud(image[height - (border_size+bottom_edge):height, :])
    top_bottom_mirrored = np.vstack((top_border, image, bottom_border))

    left_border = np.fliplr(top_bottom_mirrored[:, 0:border_size])
    right_border = np.fliplr(top_bottom_mirrored[:, width - (border_size+right_edge):width])
    mirrored_image = np.hstack((left_border, top_bottom_mirrored, right_border))
    return mirrored_image

def inference_on_image_with_overlap(model, image_path):
    window_size = 224
    oh, ow = 20, 20

    input_image = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
    image_height, image_width = input_image.shape
    original_height, original_width = image_height, image_width


    mirrored_image = add_mirrored_border(input_image, oh, window_size)
    image_height, image_width = mirrored_image.shape


    weights = np.zeros((image_height, image_width))
    output_probs = np.zeros((image_height, image_width))
    output_mask = np.zeros((image_height, image_width))
    middle, top_left, top, top_right, right, bottom_right, bottom, bottom_left, left = create_weighting_patches(window_size, oh)

    for x in range(0, image_height-window_size+1, window_size - oh):
        for y in range(0, image_width-window_size+1, window_size - ow):
            # Choose weighting window

            if x == 0:
                if y == 0:
                    weighting_window = top_left
                elif y == image_width - window_size:
                    weighting_window = top_right
                else:
                    weighting_window = top
            elif x == image_height - window_size:
                if y == 0:
                    weighting_window = bottom_left
                elif y == image_width - window_size:
                    weighting_window = bottom_right
                else:
                    weighting_window = bottom
            elif y == 0:
                weighting_window = left
            elif y == image_width - window_size:
                weighting_window = right
            else:
                weighting_window = middle
            square_section = mirrored_image[x:x + window_size, y:y + window_size]
            weights[x:x + window_size, y:y + window_size] += weighting_window
            square_section = normalize_hist(square_section)
            square_tensor = test_transform(image=square_section)['image'].unsqueeze(0).to(DEVICE)  # Add batch and channel dimension

            with torch.no_grad():
                output = torch.sigmoid(model(square_tensor)).float()

            # Scale the probablity to 0-255
            output = output*255
            output_pil = output.squeeze(0).cpu().numpy().squeeze()
            output_probs[x:x+window_size, y:y+window_size] += output_pil*weighting_window

    output_probs = output_probs[oh:original_height+oh, ow:original_width+ow]
    weights *= 255

    threshold = int(255*0.4)
    output_mask = np.where(output_probs > threshold, 255, 0)
    output_mask = output_mask.astype(np.uint8)
    return output_mask

def build_model(model_name):
    in_channels = 1
    out_channels = 1
    model = Unet(
            encoder_name=model_name,
            encoder_weights=None,
            in_channels=in_channels,
            classes=out_channels,
            activation=None,).to(DEVICE)
    return model


def remove_contour_from_mask(contour, mask):
    # Fill the contour with black pixels
    cv.drawContours(mask, [contour], -1, 0, thickness=cv.FILLED)
    return mask

def remove_fenestrations(mask, min_d, max_d, min_roundness, pixel_size_nm):
    contours, _ = cv.findContours(mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    fenestration_areas = [cv.contourArea(cnt) * (pixel_size_nm**2) for cnt in contours]
    contour_centers = find_contour_centers(contours)
    ellipses, num_ellipses = fit_ellipses(contours, contour_centers)
    roundness_of_ellipses = []
    equivalent_diameters = []
    fenestration_areas_from_ellipses = []

    for contour, ellipse in zip(contours, ellipses):
        if ellipse != (None, None, None) and ellipse is not None:
            center, axes, _ = ellipse
            minor_axis_length, major_axis_length = axes
            if major_axis_length != 0 and major_axis_length < 20*minor_axis_length:
                roundness = minor_axis_length/major_axis_length
                if roundness >= min_roundness:
                    roundness_of_ellipses.append(roundness)
                diameter = pixel_size_nm * equivalent_circle_diameter(major_axis_length, minor_axis_length)

                if (diameter < min_d or diameter > max_d) or  (roundness < min_roundness) or np.isnan(diameter):
                    mask = remove_contour_from_mask(contour, mask)
                else:
                    equivalent_diameters.append(diameter)
                    fenestration_areas_from_ellipses.append((diameter**2)/4*math.pi)
            else:
                mask = remove_contour_from_mask(contour, mask)
        else:
            mask = remove_contour_from_mask(contour, mask)
    return mask



def fit_ellipses(filtered_contours, centers):
    ellipses = []
    num_ellipses = 0
    for contour, cnt_center in zip(filtered_contours, centers):
        if len(contour) >= 5:  # Ellipse fitting requires at least 5 points
            ellipse = cv.fitEllipse(contour) # TODO: maybe try a different computation, if this does not work well on edges (probably ok)
            dist = cv.norm(cnt_center, ellipse[0])
            if dist < 20:
                ellipses.append(ellipse)
                num_ellipses += 1
            else:
                ellipses.append((None, None, None))
        else:
            ellipses.append((None, None, None))
    # print(len(filtered_contours), len(ellipses))
    return ellipses, num_ellipses

def find_fenestration_contours(image_path):
    seg_mask = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
    contours, _ = cv.findContours(seg_mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    return contours

def find_contour_centers(contours):
    contour_centers = []
    for cnt in contours:
        M = cv.moments(cnt)
        center_x = int(M['m10'] / (M['m00'] + 1e-10))
        center_y = int(M['m01'] / (M['m00'] + 1e-10))
        contour_centers.append((center_x, center_y))
    return contour_centers

def equivalent_circle_diameter(major_axis_length, minor_axis_length):
    return math.sqrt(major_axis_length * minor_axis_length)



# **3. Insert input and output folders and the path of model weights**
---

In [3]:
# @title  { display-mode: "form" }
#@markdown All Google Drive paths should start with ./gdrive/MyDrive/ (Check the folder structure in the left sidebar under **Files**).

# -----------------------------------------------------------------------------#
# |                       CHANGE THESE PARAMETERS                             |#
# -----------------------------------------------------------------------------#

#!!! If running locally on Windows, use / or \\ as folder separator !!! (not \)

#@markdown Insert folder containing LSEC images:
input_folder = './gdrive/MyDrive/lsecs/images' #@param {type:"string"}
#@markdown Insert where to save the output masks
#@markdown (the folder will be created if it does not exist yet)
#@markdown If the folder contains images, they may be overwritten:
output_folder = './gdrive/MyDrive/lsecs/my_masks' #@param {type:"string"}
#@markdown Insert model weights path:
model_path = './gdrive/MyDrive/lsecs/model_weights.pth' #@param {type:"string"}
# -----------------------------------------------------------------------------#



model_path = model_path.strip()
input_folder = input_folder.strip()
output_folder = output_folder.strip()


model = build_model('resnet34')
if torch.cuda.is_available():
    loaded_state_dict = torch.load(model_path) # TODO this is without sigmoid, it is applied in the inference loop
    model.load_state_dict(loaded_state_dict)

else:
    loaded_state_dict = torch.load(model_path, map_location=torch.device('cpu')) # TODO this is without sigmoid, it is applied in the inference loop
    model.load_state_dict(loaded_state_dict)
model.eval()


if not os.path.exists(input_folder):
    print("Input folder does not exist")
    # exit()
if not os.path.exists(output_folder):
    os.makedirs(output_folder)
    print(f'Created folder {output_folder}')



# **4. Run image segmentation**
---

In [None]:
# @title  { display-mode: "form" }
#@markdown ##**Run this cell to segment images:**
#@markdown You can choose, if you want to remove objects from the masks based on their parameters.


#@markdown If the remove_fenestrations_based_on_params box is not checked, no objects will be removed from the segmented masks.

# -----------------------------------------------------------------------------#
# |                       CHANGE THESE PARAMETERS                             |#
# -----------------------------------------------------------------------------#
# If this is False, no fenestrations will be removed
# If this is True, fenestrations will be removed based on the following parameters
remove_fenestrations_based_on_params = False # @param {type:"boolean"}
pixel_size_nm = 9.28 #@param {type:"number"}
min_diameter_nm = 50 #@param {type:"number"}
max_diameter_nm = 350 #@param {type:"number"}
min_roundness = 0.4 # @param {type:"slider", min:0, max:1, step:0.1}
# -----------------------------------------------------------------------------#



#@markdown Roundness is computed as minor axis length/major axis length of a fitted ellipse.
image_names = [f for f in sorted(os.listdir(input_folder)) if os.path.isfile(os.path.join(input_folder, f))]
for image_name in image_names:
    print(image_name)
    image_path = os.path.join(input_folder, image_name)
    out_mask = inference_on_image_with_overlap(model, image_path)
    if remove_fenestrations_based_on_params:
        out_mask = remove_fenestrations(out_mask, min_diameter_nm, max_diameter_nm, min_roundness, pixel_size_nm)
    filename_ext = os.path.basename(image_name)
    filename, ext = os.path.splitext(filename_ext)
    out = os.path.join(output_folder, filename+'_mask'+ext)
    cv.imwrite(out, out_mask)
    print(f'Saving {out}')

# **5. Apply cell masks**
---

In [None]:
# @title  { display-mode: "form" }
#@markdown ##**Insert folder with cell masks:**
#@markdown You can apply cell masks on the segmented masks, if you have them.


# -----------------------------------------------------------------------------#
# |                       CHANGE THESE PARAMETERS                             |#
# -----------------------------------------------------------------------------#

#!!! If running locally on Windows, use / or \\ as folder separator !!! (not \)

cell_masks = './gdrive/MyDrive/' #@param {type:"string"}
#@markdown If you want to replace the old masks, check this box. If not, write the new output folder into **new_output_folder**.
# If this is True, the masks will be rewritten
# If this is False, write where to save the mask into new_output_folder
rewrite_old_masks = False # @param {type:"boolean"}
new_output_folder = './gdrive/MyDrive/' #@param {type:"string"}
# -----------------------------------------------------------------------------#



cell_masks = cell_masks.strip()
new_output_folder = new_output_folder.strip()

image_names = [f for f in sorted(os.listdir(output_folder)) if os.path.isfile(os.path.join(output_folder, f))]
mask_names = [f for f in sorted(os.listdir(cell_masks)) if os.path.isfile(os.path.join(cell_masks, f))]

def apply_cell_mask(image_path, mask_path):
    image = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
    cell_mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE)
    image[cell_mask == 0] = 0
    return image

for image_name, mask_name in zip(image_names, mask_names):
    print(f'{image_name} - {mask_name}')
    image_path = os.path.join(output_folder, image_name)
    mask_path = os.path.join(cell_masks, mask_name)
    image_with_cell_mask = apply_cell_mask(image_path, mask_path)
    if rewrite_old_masks:
        cv.imwrite(output_folder, image_with_cell_mask)
    else:
        cv.imwrite(new_output_folder, image_with_cell_mask)
