In [6]:
import numpy as np
import torch
import random
import zarr
import tifffile

In [25]:
def uint16_normals_to_float_normals(normals_uint16):
    # normals_uint16: shape (Z, Y, X, 3) or (3, Z, Y, X)
    # Convert to float in [-1, +1]
    normals_float = (normals_uint16.astype(np.float32) / 32767.5) - 1.0
    return normals_float

def float_normals_to_uint16_normals(normals_float):
    # Flip or rotate has been applied in float
    # Now convert back to [0..65535]
    normals_uint16 = ((normals_float + 1.0) * 32767.5).astype(np.uint16)
    return normals_uint16

class RandomFlipWithNormals:
    """
    Flip 3D volumes (and their normal vectors) along each axis with probability p.
    Expects data in (C, Z, Y, X) format.
    """

    def __init__(self, p=0.5, normal_keys=("normals",)):
        """
        Args:
            p (float): Probability of flipping along each axis independently.
            normal_keys (tuple): which dictionary keys hold normal-vector data.
                                 e.g. ("normals", "surface_normals")
        """
        self.p = p
        self.normal_keys = set(normal_keys)

    def __call__(self, data_dict):
        """
        data_dict: dict of {key: np.ndarray}, each shaped (C, Z, Y, X) or (1, Z, Y, X), etc.

        Returns:
            A modified data_dict after possibly flipping each volume along Z, Y, X.
        """
        # Flip along Z => axis=1 (since axis=0 is channel, axis=1 is Z)
        if random.random() < self.p:
            for k, arr in data_dict.items():
                # flip array
                arr = np.flip(arr, axis=1).copy()
                # if it's normal data, optionally flip the normal's Z component
                if k in self.normal_keys:
                    # Nx = arr[0], Ny = arr[1], Nz = arr[2]
                    # physically reflecting across Z => Nz *= -1
                    arr[2] *= -1
                data_dict[k] = arr

        # Flip along Y => axis=2
        if random.random() < self.p:
            for k, arr in data_dict.items():
                arr = np.flip(arr, axis=2).copy()
                if k in self.normal_keys:
                    # multiply the Y-component (arr[1]) by -1
                    arr[1] *= -1
                data_dict[k] = arr

        # Flip along X => axis=3
        if random.random() < self.p:
            for k, arr in data_dict.items():
                arr = np.flip(arr, axis=3).copy()
                if k in self.normal_keys:
                    # multiply the X-component (arr[0]) by -1
                    arr[0] *= -1
                data_dict[k] = arr

        return data_dict

class RandomRotate90WithNormals:
    """
    Randomly rotate all 3D volumes (and normal vectors) by 90-degree increments
    around one of the axes {X, Y, Z}.

    Expects arrays shaped (C, Z, Y, X).
    If a key is in `normal_keys`, we apply the normal-vector permutations.
    """

    def __init__(self, axes=('x', 'y', 'z'), p=0.5, normal_keys=("normals",)):
        self.axes = axes
        self.p = p
        self.normal_keys = set(normal_keys)

    def __call__(self, data_dict):
        if random.random() >= self.p:
            # No rotation
            return data_dict

        axis = random.choice(self.axes)  # 'x', 'y', or 'z'
        k = random.choice([1, 2, 3])     # 90, 180, 270

        # For each array in the dictionary, rotate the spatial dims
        #  then fix up the normal channels if needed.
        for key, arr in data_dict.items():
            # 1) rotate the array in (Z, Y, X) => axes=(1,2,3)
            if axis == 'z':
                # rotating the Y,X plane => axes=(2,3)
                arr = np.rot90(arr, k=k, axes=(2, 3))
                if key in self.normal_keys:
                    # Nx=arr[0], Ny=arr[1], Nz=arr[2]
                    for _ in range(k):
                        nx = arr[0].copy()
                        ny = arr[1].copy()
                        nz = arr[2].copy()
                        # +90° around Z => (nx, ny, nz) -> (ny, -nx, nz)
                        arr[0] = ny
                        arr[1] = -nx
                        arr[2] = nz

            elif axis == 'y':
                # rotate the Z,X plane => axes=(1,3)
                arr = np.rot90(arr, k=k, axes=(1, 3))
                if key in self.normal_keys:
                    for _ in range(k):
                        nx = arr[0].copy()
                        ny = arr[1].copy()
                        nz = arr[2].copy()
                        # +90° around Y => (nx, ny, nz) -> (nz, ny, -nx)
                        arr[0] = nz
                        arr[2] = -nx

            else:  # axis == 'x'
                # rotate the Z,Y plane => axes=(1,2)
                arr = np.rot90(arr, k=k, axes=(1, 2))
                if key in self.normal_keys:
                    for _ in range(k):
                        nx = arr[0].copy()
                        ny = arr[1].copy()
                        nz = arr[2].copy()
                        # +90° around X => (nx, ny, nz) -> (nx, nz, -ny)
                        arr[1] = nz
                        arr[2] = -ny

            data_dict[key] = arr

        return data_dict

In [8]:
v_raw = zarr.open('/mnt/raid_nvme/s1.zarr', mode='r')
v_sheet = zarr.open('/mnt/raid_nvme/datasets/1-voxel-sheet_slices-closed.zarr/0.zarr', mode='r')
v_normals = zarr.open('/home/sean/Documents/GitHub/VC-Surface-Models/models/normals.zarr', mode='r')

print(f'v_raw shape: {v_raw.shape}, v_normals shape: {v_normals.shape}, v_sheet shape: {v_sheet.shape}')

v_raw shape: (14376, 7888, 8096), v_normals shape: (13700, 7888, 8096, 3), v_sheet shape: (14376, 7888, 8096)


In [28]:
zmin, zmax = 10000, 10200
ymin, ymax = 3000, 3200
xmin, xmax = 3000, 3200

raw_crop = v_raw[zmin:zmax, ymin:ymax, xmin:xmax]
norm_crop = v_normals[zmin:zmax, ymin:ymax, xmin:xmax]
sheet_crop = v_sheet[zmin:zmax, ymin:ymax, xmin:xmax]

norm_crop = uint16_normals_to_float_normals(norm_crop)

In [None]:
tifffile.imwrite('raw_crop.tif', raw_crop)
tifffile.imwrite('norm_crop.tif', norm_crop)
tifffile.imwrite('sheet_crop.tif', sheet_crop)

In [29]:
rot = RandomRotate90WithNormals(axes=('x', 'y', 'z'), p=1)
flip = RandomFlip3DWithNormals(p=1)

for _ in range(3):
    rotated_crop, rotated_norms = rot(raw_crop, norm_crop)
    flipped_crop, flipped_norms = flip(rotated_crop, rotated_norms)
    rotated_norms, flipped_norms = float_normals_to_uint16_normals(rotated_norms), float_normals_to_uint16_normals(flipped_norms)

    tifffile.imwrite(f'rotated_crop_{_}.tif', flipped_crop)
    tifffile.imwrite(f'rotated_norms_{_}.tif', flipped_norms)

In [30]:
import numpy as np
import random

def main():
    # Re-import or define your classes and functions if needed:
    # from your_module import (
    #    uint16_normals_to_float_normals, float_normals_to_uint16_normals,
    #    RandomFlip3DWithNormals, RandomRotate90WithNormals
    # )

    # 1. Create small 3D volume: shape (Z, Y, X) = (2, 2, 2)
    Z, Y, X = 2, 2, 2
    image = np.arange(Z*Y*X).reshape(Z, Y, X).astype(np.float32)

    # 2. Create corresponding normals in float:
    #    We'll do Nx=1 everywhere (the simplest case).
    #    shape => (2,2,2,3)
    normals_float = np.zeros((Z, Y, X, 3), dtype=np.float32)
    normals_float[..., 0] = 1.0  # Nx=1

    # 3. Convert float->uint16, then back->float, just to simulate your workflow
    normals_uint16 = float_normals_to_uint16_normals(normals_float)
    # Now we have [0..65535] data
    normals_float_reloaded = uint16_normals_to_float_normals(normals_uint16)

    # 4. Instantiate your transforms with p=1.0 => they always apply
    #    We'll fix random.seed for reproducibility
    random.seed(1234)
    flip_transform = RandomFlip3DWithNormals(p=1.0)
    rotate_transform = RandomRotate90WithNormals(axes=('x','y','z'), p=1.0)

    # 5. Apply them
    # step A: flip
    flipped_image, flipped_normals = flip_transform(image, normals_float_reloaded)
    # step B: rotate
    rotated_image, rotated_normals = rotate_transform(flipped_image, flipped_normals)

    # 6. Convert back to uint16
    rotated_normals_uint16 = float_normals_to_uint16_normals(rotated_normals)

    # 7. Print or check results
    print("Original image:\n", image)
    print("Flipped image:\n", flipped_image)
    print("Rotated image:\n", rotated_image)
    print("Original normals_float:\n", normals_float)
    print("After flip->rotate (float):\n", rotated_normals)
    print("After flip->rotate, converted back to uint16:\n", rotated_normals_uint16)

    # Because the transforms are random, you'd re-run or forcibly set which flips/rotations happen.
    # Or you can interpret the final result to see if Nx is indeed negative or re-oriented
    # as expected for the randomly chosen axis.

if __name__ == "__main__":
    main()


Original image:
 [[[0. 1.]
  [2. 3.]]

 [[4. 5.]
  [6. 7.]]]
Flipped image:
 [[[7. 6.]
  [5. 4.]]

 [[3. 2.]
  [1. 0.]]]
Rotated image:
 [[[6. 4.]
  [7. 5.]]

 [[2. 0.]
  [3. 1.]]]
Original normals_float:
 [[[[1. 0. 0.]
   [1. 0. 0.]]

  [[1. 0. 0.]
   [1. 0. 0.]]]


 [[[1. 0. 0.]
   [1. 0. 0.]]

  [[1. 0. 0.]
   [1. 0. 0.]]]]
After flip->rotate (float):
 [[[[ 1.5258789e-05  1.0000000e+00 -1.5258789e-05]
   [ 1.5258789e-05  1.0000000e+00 -1.5258789e-05]]

  [[ 1.5258789e-05  1.0000000e+00 -1.5258789e-05]
   [ 1.5258789e-05  1.0000000e+00 -1.5258789e-05]]]


 [[[ 1.5258789e-05  1.0000000e+00 -1.5258789e-05]
   [ 1.5258789e-05  1.0000000e+00 -1.5258789e-05]]

  [[ 1.5258789e-05  1.0000000e+00 -1.5258789e-05]
   [ 1.5258789e-05  1.0000000e+00 -1.5258789e-05]]]]
After flip->rotate, converted back to uint16:
 [[[[32768 65535 32767]
   [32768 65535 32767]]

  [[32768 65535 32767]
   [32768 65535 32767]]]


 [[[32768 65535 32767]
   [32768 65535 32767]]

  [[32768 65535 32767]
   [32768 65535