In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
### Adding root directory to sys.path (list of directories python look in for packages and modules)
import sys, os
from pathlib import Path


root_dir = os.path.abspath("../../../")
print("Root dir: ", Path(root_dir).stem)
sys.path.append( root_dir )

Root dir:  pytorch-stardist


In [3]:
import time
from pathlib import Path

from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


import torch


from stardist_tools import calculate_extents, Rays_GoldenSpiral
from stardist_tools.matching import matching, matching_dataset


from src.training import train

from src.data.utils import load_img, save_img

from utils import seed_all, prepare_conf, plot_img_label

from stardist_tools.csbdeep_utils import normalize
from src.models.config import Config3D
from src.models.stardist3d import StarDist3D

### Data

In [4]:
image_dir = r"datasets/demo/test/images"

image_paths = list( Path(image_dir).glob("*.tif") )
len(image_paths)

3

In [5]:
x = load_img(image_paths[0])
x.shape

(64, 128, 128)

In [6]:
n_channel = 1 if x.ndim == 3 else x.shape[-1]

axis_norm = axis_norm = (-1, -2, -3) # # normalize channels independently
#axis_norm = axis_norm = (-1, -2, -3, -4) # # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or -4 in axis_norm else 'independently'))

### Load trained model

In [7]:
conf = Config3D(
    name                           = 'demo',
    use_gpu                        = True if torch.cuda.is_available() else None,
    use_amp                        = True,
    isTrain                        = False ,
    load_epoch                     = "best",
)

model = StarDist3D(conf)

Load path: checkpoints\demo\best.pth cuda:0
Loading threholds ...
Instanciating network
initialize network with normal
Network [StarDistResnet] was created. Total number of parameters: 1.6 million. To see the architecture, do print(network).
<All keys matched successfully>
Loading model from <checkpoints\demo\best.pth>.



In [8]:
model.opt.epoch_count, model.thresholds

(305, {'prob': 0.726973231031345, 'nms': 0.3})

In [9]:
model.opt.rays_json

{'name': 'Rays_GoldenSpiral',
 'kwargs': {'n': 96, 'anisotropy': (2.0, 1.0, 1.0)}}

### Prediction

In [10]:
dest_dir = Path(image_dir) / f"predictions/model_{model.opt.name}_epoch_{conf.load_epoch}"

dest_dir = Path(dest_dir)
os.makedirs(dest_dir, exist_ok=True)

print("Predictions will be saved at: ", dest_dir)

Predictions will be saved at:  datasets\demo\test\images\predictions\model_demo_epoch_best


In [11]:
# If you get out of memory issue, set `patch_size` to a size lower than the size your image and higher than the size on a nuclei
# This will perform inference on patches of size `patch_size` from your image before reconstructing the final instance mask.
# The result will not differ from what you would obtained by doing inference on the whole image.

patch_size = None #(32, 64, 64)

In [12]:
for image_path in tqdm(image_paths):
    image = load_img(image_path)
    if image.ndim==3:
        image = image[np.newaxis]
    
    image = normalize(image, 1,99.8,axis=axis_norm)
    
    pred_mask = model.predict_instance(image, patch_size=patch_size)[0]
    mask_path = dest_dir / f"{image_path.name}"
    
    save_img(mask_path, pred_mask)

100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:07<00:00,  2.46s/it]
