In [None]:
# Imports
import pathlib

import albumentations
import numpy as np
import torch
from torch.utils.data import DataLoader

from customdatasets import SegmentationDataSet3
from transformations import (
    ComposeDouble,
    normalize_01,
    FunctionWrapperDouble,
    create_dense_target,
    AlbuSeg3d,
)

# root directory
root = pathlib.Path.cwd() / "Microtubules3D"


def get_filenames_of_path(path: pathlib.Path, ext: str = "*"):
    """Returns a list of files in a directory/path. Uses pathlib."""
    filenames = [file for file in path.glob(ext) if file.is_file()]
    return filenames


# input and target files
inputs = get_filenames_of_path(root / "Input")
targets = get_filenames_of_path(root / "Target")

# training transformations and augmentations
# example how to properly resize and use AlbuSeg3d
# please note that the input is grayscale and the channel dimension of size 1 is added
# also note that the AlbuSeg3d currently only works with input that does not have a C dim!
transforms_training = ComposeDouble(
    [
        # FunctionWrapperDouble(resize, input=True, target=False, output_shape=(16, 100, 100)),
        # FunctionWrapperDouble(resize, input=False, target=True, output_shape=(16, 100, 100), order=0, anti_aliasing=False, preserve_range=True),
        # AlbuSeg3d(albumentations.HorizontalFlip(p=0.5)),
        # AlbuSeg3d(albumentations.VerticalFlip(p=0.5)),
        # AlbuSeg3d(albumentations.Rotate(p=0.5)),
        AlbuSeg3d(albumentations.RandomRotate90(p=0.5)),
        FunctionWrapperDouble(create_dense_target, input=False, target=True),
        FunctionWrapperDouble(np.expand_dims, axis=0),
        # RandomFlip(ndim_spatial=3),
        FunctionWrapperDouble(normalize_01),
    ]
)

# random seed
random_seed = 42

# dataset training
dataset_train = SegmentationDataSet3(
    inputs=inputs,
    targets=targets,
    transform=transforms_training,
    use_cache=False,
    pre_transform=None,
)

x, y = dataset_train[1]
print(x.shape)
print(x.min(), x.max())
print(y.shape)
print(torch.unique(y))

# dataloader training
dataloader_training = DataLoader(
    dataset=dataset_train,
    batch_size=1,
    # batch_size of 2 won't work because the depth dimension is different between the 2 samples
    shuffle=True,
)

batch = next(iter(dataloader_training))
x, y = batch
print(x.shape)
print(x.min(), x.max())
print(y.shape)
print(torch.unique(y))

# create DatasetViewer instances
from visual import DatasetViewer

dataset_viewer_training = DatasetViewer(dataset_train)
dataset_viewer_training.napari()  # navigate with 'n' for next and 'b' for back
