In [84]:
# pip install gputools
# pip install scikit-tensor-py3

from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = 'nearest'
import matplotlib.pyplot as plt
import os
from PIL import Image
import tensorflow as tf
from glob import glob
from tqdm import tqdm
from tifffile import imread
import struct
import cv2
import pandas as pd

from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist.models import Config2D, StarDist2D, StarDistData2D

In [85]:
pth_models = r'\\10.99.68.178\andreex\data\Stardist\12_12_hyperparameter_models'  # place to save models

pth_training_HE = r'\\10.99.68.178\andreex\data\Stardist\Train_Val_Splits\Training\HE_tiles'  # change this later
pth_validation_HE = r'\\10.99.68.178\andreex\data\Stardist\Train_Val_Splits\Validation\HE_tiles'
pth_testing_HE = r'\\10.99.68.178\andreex\data\Stardist\Testing\monkey\tiles'

pth_training_masks = r'\\10.99.68.178\andreex\data\Stardist\Train_Val_Splits\Training\masks'
pth_validation_masks = r'\\10.99.68.178\andreex\data\Stardist\Train_Val_Splits\Validation\masks'
pth_testing_masks = r'\\10.99.68.178\andreex\data\Stardist\Testing\monkey\qupath\ground_truth\masks'

Helper functions

In [86]:
def augment_images(HE_tiles, mask_tiles):
  """
  Augments a set of HE images and corresponding mask labels by rotating and flipping them.

  Args:
    HE_tiles: A list of NumPy arrays representing the HE images.
    mask_tiles: A list of NumPy arrays representing the mask labels for the HE images.

  Returns:
    HE_aug: A list of NumPy arrays representing the augmented HE images.
    mask_aug: A list of NumPy arrays representing the augmented mask labels.
  """

  HE_aug = [[] for _ in range(len(HE_tiles))]
  mask_aug = [[] for _ in range(len(mask_tiles))]

  for i in range(len(HE_tiles)):
    im = Image.fromarray(HE_tiles[i])
    lbl = Image.fromarray(mask_tiles[i])

    # Rotate the image and label 90 degrees three times.
    for _ in range(3):
      im = im.rotate(90)
      HE_aug[i].append(im)
      lbl = lbl.rotate(90)
      mask_aug[i].append(lbl)

    # Flip the image and label horizontally.
    im = Image.fromarray(HE_tiles[i])
    flipped_im = im.transpose(Image.FLIP_LEFT_RIGHT)

    lbl = Image.fromarray(mask_tiles[i])
    flipped_lbl = lbl.transpose(Image.FLIP_LEFT_RIGHT)

    HE_aug[i].append(flipped_im)
    mask_aug[i].append(flipped_lbl)

    # Rotate the flipped image and label 90 degrees three times.
    for _ in range(3):
      flipped_im = flipped_im.rotate(90)
      HE_aug[i].append(flipped_im)
      flipped_lbl = flipped_lbl.rotate(90)
      mask_aug[i].append(flipped_lbl)

  return HE_aug, mask_aug

In [87]:
import random
import math

def get_random_indices(images, ratio_blurred):
    num_tiles = len(images)
    num_blurred = math.floor(ratio_blurred*num_tiles)
    # print(num_blurred)

    random.seed(0)
    random_indices = [random.randint(0, num_tiles) for _ in range(num_blurred)]
    # print(random_indices)
    return random_indices, num_tiles

In [88]:
def blur_images(images, ratio=0.1, radius=2):
    """blur 10 percent of training tiles"""
    random_indices, num_tiles = get_random_indices(images, ratio)
    for i in range(num_tiles):
        if i in random_indices:
            images[i] = images[i].filter(ImageFilter.GaussianBlur(radius=2))

    return images

Functions that get tiles below

In [89]:
def get_training_set(pth_training_HE, pth_training_masks, gaussian_ratio=0.1, gaussian_radius=2):
    training_HE_tiles_pths = sorted(glob(os.path.join(pth_training_HE,'*.tif')))
    training_mask_tiles_pths = sorted(glob(os.path.join(pth_training_masks,'*.tif')))

    HE_original_tiles = list(map(imread,training_HE_tiles_pths))
    mask_original_tiles = list(map(imread,training_mask_tiles_pths))

    HE_aug, mask_aug = augment_images(HE_original_tiles, mask_original_tiles)
    HE_trn_tiles_raw = [im for im_list in HE_aug for im in im_list]  # flatten
    HE_trn_masks = [np.array(im) for im_list in mask_aug for im in im_list]  # flatten

    HE_trn_masks = [fill_label_holes(y) for y in HE_trn_masks]  # fills holes in annotations if there are any

    if gaussian_ratio > 0:
        HE_trn_tiles_raw = blur_images(HE_trn_tiles_raw, ratio=gaussian_ratio, radius=gaussian_radius)

    HE_trn_tiles = [np.array(im)/255 for im in HE_trn_tiles_raw]  # normalize

    return HE_trn_tiles, HE_trn_masks

In [90]:
def get_validation_set(pth_validation_HE, pth_validation_masks):
    validation_HE_tiles_pths = sorted(glob(os.path.join(pth_validation_HE,'*.tif')))
    validation_mask_tiles_pths = sorted(glob(os.path.join(pth_validation_masks,'*.tif')))

    HE_val_tiles = list(map(imread,validation_HE_tiles_pths))
    HE_val_masks = list(map(imread,validation_mask_tiles_pths))

    HE_val_tiles = [np.array(im) for im in HE_val_tiles]
    HE_val_tiles = [im/255 for im in HE_val_tiles]  # normalize

    HE_val_masks = [np.array(im) for im in HE_val_masks]
    HE_val_masks = [fill_label_holes(y) for y in HE_val_masks]  # fills holes in annotations if there are any

    return HE_val_tiles, HE_val_masks

In [91]:
def get_testing_set(pth_testing_HE, pth_testing_masks):
    testing_HE_tiles_pths = sorted(glob(os.path.join(pth_testing_HE,'*.tif')))
    testing_mask_tiles_pths = sorted(glob(os.path.join(pth_testing_masks,'*.tif')))

    HE_testing_tiles = list(map(imread,testing_HE_tiles_pths))
    HE_testing_masks = list(map(imread,testing_mask_tiles_pths))

    HE_testing_tiles = [np.array(im) for im in HE_testing_tiles]
    HE_testing_tiles = [im/255 for im in HE_testing_tiles]  # normalize

    HE_testing_masks = [np.array(im) for im in HE_testing_masks]
    HE_testing_masks = [fill_label_holes(y) for y in HE_testing_masks]  # fills holes in annotations if there are any

    return HE_testing_tiles, HE_testing_masks

In [92]:
HE_trn_tiles_2, HE_trn_masks_2 = get_training_set(pth_training_HE, pth_training_masks, gaussian_ratio=0.1, gaussian_radius=2)

In [93]:
print(len(HE_trn_tiles_2))
print(len(HE_trn_masks_2))

189
189


In [94]:
HE_val_tiles_2, HE_val_masks_2 = get_validation_set(pth_validation_HE, pth_validation_masks)

In [95]:
print(len(HE_val_tiles_2))
print(len(HE_val_masks_2))

12
12


In [96]:
HE_testing_tiles_2, HE_testing_masks_2 = get_testing_set(pth_testing_HE, pth_testing_masks)

KeyboardInterrupt: 

In [None]:
print(len(HE_testing_tiles_2))
print(len(HE_testing_masks_2))

Show some images and make sure everything is ok

In [None]:
# Plot image and label for some images - sanity check
def show_tile_segmented(tile, segmented, **kwargs):
    """Plot large image at different resolutions."""
    fig, ax = plt.subplots(1, 2, figsize=(16/2, 8/2))

    # Plot the original image on the left
    ax[0].imshow(tile, **kwargs)

    # Plot the cropped image on the right
    ax[1].imshow(segmented, **kwargs)

    ax[0].axis('off')
    ax[1].axis('off')

    plt.tight_layout()
    plt.show()

i = 0

#Random color map labels
np.random.seed(42)
lbl_cmap = random_label_cmap()

In [None]:
i += 1

print(i)

img, lbl = HE_trn_tiles_2[i], HE_trn_masks_2[i]
show_tile_segmented(img,lbl)

In [None]:
i = 0

In [None]:
# look at validation tiles
img, lbl = HE_val_tiles_2[i], HE_val_masks_2[i]
show_tile_segmented(img,lbl)
i += 1

Configure Model

In [None]:
if tf.config.list_physical_devices('GPU'):
    print("GPU is available")
else:
    print("GPU is not available")

print(tf.__version__)

#Define the config by setting some parameter values
# 32 is a good default choice (see 1_data.ipynb)
n_rays = 32  #Number of radial directions for the star-convex polygon.

# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = True and gputools_available()

# Predict on subsampled grid for increased efficiency and larger field of view
grid = (2,2)

n_channel = 3  # change if not using rbg images

conf = Config2D (
    n_rays       = n_rays,
    grid         = grid,
    use_gpu      = use_gpu,
    n_channel_in = n_channel,
)
# print(conf)
vars(conf)

Functions to load models

In [None]:
#Start from 40x H&E pre-trained model to the specified directory
import copy
import json

def load_justin_model(folder_to_write_new_model_folder: str, name_for_new_model: str) -> StarDist2D:
    published_model = load_model(r'\\10.99.68.178\andreex\students\Donald Monkey fetus\stardist\Model_00')
    original_thresholds = copy.copy({'prob': published_model.thresholds[0], 'nms': published_model.thresholds[1]})
    configuration = Config2D(n_channel_in=3, grid=(2,2), use_gpu=True, train_patch_size=[256, 256])
    model = StarDist2D(config=configuration, basedir=folder_to_write_new_model_folder, name=name_for_new_model)
    model.keras_model.set_weights(published_model.keras_model.get_weights())
    model.thresholds = original_thresholds
    return model

def load_published_he_model(folder_to_write_new_model_folder: str, name_for_new_model: str) -> StarDist2D:
    published_model = StarDist2D.from_pretrained('2D_versatile_he')
    original_thresholds = copy.copy({'prob': published_model.thresholds[0], 'nms': published_model.thresholds[1]})
    configuration = Config2D(n_channel_in=3, grid=(2,2), use_gpu=True, train_patch_size=[256, 256])
    model = StarDist2D(config=configuration, basedir=folder_to_write_new_model_folder, name=name_for_new_model)
    model.keras_model.set_weights(published_model.keras_model.get_weights())
    model.thresholds = original_thresholds
    return model

def load_model(model_path: str) -> StarDist2D:
    # Load StarDist model weights, configurations, and thresholds
    with open(model_path + '\\config.json', 'r') as f:
        config = json.load(f)
    with open(model_path + '\\thresholds.json', 'r') as f:
        thresh = json.load(f)
    model = StarDist2D(config=Config2D(**config), basedir=model_path, name='offshoot_model')
    model.thresholds = thresh
    print('Overriding defaults:', model.thresholds, '\n')
    model.load_weights(model_path + '\\weights_best.h5')
    return model

def configure_model_for_training(model: StarDist2D,
                                 epochs: int = 25, learning_rate: float = 1e-6,
                                 batch_size: int = 4, patch_size: list[int,int] = [256, 256]) -> StarDist2D:
    model.config.train_epochs = epochs
    model.config.train_learning_rate = learning_rate
    model.config.train_batch_size = batch_size
    model.config.train_patch_size = patch_size
    return model

In [None]:
from tensorflow.python.summary.summary_iterator import summary_iterator

def get_loss_data(pth_training_log, pth_out) -> list:

    # event_file = r"\\10.99.68.178\andreex\data\Stardist\models\monkey_ft_11_02_2023_lr_1e-4_epochs_200_pt_10\logs\train\events.out.tfevents.1698964029.WPC-C13.20400.7.v2"

    loss_values = []

    for summary in summary_iterator(pth_training_log):
        for value in summary.summary.value:
            if value.tag == 'epoch_loss':
                loss = struct.unpack('f', value.tensor.tensor_content)[0]
                loss_values.append(loss)

    out_txt_name = f"{pth_out}\loss.txt"

    with open(out_txt_name, 'w') as f:
        f.write('\n'.join(map(str, loss_values)) + '\n')

    return loss_values

Define hyperparameters, train

In [None]:
import itertools

lrs = [5e-4, 1e-3, 3e-3, 5e-3]
epochs_strs = [100, 200, 400]
pts = [10, 20, 30, 40]
gaussian_ratios = [0, 0.1]


# Create a list of all the possible combinations of the hyperparameters
list_trainings = list(itertools.product(lrs, epochs_strs, pts, gaussian_ratios))

losses_best = []
losses_last = []

for i in range(len(list_trainings)):
    lr=list_trainings[i][0]
    epochs=list_trainings[i][1]
    pt=list_trainings[i][2]
    gaussian_ratio=list_trainings[i][3]
    nm='monkey_'
    dt='12_12_2023'
    outnm = nm + dt + '_lr_' + str(lr) + '_epochs_' + str(epochs) + '_pt_' + str(pt) + '_gaus_ratio_' + str(gaussian_ratio)
    print(outnm)
    print(f'({i}/{len(list_trainings)})')

    HE_trn_tiles, HE_trn_masks = get_training_set(pth_training_HE, pth_training_masks, gaussian_ratio=gaussian_ratio, gaussian_radius=2)
    HE_val_tiles, HE_val_masks = get_validation_set(pth_validation_HE, pth_validation_masks)

    model = load_published_he_model(pth_models, outnm)
    model.config.train_learning_rate = lr
    model.config.train_patch_size = (256,256)
    model.config.train_reduce_lr={'factor': 0.5, 'patience': pt, 'min_delta': 0}
    model.train(HE_trn_tiles, HE_trn_masks, validation_data=(HE_val_tiles,HE_val_masks), epochs=epochs, steps_per_epoch=100)
    model.optimize_thresholds(HE_val_tiles,HE_val_masks)

    pth_log_train = fr"\\10.99.68.178\andreex\data\Stardist\models\{outnm}\logs\train"

    pth_log = glob(os.path.join(pth_log_train,'*.v2'))[0]
    loss = get_loss_data(str(pth_log), pth_log_train)

    losses_best.append(min(loss))
    losses_last.append(loss[-1])