## 3D Prediction using StarDist

Repo: https://github.com/stardist/stardist?tab=readme-ov-file

## Installation

In [1]:
%pip install stardist

Collecting stardist
  Downloading stardist-0.9.1-cp310-cp310-win_amd64.whl.metadata (21 kB)
Collecting csbdeep>=0.8.0 (from stardist)
  Downloading csbdeep-0.8.0-py2.py3-none-any.whl.metadata (2.4 kB)
Collecting numba (from stardist)
  Downloading numba-0.60.0-cp310-cp310-win_amd64.whl.metadata (2.8 kB)
Collecting llvmlite<0.44,>=0.43.0dev0 (from numba->stardist)
  Downloading llvmlite-0.43.0-cp310-cp310-win_amd64.whl.metadata (4.9 kB)
Downloading stardist-0.9.1-cp310-cp310-win_amd64.whl (786 kB)
   ---------------------------------------- 0.0/786.1 kB ? eta -:--:--
   ---------- ----------------------------- 204.8/786.1 kB 4.1 MB/s eta 0:00:01
   ---------------------------------------- 786.1/786.1 kB 9.8 MB/s eta 0:00:00
Downloading csbdeep-0.8.0-py2.py3-none-any.whl (71 kB)
   ---------------------------------------- 0.0/71.3 kB ? eta -:--:--
   ---------------------------------------- 71.3/71.3 kB 3.8 MB/s eta 0:00:00
Downloading numba-0.60.0-cp310-cp310-win_amd64.whl (2.7 MB)
   -

## Libraries

In [1]:
import re
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['image.interpolation'] = 'none'
%matplotlib inline

from glob import glob
import nibabel as nib
from csbdeep.utils import Path, normalize
from csbdeep.io import save_tiff_imagej_compatible

from stardist import random_label_cmap
from stardist.plot import render_label
from stardist.models import StarDist3D

### Data

In [2]:
X_names = sorted(glob('../data/Gr4/RawImages/Nuclei/*.nii.gz'))
X = list(map(nib.load, X_names))
X = [x.get_fdata() for x in X]

Load a previously trained model (demo model for 3D nuclei segmentation) and apply it to the data.

In [3]:
model = StarDist3D.from_pretrained('3D_demo')

Found model '3D_demo' for 'StarDist3D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.707933, nms_thresh=0.3.


### Prediction 

Make sure to normalize the input image beforehand or supply a ``normalizer`` to the prediction function.

Calling ``model.predict_instances`` will

- predict object probabilities and star-convex polygon distances (see ``model.predict`` if you want those)
- perform non-maximum suppression (with overlap threshold ``nms_thresh``) for polygons above object probability threshold ``prob_thresh``.
- render all remaining polygon instances in a label image
- return the label instances image and also the details (coordinates, etc.) of all remaining polygons

In [4]:
idx = 2

In [None]:
labels, details = model.predict_instances(X[idx])

In [None]:
segmented_image = nib.Nifti1Image(labels, X[idx].affine)
file_name = re.search(r'([^/\\]+)(?=\.\w+\.\w+$)', X_names[0]).group()
nib.save(segmented_image, f'../data/Gr4/Predictions/Stardist/{file_name}.nii.gz')

In [None]:
plt.figure(figsize=(16,10))
plt.subplot(121); plt.imshow(X[idx], cmap='gray'); plt.axis('off'); plt.title('Raw image')
plt.subplot(122); plt.imshow(render_label(labels, img=X[idx])); plt.axis('off'); plt.title('Predicted labels')
plt.tight_layout()
plt.show()

In [None]:
# plt.figure(figsize=(13,10))
# z = max(0, img.shape[0] // 2 - 5)
# plt.subplot(121)
# plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')
# plt.title('Raw image (XY slice)')
# plt.axis('off')
# plt.subplot(122)
# plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')
# plt.imshow(labels[z], cmap=lbl_cmap, alpha=0.5)
# plt.title('Image and predicted labels (XY slice)')
# plt.axis('off');