# Stardist perform inference

In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import imageio
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# !pip install csbdeep
# !pip install stardist
# !pip install gputools
# !pip install scikit-tensor-py3

import napari

from glob import glob
from tqdm import tqdm
from tifffile import imread
import csbdeep
from csbdeep.utils import Path, download_and_extract_zip_file

from stardist import fill_label_holes, relabel_image_stardist, random_label_cmap, calculate_extents, gputools_available
from stardist.matching import matching, matching_dataset
from stardist.models import Config2D, StarDist2D, StarDistData2D

import cv2
import os

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import Model, layers, models

from tensorflow.keras.utils import Sequence
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.preprocessing.image import Iterator, ImageDataGenerator

from tensorflow.keras.utils import Sequence
from tensorflow.keras.preprocessing.image import Iterator, ImageDataGenerator
import tensorflow.keras.backend as K

print(tf.__version__)
print(tf.test.is_built_with_cuda()) 
print(tf.config.list_physical_devices('GPU'))

import skimage.transform

from PIL import Image

# tf.config.gpu.set_per_process_memory_fraction(0.80)
# tf.config.gpu.set_per_process_memory_growth(True)

np.random.seed(42)
lbl_cmap = random_label_cmap()

The version installed is 5.9.7. Please report any issues with this specific QT version at https://github.com/Napari/napari/issues.
  warn(message=warn_message)


2.2.0
True
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [13]:
def preprocess(imgs, target_shape=(512, 512), normalize=True, axis_norm=(0, 1)):
    
#     assert len(imgs.shape) == 4  # batch preprocessing only
    
    # perf reshape
    if target_shape != (imgs[0].shape[0], imgs[0].shape[1]):
        bat = []
        for im in imgs:
#             resized = skimage.transform.resize(img, target_shape)
            resized = cv2.resize(im, target_shape, 0, 0, interpolation=cv2.INTER_NEAREST)
            bat.append(resized)
        imgs = np.array(bat)
        
    if normalize:
        bat = []
        for im in imgs:
            bat.append( (im - im.min(axis=(0, 1))) / (im.max(axis=(0, 1)) - im.min(axis=(0, 1))) )
        imgs = np.array(bat)
        
    return imgs

def random_fliprot(img, mask): 
    assert img.ndim >= mask.ndim
    axes = tuple(range(mask.ndim))
    perm = tuple(np.random.permutation(axes))
    img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) 
    mask = mask.transpose(perm) 
    for ax in axes: 
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask 

def random_intensity_change(img):
    img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
    return img

def augment(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    x, y = random_fliprot(x, y)
    x = random_intensity_change(x)
    # add some gaussian noise
    sig = 0.02*np.random.uniform(0,1)
    x = x + sig*np.random.normal(0,1,x.shape)
    return x, y

def visualize_data(bf, masks, nc_ims=1, nc_masks=1):
    viewer = napari.Viewer()
    if nc_ims == 1:
        viewer.add_image(bf[:, :, :])
    else:
        for k in range(nc_ims):
            viewer.add_image(bf[:, :, :, k], blending="additive")
    
    viewer.add_image(masks[:, :, :], blending="additive")

In [14]:
import re 

def alphanumeric_sort( l ): 
    """ Sort the given iterable in the way that humans expect.""" 
    convert = lambda text: int(text) if text.isdigit() else text 
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(l, key = alphanum_key)

In [20]:
data_path, rfp_path = f"D:\Hugo\Data\H449.1/f0_BF", f"D:\Hugo\Data\H449.1/f0_RFP"

target_dim, normalize_ims = (512, 512), True

print(f"Target shape: {target_dim}")
print(f"Normalize images: {normalize_ims}.")

# Load from dir
X_bf = np.expand_dims(preprocess(np.array([imageio.imread(f"{data_path}/{im}", format="tif") for im in alphanumeric_sort(os.listdir(data_path))]), target_shape=target_dim, normalize=True), axis=-1)
X_rfp = np.expand_dims(preprocess(np.array([imageio.imread(f"{data_path}/{im}", format="tif") for im in alphanumeric_sort(os.listdir(data_path))]), target_shape=target_dim, normalize=True), axis=-1)
print(X_bf.shape, X_rfp.shape)

X = np.concatenate([X_rfp, X_bf], axis=-1)

# Load from stack
# X = preprocess(data_path, target_shape=target_dim, normalize=True)

print("Loaded data.")
print(X.shape, X.min(), X.max())

            
plot = True
if plot:    
    viewer = napari.view_image(X, nc_ims=2)

Target shape: (512, 512)
Normalize images: True.
(689, 512, 512, 1) (689, 512, 512, 1)
Loaded data.
(689, 512, 512, 2) 0.0 1.0


# Load model

In [7]:
from stardist.models import StarDist2D

model_path = "D:/Hugo/BiSeg/Models/BSd125"
model = StarDist2D(None, name=model_path)

Loading network weights from 'weights_best.h5'.
Couldn't load thresholds from 'thresholds.json', using default values. (Call 'optimize_thresholds' to change that.)
Using default values: prob_thresh=0.5, nms_thresh=0.4.


# Make predictions

In [8]:
predictions = []
for im in X:
    label, _ = model.predict_instances(im)
    predictions.append(label)

predictions = np.array(predictions)

ValueError: axes (YXC) must be of length 2.

In [7]:
viewer = napari.Viewer()
viewer.add_image(X, colormap="gray")
viewer.add_image(predictions, blending="additive", colormap="twilight_shifted")

<Image layer 'predictions' at 0x298ea106e20>

In [8]:
save_path = "D:\Hugo\Anaphase\Inter_Div_Correlation\H449.1/Sd32_H449.1_f0.tif"
imageio.volwrite(save_path, predictions)