In [1]:
from typing import Tuple, List
import numpy as np

In [2]:
class Random3DCanonicalPatchSamplingStrategy():
    def __init__(self, patch_size=[96, 96, 96], patch_overlap_frac=0.4):
        self.patch_overlap_frac = patch_overlap_frac
        self.patch_size = patch_size

def _get_deterministic_3d_patches(
    self, padded_volume: np.ndarray
) -> Tuple[List[np.ndarray], List[tuple]]:
    """
    Gets deterministic patches from a 3D volume, ensuring all patches have the same size and
    considering the desired overlap between them.

    :param padded_volume: np.ndarray. The 3D volume to extract patches from, already modified 
        by self._pad_volume_to_fit_patches
    :return: List[Tuple[np.ndarray, tuple]]. A list of patches and their corresponding slice 
        tuples.
    """
    patches = []
    slice_tuples = []

    # Calculate step size for each dimension
    step_x = int(self.patch_size[0] * (1 - self.patch_overlap_frac))
    step_y = int(self.patch_size[1] * (1 - self.patch_overlap_frac))
    step_z = int(self.patch_size[2] * (1 - self.patch_overlap_frac))

    # Ensure step size is at least 1
    step_x = max(1, step_x)
    step_y = max(1, step_y)
    step_z = max(1, step_z)

    # Iterate over the padded volume and extract patches
    for x in range(0, padded_volume.shape[0] - self.patch_size[0] + 1, step_x):
        for y in range(0, padded_volume.shape[1] - self.patch_size[1] + 1, step_y):
            for z in range(
                0, padded_volume.shape[2] - self.patch_size[2] + 1, step_z
            ):
                slice_tuple = (
                    slice(x, x + self.patch_size[0]),
                    slice(y, y + self.patch_size[1]),
                    slice(z, z + self.patch_size[2]),
                )
                patch = padded_volume[slice_tuple]
                patches.append(patch)
                slice_tuples.append(slice_tuple)

    return patches, slice_tuples

def _pad_volume_to_fit_patches(self, volume: np.ndarray) -> np.ndarray:
    """
    Pads the given volume to ensure that it can accommodate patches of the specified size,
    considering the desired overlap between patches.

    :param volume: np.ndarray. The original volume to pad.
    :return: np.ndarray. The padded volume.
    """
    # Calculate the step size for each dimension considering the overlap
    step_x = int(self.patch_size[0] * (1 - self.patch_overlap_frac))
    step_y = int(self.patch_size[1] * (1 - self.patch_overlap_frac))
    step_z = int(self.patch_size[2] * (1 - self.patch_overlap_frac))

    # Calculate padding for each dimension once step size is known
    pad_x = (step_x - (volume.shape[0] - self.patch_size[0]) % step_x) % step_x
    pad_y = (step_y - (volume.shape[1] - self.patch_size[1]) % step_y) % step_y
    pad_z = (step_z - (volume.shape[2] - self.patch_size[2]) % step_z) % step_z

    # Calculate half-padding + reminders
    hpad_x = pad_x // 2
    hpad_y = pad_y // 2
    hpad_z = pad_z // 2

    # Compute spatial padding tuples
    pad_spatial = (
        (hpad_x, hpad_x + (pad_x % 2)),
        (hpad_y, hpad_y + (pad_y % 2)),
        (hpad_z, hpad_z + (pad_z % 2))
    )

    # Pad the volume
    if len(volume.shape) == 3:
        padded_volume = np.pad(volume, pad_spatial, mode="constant")
    else:
        padded_volume = np.pad(
            volume, pad_spatial + ((0, 0),), mode="constant"
        )
    return padded_volume


In [3]:
def get_gradient_steps_in_trianing(dummy_im, strategy, n_val_im,
                                    val_batchsize, train_steps_per_epoch, train_batchsize):
    


    padded_im = _pad_volume_to_fit_patches(strategy, dummy_im)
    patches, slice_tuples = _get_deterministic_3d_patches(strategy, padded_im)

    n_patches_pre_im = len(patches)
    print(n_patches_pre_im)
    
    
    patches_val = n_val_im * n_patches_pre_im
    steps_valid = patches_val/val_batchsize
    
    patches_train = train_steps_per_epoch/train_batchsize
    
    print(
        f"TRAIN: {patches_train} patches, {train_steps_per_epoch} gradient steps"
    )
    print(
        f"VALIDATION: {patches_val} patches, {steps_valid} inference steps"

    )


In [4]:
dummy_im = np.zeros(shape = (400, 400, 400))

strategy = Random3DCanonicalPatchSamplingStrategy(
    patch_overlap_frac=0
)
get_gradient_steps_in_trianing(
    dummy_im, 
    strategy,
    n_val_im=15,
    val_batchsize=10, 
    train_steps_per_epoch=5000, 
    train_batchsize=5
    )

125
TRAIN: 1000.0 patches, 5000 gradient steps
VALIDATION: 1875 patches, 187.5 inference steps
