In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
### Adding root directory to sys.path (list of directories python look in for packages and modules)
import sys, os

root_dir = os.path.abspath("../..")
sys.path.append( root_dir )

In [4]:
import time
from pathlib import Path

from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


import torch


from stardist_tools import calculate_extents, Rays_GoldenSpiral#, random_label_cmap, relabel_image_stardist3D
from stardist_tools.matching import matching, matching_dataset


from src.training import train

from src.data.utils import load_tif, save_tif

from utils import seed_all, prepare_conf, plot_img_label

from stardist_tools.csbdeep_utils import normalize
from src.models.config import ConfigBase, Config2D
from src.models.stardist2d import StarDist2D

### Data

In [5]:
image_dir = r"datasets\dsb2018/test/images"

image_paths = list( Path(image_dir).glob("*.tif") )
len(image_paths)

50

In [6]:
x = load_tif(image_paths[0])
x.shape

(256, 256)

In [7]:
n_channel = 1 if x.ndim == 3 else x.shape[-1]

axis_norm = axis_norm = (-1, -2) # # normalize channels independently
#axis_norm = axis_norm = (-1, -2, -3) # # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or -3 in axis_norm else 'independently'))

Normalizing image channels independently.


In [8]:
#list( Path(r"./checkpoints/dsb2018").glob('*') )

### Load trained model

In [9]:
conf = Config2D(
    name                           = 'dsb2018',
    use_gpu                        = True if torch.cuda.is_available() else None,
    use_amp                        = True,
    isTrain                        = False ,
    load_epoch                     = "best",
)

model = StarDist2D(conf)

Load path: checkpoints\dsb2018\best.pth cuda:0
Loading threholds ...
Instanciating network
initialize network with normal
Network [StarDistResnet] was created. Total number of parameters: 0.5 million. To see the architecture, do print(network).
<All keys matched successfully>
Loading model from <checkpoints\dsb2018\best.pth>.



In [10]:
model.opt.epoch_count, model.thresholds

(10, {'prob': 0.48855887404256104, 'nms': 0.3})

### Prediction

In [11]:
dest_dir = Path(image_dir) / f"predictions/model_{model.opt.name}_epoch_{conf.load_epoch}" #f"results/{model.opt.name}_{conf.load_epoch}/"

dest_dir = Path(dest_dir)
os.makedirs(dest_dir, exist_ok=True)

print("Predictions will be saved at: ", dest_dir)

Predictions will be saved at:  datasets\dsb2018\test\images\predictions\model_dsb2018_epoch_best_per_patch


In [12]:
for image_path in tqdm(image_paths):
    image = load_tif(image_path).squeeze()
    if image.ndim==2:
        image = image[np.newaxis]
    
    image = normalize(image, 1,99.8,axis=axis_norm)
    
    pred_mask = model.predict_instance(image)[0]
    mask_path = dest_dir / f"{image_path.name}"
    
    save_tif(mask_path, pred_mask)

100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:07<00:00,  6.87it/s]
