In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import numpy as np
import torch
from torch import nn
import imageio
import napari


from utils.in_out_tools import write_videos_on_disk
from utils.training_inference_tools import get_preds_from_path
from utils.training_script_utils import init_model
from utils.visualization_tools import (
    get_annotations_contour,
    get_discrete_cmap,
    get_labels_cmap,
)

from config import TrainingConfig



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



Parameters that are necessary to configure the dataset and the UNet model (can be eventually hard-coded in the function)

In [3]:
### Set training-specific parameters ###

# Initialize training-specific parameters
config_path = os.path.join("config_files", "config_final_model.ini")
params = TrainingConfig(training_config_file=config_path)
params.training_name = "final_model"
model_name = f"network_100000.pth"

assert params.nn_architecture in [
    "pablos_unet",
    "github_unet",
    "openai_unet",
], f"nn_architecture must be one of 'pablos_unet', 'github_unet', 'openai_unet'"

Load UNet model

In [4]:
### Configure UNet ###
params.set_device(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

network = init_model(params=params)
network = nn.DataParallel(network).to(params.device)

### Load UNet model ###

# Path to the saved model checkpoint
models_relative_path = "runs/"
model_path = os.path.join(models_relative_path, params.training_name, model_name)

# Load the model state dictionary
network.load_state_dict(torch.load(model_path, map_location=params.device))
network.eval()

Define movie path

In [5]:
# movie_path = os.path.join(
#     r"C:\Users\dotti\sparks_project\data\sparks_dataset",
#     "34_video.tif"
# )
movie_path = r"C:\Users\dotti\Desktop\cropped 34_video.tif"

Function definition

In [7]:
segmentation, instances = get_preds_from_path(
    model=network,
    params=params,
    movie_path=movie_path,
    return_dict=False,
)

### Visualize preds with Napari

In [9]:
# open original movie
sample = np.asarray(imageio.volread(movie_path))

In [10]:
# set up napari parameters
cmap = get_discrete_cmap(name="gray", lut=16)
labels_cmap = get_labels_cmap()

In [11]:
# visualize only border of classes (segmentation array)
segmentation_border = get_annotations_contour(segmentation)

In [12]:
viewer = napari.Viewer()
viewer.add_image(
    sample,
    name="input movie",
    # colormap=('colors',cmap)
)

viewer.add_labels(
    segmentation_border,
    name="segmentation",
    opacity=0.9,
    color=labels_cmap,
)  # only visualize border

viewer.add_labels(
    segmentation,
    name="segmentation",
    opacity=0.5,
    color=labels_cmap,
    visible=False,
)  # to visualize whole roi instead

viewer.add_labels(
    instances,
    name="instances",
    opacity=0.5,
)

<Labels layer 'instances' at 0x21662e24400>