In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, download_and_extract_zip_file

from stardist import relabel_image_stardist3D, Rays_GoldenSpiral, calculate_extents
from stardist import fill_label_holes, random_label_cmap
from stardist.matching import matching_dataset

np.random.seed(42)
lbl_cmap = random_label_cmap()

In [None]:
import sys

In [None]:
sys.path.append('/u/ejelli/jupyter/Notebooks/stardist_mpcdf')

In [None]:
from stardist_mpcdf.data import readDataset

In [None]:
import yaml

with open(r'../../config.yaml') as file:
    config = yaml.load(file, Loader=yaml.FullLoader)

In [None]:
X, Y = readDataset('patches-semimanual-raw-64x128x128')

In [None]:
len(X['train'])

In [None]:
X = X['train']
Y = Y['train']

In [None]:
#X = config['napari_raw_files']
#Y = config['napari_label_files']

In [None]:
#X, Y = X[:10], Y[:10]

In [None]:
#X = list(map(imread,X))
#Y = list(map(imread,Y))

In [None]:
extents = calculate_extents(Y)
anisotropy = tuple(np.max(extents) / extents)
print('empirical anisotropy of labeled objects = %s' % str(anisotropy))

In [None]:
i = 3
img, lbl = X[i], fill_label_holes(Y[i])
assert img.ndim in (3,4)
# assumed axes ordering of img and lbl is: ZYX(C)

In [None]:
plt.figure(figsize=(16,10))
z = img.shape[0] // 2
y = img.shape[1] // 2
plt.subplot(121); plt.imshow(img[z],cmap='gray');   plt.axis('off'); plt.title('Raw image (XY slice)')
plt.subplot(122); plt.imshow(lbl[z],cmap=lbl_cmap); plt.axis('off'); plt.title('GT labels (XY slice)')
plt.figure(figsize=(16,10))
plt.subplot(121); plt.imshow(img[:,y],cmap='gray');   plt.axis('off'); plt.title('Raw image (XZ slice)')
plt.subplot(122); plt.imshow(lbl[:,y],cmap=lbl_cmap); plt.axis('off'); plt.title('GT labels (XZ slice)')
None;

In [None]:
def reconstruction_scores(n_rays, anisotropy):
    scores = []
    for r in tqdm(n_rays):
        rays = Rays_GoldenSpiral(r, anisotropy=anisotropy)
        Y_reconstructed = [relabel_image_stardist3D(lbl, rays) for lbl in Y]
        mean_iou = matching_dataset(Y, Y_reconstructed, thresh=0, parallel=False, show_progress=False).mean_true_score
        scores.append(mean_iou)
    return scores

In [None]:
n_rays = [8, 16, 32, 64, 96, 128, 192, 256]
scores_iso   = reconstruction_scores(n_rays, anisotropy=None)
scores_aniso = reconstruction_scores(n_rays, anisotropy=anisotropy)

In [None]:
plt.figure(figsize=(8,5))
plt.plot(n_rays, scores_iso,   'o-', label='Isotropic')
plt.plot(n_rays, scores_aniso, 'o-', label='Anisotropic')
plt.xlabel('Number of stardist rays')
plt.ylabel('Mean intersection over union')
plt.legend()
plt.grid()
plt.savefig('stardist_ray_number.eps')
plt.savefig('stardist_ray_number.png')
None;

In [None]:
np.save('scores_iso.npz', scores_iso)
np.save('scores_aniso.npz', scores_aniso)