# Imports

In [None]:
import copick
import os
import shutil
import numpy as np
import zarr
import json
from tqdm import tqdm
import polars as pl
from itkwidgets import view
import cv2
import torch

import cv2
import polars as pl
import numpy as np
from matplotlib import pyplot as plt
import PIL
from IPython.display import Image
from IPython.display import display as idisplay
import random

from monai import transforms as T
from yagm.transforms import monai_custom as CT
from typing import Tuple
from collections import Counter
import heapq
import math
from math import pi


LAZY = True
LOG_STATS = False
OVERRIDES = {'image': {'padding_mode': 'zeros'}}
ALIGN_CORNERS = False

# Functions

In [None]:
def describe(df):
    with pl.Config(tbl_rows = 30):
        display(df.describe(percentiles = PERCENTILES))
        

def longest_resize(img, max_h = None, max_w = None, upscale = False, interpolation = cv2.INTER_LINEAR):
    if max_h is None and max_w is None:
        return img
    img_h, img_w = img.shape[:2]
    _ratios = []
    if max_h is not None:
        _ratios.append(max_h / img_h)
    if max_w is not None:
        _ratios.append(max_w / img_w)
    r = min(_ratios)
    if not upscale:
        r = min(1.0, r)
    if r == 1.0:
        return img
    new_h, new_w = int(r * img_h), int(r * img_w)
    img = cv2.resize(img, (new_w, new_h), interpolation)
    return img


def pad(img, new_h, new_w, pad_mode = 'top_left', pad_value = 'noise'):
    ori_h, ori_w = img.shape[:2]
    assert new_h >= ori_h and new_w >= ori_w
    if pad_value == 'noise':
        padded = np.random.rand(new_h, new_w, img.shape[2])
        if img.dtype == np.uint8:
            padded = (padded * 255).astype(np.uint8)
        else:
            padded = padded.astype(img.dtype)
    if pad_mode == 'center':
        pad_top = (new_h - ori_h) // 2
        pad_left = (new_w - ori_w) // 2
    elif pad_mode == 'top_left':
        pad_top, pad_left = 0, 0
    else:
        raise ValueError
    padded[pad_top:pad_top + ori_h, pad_left: pad_left + ori_w] = img
    return padded


def concat_imgs(imgs, max_h = None, max_w = None, axis = 1, border_width = 5, border_color = [255, 0, 255]):
    hws = [img.shape[:2] for img in imgs]
    src_max_h = max([hw[0] for hw in hws])
    src_max_w = max([hw[1] for hw in hws])
    max_h = min(max_h, src_max_h) if max_h is not None else src_max_h
    max_w = min(max_w, src_max_w) if max_w is not None else src_max_w
    r = min(max_h / src_max_h, max_w / src_max_w)
    dst_max_h, dst_max_w = int(r * src_max_h), int(r * src_max_w)

    new_imgs = []
    for i, img in enumerate(imgs):
        h, w = img.shape[:2]
        new_h, new_w = int(h * r), int(w * r)
        new_img = cv2.resize(img, (new_w, new_h))
        if len(new_img.shape) < 3:
            assert len(new_img.shape) == 2
            new_img = new_img[..., None]
        new_img = pad(new_img, dst_max_h, dst_max_w, pad_mode = 'top_left', pad_value = 'noise')
        new_imgs.append(new_img)
        if i != len(imgs) - 1:
            _shape = list(new_img.shape)
            _shape[axis] = border_width
            border = np.zeros(_shape, dtype = new_img.dtype)
            border_value = np.array(border_color, dtype = np.uint8)[None, None]
            if border.dtype != np.uint8:
                border_value = border_value / 255
            border[..., :] = border_value
            new_imgs.append(border)
    ret = np.concatenate(new_imgs, axis = axis)
    return ret


def float_to_uint8(imgs):
    if isinstance(imgs, np.ndarray):
        imgs = [imgs]
    return [(255 * img).astype(np.uint8) for img in imgs]


def display_img(img, max_h = None, max_w = None):
    img = longest_resize(img, max_h, max_w)
    print('display:', img.shape, img.dtype)
    if img.dtype != np.uint8:
        img = (img * 255).astype(np.uint8)
    idisplay(PIL.Image.fromarray(img))


def select(tomo, num_slices, axis = 'z'):
    assert len(tomo.shape) == 3 and num_slices >= 1
    axis = {'z': 0, 'y': 1, 'x': 2}[axis]
    slices = [slice(None), slice(None), slice(None)]
    # main_slice = slice(0, tomo.shape[axis], (tomo.shape[axis] - 1) // (num_slices - 1))
    main_slice = list(range(0, tomo.shape[axis], (tomo.shape[axis] - 1) // (num_slices - 1)))
    slices[axis] = main_slice
    # print('SLICES:', slices)
    tomo = tomo[*slices]
    permute = [0, 1, 2]
    permute.remove(axis)
    new_permute = [axis, *permute]
    tomo = np.transpose(tomo, new_permute)
    return tomo, main_slice


def get_norm_func(norm_type = 'sample_min_max'):
    # Transform
    if norm_type == "sample_min_max":
        normalize_transform = T.ScaleIntensityd(
            keys=["image"],
            minv=0.0,
            maxv=1.0,
            factor=None,
            channel_wise=False,
            dtype=None,
        )
    elif norm_type == "sample_mean_std":
        normalize_transform = T.NormalizeIntensityd(
            keys=["image"],
            subtrahend=None,
            divisor=None,
            nonzero=False,
            channel_wise=False,
        )
    elif norm_type == "global_mean_std":
        normalize_transform = T.NormalizeIntensityd(
            keys=["image"],
            subtrahend=5.257659e-08,
            divisor=7.199923e-06,
            nonzero=False,
            channel_wise=False,
        )
    elif norm_type.startswith("sample_percentile"):
        postfix = norm_type.replace('sample_percentile', '', 1)
        if postfix == '':
            clip = False
        elif postfix == '_clip':
            clip = True
        else:
            raise ValueError
        normalize_transform = T.ScaleIntensityRangePercentilesd(
            keys=["image"],
            lower=5,
            upper=95,
            b_min=0.0,
            b_max=1.0,
            clip=clip,
            relative=True,
            channel_wise=False,
        )
    elif norm_type == "sample_hist_equalize":
        normalize_transform = T.HistogramNormalized(
            keys=["image"], num_bins=256, min=0.0, max=1.0
        )
    elif norm_type == 'log_percentile':
        def _func(image):
            image = np.log(image - image.min())
            low = np.percentile(image, 5)
            top = np.percentile(image, 95)
            print(image.min(), low, top, image.max())
            return np.clip(image, a_min=low, a_max=top)
        return _func
    elif norm_type == 'log_sign':
        return lambda image: np.log(np.abs(image) + 1e-8) * np.sign(image)
    elif norm_type == 'log_abs':
        return lambda image: np.log(np.abs(image) + 1e-8)
    else:
        raise ValueError
    return lambda image: normalize_transform({'image': image[None]})['image'][0]


def crop_kpts(
    kpts: torch.Tensor, crop_start: Tuple[int, int, int], crop_end: Tuple[int, int, int]
) -> torch.Tensor:
    """
    Crop keypoints: remove out-of-regions ones, then offset valid one due to cropping

    Args:
        kpts: (N, D), where typically D>=3, the first 3 dimensions are Z, Y, X coordinate
        crop_start: starting coordinate of length 3 for Z, Y, X
        crop_end: ending coordinate of length 3 for Z, Y, X

    Returns:
        Cropped keypoints with length <= input keypoints
    """
    crop_start = torch.tensor(crop_start)[None]
    crop_end = torch.tensor(crop_end)[None]
    assert crop_start.shape == crop_end.shape == (1, 3)
    assert kpts.ndim == 2 and kpts.shape[1] >= 3
    keep = torch.all(
        torch.logical_and(kpts[:, :3] >= crop_start, kpts[:, :3] <= crop_end), dim=1
    )
    kpts = kpts[keep]
    kpts[:, :3] = kpts[:, :3] - crop_start
    return kpts


def select_top_k_unique_keypoints(keypoints, K):
    """
    Selects top K keypoints such that:
    - The planes z, y, x passing through these keypoints cover the maximum number of other keypoints.
    - Each z, y, x coordinate of the selected keypoints is unique.

    Args:
        keypoints (list of tuple): List of 3D integer coordinates (z, y, x).
        K (int): Number of keypoints to select.

    Returns:
        list of tuple: Top K keypoints with unique z, y, and x coordinates.
    """
    # Initialize counters for keypoint coverage by planes
    z_counter = Counter()
    y_counter = Counter()
    x_counter = Counter()

    # Count how many keypoints are covered by each plane
    for z, y, x in keypoints:
        z_counter[z] += 1
        y_counter[y] += 1
        x_counter[x] += 1

    # Calculate the coverage score for each keypoint
    keypoint_scores = []
    for z, y, x in keypoints:
        score = z_counter[z] + y_counter[y] + x_counter[x] - 2  # Avoid double counting the keypoint itself
        keypoint_scores.append((score, (z, y, x)))

    # Sort keypoints by score in descending order
    keypoint_scores.sort(reverse=True, key=lambda item: item[0])

    # Select the top K keypoints with unique z, y, and x coordinates
    selected_keypoints = []
    used_z = set()
    used_y = set()
    used_x = set()

    for _, (z, y, x) in keypoint_scores:
        if z not in used_z and y not in used_y and x not in used_x:
            selected_keypoints.append((z, y, x))
            used_z.add(z)
            used_y.add(y)
            used_x.add(x)
            if len(selected_keypoints) == K:
                break
        else:
            print('USED X/Y/Z')

    return selected_keypoints


COLORS = [
    (255, 0, 0),      # Red
    (0, 255, 0),      # Green
    (0, 0, 255),      # Blue
    (0, 255, 255),    # Cyan
    (255, 0, 255),    # Magenta
    (255, 255, 0)     # Yellow
]
AXIS_COLORS = [
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255)
]


def draw_ellipse_from_normal_distribution(image, mean, cov, color=(255, 0, 0), thickness=2, expand=1.0):
    """
    Draws an ellipse representing a 2D normal distribution on a given image.
    
    Args:
        image (np.ndarray): The image on which to draw the ellipse.
        mean (tuple): Mean vector (x, y) of the 2D normal distribution.
        cov (np.ndarray): 2x2 covariance matrix.
        color (tuple): Color of the ellipse (B, G, R).
        thickness (int): Thickness of the ellipse.
    """
    # Compute eigenvalues and eigenvectors of the covariance matrix
    eigenvalues, eigenvectors = np.linalg.eigh(cov)
    
    # Get the angle of rotation from the eigenvectors
    angle = np.degrees(np.arctan2(eigenvectors[1, 0], eigenvectors[0, 0]))
    
    # Get the lengths of the axes
    axis_length = np.sqrt(eigenvalues) * expand
    # print('axis length:', axis_length)
    
    # Convert mean to integer tuple for OpenCV
    center = tuple(map(int, mean))
    
    # Draw the ellipse
    cv2.ellipse(
        image,
        center,
        (int(axis_length[0]), int(axis_length[1])),  # Major and minor axes
        angle,
        0, 360,
        color,
        thickness
    )


def compute_spacing_shape(ori_shape, ori_spacing, target_spacing, scale_extent = False):
    in_coords = np.array([(-0.5, dim - 0.5) if scale_extent else (0.0, dim - 1.0) for dim in ori_shape])
    # print('in_coords:', in_coords)
    out_shape = np.ptp(in_coords * ori_spacing / target_spacing, axis = 1)
    out_shape = np.round(out_shape) if scale_extent else np.round(out_shape + 1.0)
    return [int(e) for e in out_shape]
    

def viz_transform(all_tomo_ids, all_tomos, all_kpts, all_voxel_spacings, target_voxel_spacing, transforms, seed = None, view_3d = False):
    assert len(all_tomos) == len(all_kpts)
    transform = T.Compose(
        [
            *transforms[:],
            CT.ApplyTransformToNormalDistributionsd(
                keys=["kpts"],
                refer_keys=["image"],
                dtype=torch.float32,
                affine=None,
                invert_affine=True,
            ),
            # transforms[-1],
        ],
        map_items=True,
        log_stats=True,
        lazy=LAZY,
        overrides={'image': {'padding_mode': 'zeros',}},
    )
    if seed:
        transform.set_random_state(seed=seed)

    for run_name, ori_tomo, ori_kpts, ori_voxel_spacing in zip(all_tomo_ids, all_tomos, all_kpts, all_voxel_spacings):
        kpts = ori_kpts
        print('BEFORE:', f'shape={ori_tomo.shape} dtype={ori_tomo.dtype} min={ori_tomo.min()} mean={ori_tomo.mean()} std={ori_tomo.std()} max={ori_tomo.max()}')
        print('VOXEL SPACING:', ori_voxel_spacing, '-->', target_voxel_spacing)
        # print('EXPECTED SHAPE AFTER SPACING:', [round(e * ori_voxel_spacing / target_voxel_spacing + 1) for e in ori_tomo.shape])
        print('EXPECTED SHAPE AFTER SPACING:', compute_spacing_shape(ori_tomo.shape, ori_voxel_spacing, target_voxel_spacing, scale_extent = False))
        print('NUM KPTS:', len(ori_kpts))

        spacing_scale = target_voxel_spacing / ori_voxel_spacing
        spaced_kpts = kpts.copy()
        spaced_kpts[:, :3] /= spacing_scale
        spaced_kpts[:, 3:9] /= (spacing_scale**2)
        data = {
            'image': np.transpose(ori_tomo, (2, 1, 0))[None], # ZYX --> 1XYZ
            'kpts': kpts[None],
            'spaced_kpts': spaced_kpts[None],
            "spacing_scale": spacing_scale
        }
        print('BEFORE Transform:', data['image'].shape, data['image'].dtype, data['image'].min(), data['image'].max())
        data = transform(data)
        print('AFTER Transform:', data['image'].shape, data['image'].dtype, data['image'].min(), data['image'].max())
        image = torch.permute(data["image"][0], (2,1,0)) # 1XYZ -> ZYX
        # image = image.astype(np.uint8)
        image = torch.clip(image, 0, 255).byte().cpu().numpy()
        
        print('AFTER:', f'shape={image.shape} dtype={image.dtype} min={image.min()} mean={image.mean()} std={image.std()} max={image.max()}')
        kpts = data["kpts"]
        assert kpts.shape[0] == 1
        # zy, zx, yx <-> 6,7,8
        # xy, xz, yz <-> 8, 7, 6
        kpts = kpts[0, :, [2,1,0,5,4,3,8,7,6,9]]
        # kpts = torch.cat([_kpts[0][:, [2,1,0]], kpts[:, 3:]], dim=-1)
        # filter invalid keypoints
        image_shape = image.shape
        kpts = crop_kpts(kpts, [0, 0, 0], image_shape)
        assert kpts.shape[1] == 10
        print('After crop:', kpts.shape)

        # if view_3d, use itkwidgets
        if view_3d:
            view(image = image, point_set = kpts[:, [2, 1, 0]])
            return
        
        K=5
        kpt_zyxs = np.round((kpts[:, :3] - 0.5).numpy())
        topk_freq = select_top_k_unique_keypoints(kpt_zyxs, K=K)
        if len(topk_freq) < K:
            print(f'Sample top {len(topk_freq)}/{K}')
            for _ in range(K - len(topk_freq)):
                topk_freq.append([random.randrange(0, axis_shape) for axis_shape in image.shape])
        topk_freq.sort(key = lambda x: x[0])
        topk_freq = np.array(topk_freq, dtype = int)
        
        print('RUN:', run_name)
        for axis in ['z', 'y', 'x']:
            axis_idx = {'z': 0, 'y': 1, 'x': 2}[axis]
            # tomo, slice_idxs = select(image, 4, axis)

            slice_idxs = topk_freq[:, axis_idx]
            _slices = [slice(None), slice(None), slice(None)]
            _slices[axis_idx] = slice_idxs
            tomo = image[*_slices]
            permute = [0, 1, 2]
            permute.remove(axis_idx)
            new_permute = [axis_idx, *permute]
            tomo = np.transpose(tomo, new_permute)
            
            print(f'axis={axis} axis_idx={axis_idx} slices={slice_idxs} tomo_shape={tomo.shape}')
            # tomo = (((tomo - tomo.min()) / (tomo.max() - tomo.min())) * 255).astype(np.uint8)
            
            # draw kpts as circles
            viz_tomo_slices = []
            for j, (slice_idx, tomo_slice) in enumerate(zip(slice_idxs, tomo)):
                
                # tomo_slice = (np.clip(tomo_slice, a_min=0, a_max=1.0) * 255).astype(np.uint8)
                tomo_slice = np.repeat(tomo_slice[..., None], 3, axis = -1) # HW -> HW3

                # draw 2 projection lines
                _tmp = [0, 1, 2]
                _tmp.remove(axis_idx)
                proj_row_idx = topk_freq[j, _tmp[0]]
                proj_col_idx = topk_freq[j, _tmp[1]]
                H, W = tomo_slice.shape[:2]
                cv2.line(tomo_slice, (0, proj_row_idx), (W-1, proj_row_idx), AXIS_COLORS[_tmp[0]], thickness=2)
                cv2.line(tomo_slice, (proj_col_idx, 0), (proj_col_idx, H-1), AXIS_COLORS[_tmp[1]], thickness=2)
                
                _squared_radius = kpts[:, axis_idx + 3] ** 2 -  (kpts[:, axis_idx] - slice_idx - 0.5) ** 2
                # keep = _squared_radius > 0

                # +-2 slices
                # _dist_thres = (kpts[:, axis_idx + 3] / 3)
                _dist_thres = 2.0
                keep = torch.abs(kpts[:, axis_idx] - slice_idx - 0.5) <= _dist_thres

                keep_kpts = kpts[keep]
                keep_radius = _squared_radius[keep] ** 0.5
                assert len(keep_kpts) == len(keep_radius)

                # for _kpt in keep_kpts:
                #     _sigma = PARTICLE_SIGMAS[list(PARTICLE_SIGMAS.keys())[int(_kpt[-1])]] * 3
                #     print(f'class={_kpt[-1]} sigma={_sigma} sigma_ZYX={_kpt[3], _kpt[4], _kpt[5]}')
                
                # print(f'slice {slice_idx} keep {len(keep_kpts)} keypoints')
                
                for kpt, radius in zip(keep_kpts, keep_radius.tolist()):
                    cov33 = kpt[[3, 6, 7, 6, 4, 8, 7, 8, 5]].reshape(3,3)
                    h_axis, w_axis = _tmp
                    mean = [kpt[w_axis], kpt[h_axis]]
                    cov = np.array([[cov33[w_axis, w_axis], cov33[w_axis, h_axis]],
                                    [cov33[w_axis, h_axis], cov33[h_axis, h_axis]]
                                   ])                    
                    # Draw the ellipse on the image
                    # cv2.ellipse(tomo_slice, coord,  (round(kpt[3 + w_axis].item()), round(kpt[3 + h_axis].item())), 0, 0, 360, COLORS[int(kpt[-1])], thickness=2)
                    draw_ellipse_from_normal_distribution(tomo_slice, mean, cov, color=COLORS[int(kpt[-1])], thickness=2, expand=4.03)
                    
                viz_tomo_slices.append(tomo_slice)
                    
            row = concat_imgs(viz_tomo_slices, max_h = None, max_w = None, axis = 1, border_width = 5)
            display_img(row, max_h = None, max_w = None)

print('done')

# Load data

In [None]:
from byu.data.io import OpencvTomogramLoader

# df = pl.scan_csv('/home/dangnh36/datasets/.comp/byu/processed/gt.csv').filter(pl.col('num_motors') == 1).collect().sample(10)

# VERY CONFUSE
# TP-36-tomo_5764d6-500x928x960-spacing13.100000381469727.jpg

SELECT_TOMO_IDS = ['tomo_1cc887', 'tomo_1ab322', 'tomo_ed1c97',
                  'tomo_0a8f05', 'tomo_0363f2', 'tomo_10c564',
                   'tomo_319f79', 'tomo_4555b6', 'tomo_868255',
                   'tomo_b2ebbc', 'tomo_bede89', 'tomo_918e2b', 
                   'tomo_9c0253', 'tomo_b54396', 'tomo_dfc627'
                  ]
df = pl.scan_csv('/home/dangnh36/datasets/.comp/byu/processed/gt.csv').filter(pl.col('tomo_id').is_in(SELECT_TOMO_IDS)).collect()

# tmp = df.sort('num_motors', descending=True).group_by('num_motors').head(1)
# display(tmp)
select_tomo_ids = df['tomo_id'].to_list()
print(select_tomo_ids)
display(df)

all_tomo_ids = []
all_tomos = []
all_kpts = []
all_voxel_spacings = []

tomo_loader = OpencvTomogramLoader()
for tomo_id in tqdm(select_tomo_ids):
    sub = df.filter(pl.col('tomo_id') == tomo_id)
    voxel_spacing = sub[0, 'voxel_spacing']

    # load keypoints
    kpts = []
    if sub[0, 'num_motors'] == 0:
        assert sub[0, 'motor_zyx'] == '[]'
        kpts = np.zeros((0, 10))
    else:
        motor_zyxs = eval(sub[0, 'motor_zyx'])
        assert len(motor_zyxs) == sub[0, 'num_motors']
        for (z, y, x) in motor_zyxs:
            kpts.append(
                    [
                        x,
                        y,
                        z,
                        (1000 / voxel_spacing / 4.03) ** 2,  # cov_xx
                        (1000 / voxel_spacing / 4.03) ** 2,  # cov_yy
                        (1000 / voxel_spacing / 4.03) ** 2,  # cov_zz
                        0, # cov_xy
                        0, # cov_xz
                        0, # cov_yz
                        0,
                    ]
            )
    # load tomo
    tomo = tomo_loader.load(f'/home/dangnh36/datasets/.comp/byu/raw/train/{tomo_id}')
    all_tomo_ids.append(tomo_id)
    all_tomos.append(tomo)
    all_kpts.append(np.array(kpts))
    all_voxel_spacings.append(voxel_spacing)

len(all_kpts), len(all_tomos), [len(e) for e in all_kpts]

In [None]:
import gc
gc.collect()

# Test Invertd

In [None]:
data = {
    'image': torch.from_numpy(all_tomos[2]).permute(2,1,0)[None],
    'kpts': all_kpts[2][None]
}

In [None]:
from math import pi

pre_transform = T.Compose([
    T.RandSpatialCropd(keys=["image"], roi_size=[320, 240, 64], random_center=True, random_size=False, lazy=True),
    T.RandAffined(keys = ['image'], prob=1.0,
                  rotate_range=((0,0), (0.0), (0, 0)), shear_range=None, 
                  translate_range=((0,0), (0,0), (0, 0)),
                  scale_range=((-0.0, 0.0),(-0.0, 0.0),(-0.0, 0.0)),
                  spatial_size=None, mode='bilinear', padding_mode='constant', 
                  cache_grid=True, device=None, lazy=True),
    T.RandZoomd(keys=['image'], prob=1.0, min_zoom=(2.0, 1), max_zoom=(2.0, 1), mode='bilinear',
             padding_mode='constant', align_corners=False, keep_size=True, lazy=True),
    T.ScaleIntensityRangePercentilesd(
                keys=["image"],
                lower=5,
                upper=95,
                b_min=0.0,
                b_max=1.0,
                clip=True,
                relative=False,
                channel_wise=False,
            )
])

inv_transform = T.Invertd(
    keys=["image"],  # invert the `pred` data field, also support multiple fields
    transform=pre_transform,
    orig_keys=['image'],  # get the previously applied pre_transforms information on the `img` data field,
    # then invert `pred` based on this information. we can use same info
    # for multiple fields, also support different orig_keys for different fields
    nearest_interp=False,  # don't change the interpolation mode to "nearest" when inverting transforms
    # to ensure a smooth output, then execute `AsDiscreted` transform
    to_tensor=True,  # convert to PyTorch Tensor after inverting
)
print(pre_transform, inv_transform)

In [None]:
ret = pre_transform(data)
print(ret.keys())
print(ret['image'].shape, ret['image'].dtype)
inv_ret = inv_transform(ret)
print(inv_ret.keys())
print(inv_ret['image'].shape, inv_ret['image'].dtype)

In [None]:
image = inv_ret['image'][0].cpu().numpy()
print(image.min(), image.max(), image.shape)
image = (np.clip(image, a_min=0, a_max=1.0) * 255).astype(np.uint8)
view(image = image)

# Used transforms

## Spatial

In [None]:
# Rotate
T.OneOf(
    transforms=[
        T.RandRotated(keys = ['image'], range_x=(-pi/6, pi/6), range_y=(-pi/6, pi/6), range_z=(0, 2*pi), prob=1.0, keep_size=True, mode='bilinear', padding_mode='zeros', align_corners=ALIGN_CORNERS, lazy=LAZY),
    ],
    weights=None,
    log_stats=LOG_STATS,
    lazy=LAZY,
    overrides=OVERRIDES
)

# Flip
T.RandFlipd(keys = ['image'], prob=0.5, spatial_axis=0, lazy=LAZY),
T.RandFlipd(keys = ['image'], prob=0.5, spatial_axis=1, lazy=LAZY),
T.RandFlipd(keys = ['image'], prob=0.5, spatial_axis=2, lazy=LAZY),
# T.Rotate180()



# Affine + RandZoom
T.RandAffined(keys = ['image'],
                  prob=1.0,
                  rotate_range=((-pi/6, pi/6), (-pi/6, pi/6), (0, 2*pi)),
                  shear_range=((-0.2, 0.2), (-0.2, 0.2), (-0.2, 0.2)),
                  translate_range=((0,0), (0,0), (0, 0)),
                  scale_range=((-0.3, 0.3), (-0.3, 0.3), (-0.3, 0.3)), # max_skew_xy = 1.3 / 0.7 = 1.86
                  spatial_size=None,
                  mode='bilinear', # bilinear, nearest
                  padding_mode='constant',
                  cache_grid=True,
                  device=None,
                  lazy=LAZY)
T.RandZoomd(keys=['image'], prob=1.0, min_zoom=(0.25, 2.0), max_zoom=(0.25, 2.0), mode='bilinear',
             padding_mode='constant', align_corners=ALIGN_CORNERS, keep_size=False, lazy=LAZY)

# not lazy
T.RandGridDistortiond(keys=['image'], num_cells=(10,10,2), prob=1.0, distort_limit=(-0.1, 0.1), mode='bilinear', padding_mode='constant', device=None)
T.Rand3DElasticd(keys = ['image'], sigma_range = (11, 11),
                     magnitude_range = (5, 5),
                     prob=1.0,
                     spatial_size=None,
                     mode='bilinear',
                     padding_mode='constant'
)
# must include
T.RandSimulateLowResolutiond(keys=['image'], prob=1.0, downsample_mode='nearest', upsample_mode='trilinear', zoom_range=(0.3, 0.3), align_corners=ALIGN_CORNERS)

# exponent noise
# note: (0.5, 1.0) is half of (1.0, 2.0) ==> pixel intensities tend to be darker
T.RandSmoothFieldAdjustContrastd(keys = ['image'],
                                 spatial_size = (320, 320, 64),
                                 rand_size = (80, 80, 16), pad=0, mode='area',
                                 align_corners=None, prob=1.0, gamma=(0.5, 2.0)
                                )
# multiplicative noise
T.RandSmoothFieldAdjustIntensityd(
        keys = ['image'],
        spatial_size = (320, 320, 64),
        rand_size = (80, 80, 16),
        pad=0, mode='area', align_corners=None, prob=1.0,
        gamma=(1.5, 1.51) # 0.5, 1.5
    )
T.RandSmoothDeformd(keys = ['image'], spatial_size = (320, 320, 64),
                       rand_size = (80, 80, 16), pad=0, field_mode='area', align_corners=None,
                       prob=1.0, def_range=(-0.02, 0.02), grid_mode='nearest', 
                       grid_padding_mode='zeros', grid_align_corners=ALIGN_CORNERS
                      )

T.Transposed

## Intensity

In [None]:


# LOCAL NOISE
T.RandGaussianNoised(['image'], prob=1.0, mean=0.0, std=0.2, sample_std=True)

# GLOBAL INTENSITY CHANGE
# mean shift
T.RandShiftIntensityd(['image'], offsets = (-0.35, 0.35), safe=False, prob=1.0, channel_wise=True)
T.RandStdShiftIntensityd(['image'], (-1.35, 1.35), prob=1.0, nonzero=False, channel_wise=True)
# std scale (multiplicative)
T.RandScaleIntensityFixedMeand(['image'], prob=1.0, factors=(-0.3, 0.3), fixed_mean=True, preserve_range=False)


# mean/std scale (multiplicative)
T.RandScaleIntensityd(['image'], factors = (-0.6, 0.4), prob=1.0, channel_wise=True)
# mean/std polynomial (x**gamma)
T.RandAdjustContrastd(['image'], prob=1.0, gamma=(0.5, 1.5), invert_image=False, retain_stats=False)
# histogram modification
T.RandHistogramShiftd(['image'], num_control_points=(6,15), prob=1.0)

# SMOOTHEN
T.MedianSmoothd(['image'], radius = 1) # slow on CPU, radius >=2 -> large RAM
T.RandGaussianSmoothd(['image'], sigma_x=(0.5, 1.25), sigma_y=(0.5, 1.25), sigma_z=(0.5, 1.25), prob=1.0, approx='erf')


# DROPOUT


# WE NEED READ MORE ABOUT THIS
T.RandBiasFieldd(['image'], degree=3, coeff_range=(-0.5, -0.5), prob=1.0)
T.RandGaussianSharpend(['image'], sigma1_x=(0.5, 1.0), sigma1_y=(0.5, 1.0), sigma1_z=(0.5, 1.0),
                      sigma2_x=0.5, sigma2_y=0.5, sigma2_z=0.5,
                      alpha=(10.0, 30.0), approx='erf', prob=1.0)
T.RandGibbsNoised(['image'], prob=1.0, alpha=(0.0, 0.7))


# LAST PREPROCESS STEP
T.HistogramNormalized(['image'], num_bins=256, min=0, max=1.0, mask=None)

# Test each transforms

In [None]:
transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 320, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ),
    T.RandRotated(keys = ['image'], range_x=(-pi/6, pi/6), range_y=(-pi/6, pi/6), range_z=(0, 2*pi), prob=1.0, keep_size=True, mode='bilinear', padding_mode='zeros', align_corners=ALIGN_CORNERS, lazy=LAZY),
]

LAZY = True
viz_transform(all_tomos, all_kpts, transforms, seed = None, view_3d = False)

In [None]:
transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 320, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ),
    T.RandFlipd(keys = ['image'], prob=1.0, spatial_axis=0, lazy=LAZY),
    T.RandFlipd(keys = ['image'], prob=1.0, spatial_axis=1, lazy=LAZY),
    T.RandFlipd(keys = ['image'], prob=1.0, spatial_axis=2, lazy=LAZY),
]

LAZY = True
viz_transform(all_tomos, all_kpts, transforms, seed = None, view_3d = False)

In [None]:
transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 320, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ),
    T.RandZoomd(keys=['image'], prob=1.0, min_zoom=(0.25, 2.0), max_zoom=(0.25, 2.0), mode='bilinear',
             padding_mode='constant', align_corners=ALIGN_CORNERS, keep_size=False, lazy=LAZY),
]

LAZY = True
viz_transform(all_tomos, all_kpts, transforms, seed = None, view_3d = False)

In [None]:
transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 200, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ),
    T.RandAffined(keys = ['image'],
                  prob=1.0,
                  rotate_range=None,
                  shear_range=((0.2, 0.2), (0.2, 0.2), (0.2, 0.2)), # HWD -> parallel to W, D, H
                  translate_range=None,
                  scale_range=None,
                  spatial_size=None,
                  mode='bilinear', # bilinear, nearest
                  padding_mode='constant',
                  cache_grid=True,
                  device=None,
                  lazy=LAZY)
]

LAZY = False
viz_transform(all_tomos, all_kpts, transforms, seed = None, view_3d = False)

In [None]:
transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 320, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ),
    T.RandGridDistortiond(keys=['image'], num_cells=(10,10,2), prob=1.0, distort_limit=(-0.15, 0.15), mode='bilinear', padding_mode='constant', device=None)
]

LAZY = False
viz_transform(all_tomos, all_kpts, transforms, seed = None, view_3d = False)

In [None]:
transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 320, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ),
    T.Rand3DElasticd(keys = ['image'], sigma_range = (11, 11),
                     magnitude_range = (25, 25),
                     prob=1.0,
                     spatial_size=None,
                     mode='bilinear',
                     padding_mode='constant'
    ),
]

LAZY = False
viz_transform(all_tomos, all_kpts, transforms, seed = None, view_3d = False)

In [None]:
transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 320, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ),
    T.RandSimulateLowResolutiond(keys=['image'], prob=1.0, downsample_mode='nearest', upsample_mode='trilinear', zoom_range=(0.3, 0.3), align_corners=ALIGN_CORNERS)
]

LAZY = True
viz_transform(all_tomos, all_kpts, transforms, seed = None, view_3d = False)

In [None]:
transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 320, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ),
    T.RandSmoothFieldAdjustContrastd(keys = ['image'],
                                     spatial_size = (320, 320, 64),
                                     rand_size = (80, 80, 16), pad=0, mode='area',
                                     align_corners=None, prob=1.0, gamma=(0.5, 2.0)),
]

LAZY = True
viz_transform(all_tomos, all_kpts, transforms, seed = None, view_3d = False)

In [None]:
transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 320, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ), 
    T.RandSmoothFieldAdjustIntensityd(
        keys = ['image'],
        spatial_size = (320, 320, 64),
        rand_size = (80, 80, 16),
        pad=0, mode='area', align_corners=None, prob=1.0,
        gamma=(1.5, 1.51) # 0.5, 1.5
    )
]

LAZY = True
viz_transform(all_tomos, all_kpts, transforms, seed = None, view_3d = False)

In [None]:
H, W, D = (320, 320, 64)

transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 320, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ),
    # T.RandSmoothDeformd(keys = ['image'], spatial_size = (320, 320, 64),
    #                    rand_size = (80, 80, 16), pad=0, field_mode='area', align_corners=None,
    #                    prob=1.0, def_range=(-0.02, 0.02), grid_mode='nearest', 
    #                    grid_padding_mode='zeros', grid_align_corners=ALIGN_CORNERS
    #                   ),

    T.RandSmoothDeformd(
            keys=['image'],
            spatial_size=(H, W, D),
            rand_size=(H // 4, W // 4, D // 4),
            pad=0,
            field_mode="area",
            align_corners=None,
            prob=1.0,
            # 6 is min of particle radius (apo-ferritin)
            def_range=(-6.0 / max(H, W, D), 6.0 / max(H, W, D)),
            grid_mode="nearest",
            grid_padding_mode="zeros",
            grid_align_corners=ALIGN_CORNERS,
        ),
]

LAZY = True
viz_transform(all_tomos, all_kpts, transforms, seed = None, view_3d = False)

In [None]:
transforms = [
    T.RandSpatialCropd(
        keys=["image"],
        roi_size=[320, 240, 64],
        # max_roi_size=(192, 192, 192),
        random_center=True,
        random_size=False,
        lazy=LAZY,
    ),
    CT.RandRotate180d(keys=['image'], prob=1.0, spatial_axes=(0, 1), lazy=LAZY),
]

LAZY = True
viz_transform(all_tomos, all_kpts, transforms, seed = 42, view_3d = False)

# Compare

In [None]:
transforms = [
    CT.CustomSpacingd(keys=['image'], spacing_key='spacing_scale', mode='bilinear', padding_mode='zeros', align_corners=ALIGN_CORNERS, lazy=LAZY),
]
LAZY = True
viz_transform(all_tomo_ids[:], all_tomos[:], all_kpts[:], all_voxel_spacings[:], target_voxel_spacing = 16.0, transforms = transforms, seed = 611, view_3d = False)

In [None]:
transforms = [
    CT.CustomSpacingd(keys=['image'], spacing_key='spacing_scale', mode='bilinear', padding_mode='zeros', align_corners=ALIGN_CORNERS, lazy=LAZY),
    
    # T.RandSpatialCropd(keys=["image"], roi_size=[224, 224, 112], random_center=True, random_size=False, lazy=LAZY),
    
    RandSpatialCropByKeypointsd(
        keys = ['image'],
        keypoints_key = 'spaced_kpts',
        roi_size=(224, 224, 112),
        max_roi_size = None,
        random_center = True,
        random_size = False,
        margin = 0.4,
        auto_correct_center = False,
        allow_missing_keys = False,
        lazy = LAZY),
]
LAZY = True
viz_transform(all_tomo_ids[:], all_tomos[:], all_kpts[:], all_voxel_spacings[:], target_voxel_spacing = 16.0, transforms = transforms, seed = 611, view_3d = False)

In [None]:
# transforms = [
#     T.RandSpatialCropd(keys=["image"], roi_size=[320, 320, 64], random_center=True, random_size=False, lazy=LAZY),
#     T.RandRotate90d(keys=["image"], prob=1.0, max_k=1, spatial_axes=(0, 1), lazy=LAZY),
#     # T.RandZoomd(keys=['image'], prob=1.0, min_zoom=(0.8, 2.0, 1.0), max_zoom=(0.8, 2.0, 1.0), mode='bilinear',
#     #          padding_mode='constant', align_corners=ALIGN_CORNERS, keep_size=True, lazy=LAZY),
#     # # T.RandAffined(keys = ['image'], prob=1.0, rotate_range=((0,0), (0,0), (pi/4, pi/3)), shear_range=None, translate_range=((0,0), (0,0), (0, 0)), 
#     # #               scale_range=((0, 0),(0,0),(0,0)), spatial_size=None, mode='bilinear', padding_mode='constant',
#     # #               cache_grid=True, device=None, lazy=LAZY)
#     # T.RandAffined(keys = ['image'], prob=1.0,
#     #               rotate_range=((0,0), (0,0), (0,2*pi)),
#     #               shear_range=((0.2, 0.2), (0.2, 0.2), (0.2, 0.2)),
#     #               translate_range=((0,0), (0,0), (0, 0)), 
#     #               scale_range=((-0.2, 0.2),(-0.2,0.2),(-0.2,0.2)), spatial_size=None, mode='bilinear', padding_mode='constant',
#     #               cache_grid=True, device=None, lazy=LAZY)
# ]
# LAZY = True
# viz_transform(all_tomos, all_kpts, transforms, seed = 611, view_3d = False)

In [None]:
# spacing_transform = CT.CustomSpacingd(keys=['image'], spacing_key='spacing_scale', mode='bilinear', padding_mode='zeros', align_corners=ALIGN_CORNERS, lazy=LAZY)
# target_voxel_spacing = 32.0

# for row in tqdm(df.iter_rows(named = True), total = len(df)):
#     tomo_id = row['tomo_id']
#     ori_voxel_spacing = row['voxel_spacing']
#     tomo = np.load(f'/home/dangnh36/datasets/byu/processed/npy/{tomo_id}.npy')
#     data = {
#             'image': np.transpose(tomo, (2, 1, 0))[None], # ZYX --> 1XYZ
#             "spacing_scale": target_voxel_spacing / ori_voxel_spacing
#     }
#     data = spacing_transform(data)
#     real_shape = list(data['image'].peek_pending_shape()[::-1])
#     expected_shape = compute_spacing_shape(tomo.shape, ori_voxel_spacing, target_voxel_spacing, scale_extent = False)
#     print('ori:', tomo.shape)
#     print('expect:', expected_shape)
#     print('real:', real_shape)
#     assert expected_shape == real_shape, f'{expected_shape} != {real_shape}'

In [None]:
# LOCAL NOISE
# T.RandGaussianNoised(['image'], prob=1.0, mean=0.0, std=0.2, sample_std=True)

# GLOBAL INTENSITY CHANGE
# mean shift
# T.RandShiftIntensityd(['image'], offsets = (-0.35, 0.35), safe=False, prob=1.0, channel_wise=True)
# T.RandStdShiftIntensityd(['image'], (-1.35, 1.35), prob=1.0, nonzero=False, channel_wise=True)

# std scale (multiplicative)
# T.RandScaleIntensityFixedMeand(['image'], prob=1.0, factors=(-0.3, 0.3), fixed_mean=True, preserve_range=False)


# mean/std scale (multiplicative)
# T.RandScaleIntensityd(['image'], factors = (-0.6, 0.4), prob=1.0, channel_wise=True)
# mean/std polynomial (x**gamma)
# T.RandAdjustContrastd(['image'], prob=1.0, gamma=(0.5, 1.5), invert_image=False, retain_stats=False)
# histogram modification
# T.RandHistogramShiftd(['image'], num_control_points=(6,15), prob=1.0)

# SMOOTHEN
# T.MedianSmoothd(['image'], radius = 1) # slow on CPU, radius >=2 -> large RAM
# T.RandGaussianSmoothd(['image'], sigma_x=(0.5, 1.25), sigma_y=(0.5, 1.25), sigma_z=(0.5, 1.25), prob=1.0, approx='erf')


# DROPOUT


# WE NEED READ MORE ABOUT THIS
# T.RandBiasFieldd(['image'], degree=3, coeff_range=(-0.5, -0.5), prob=1.0)
# T.RandGaussianSharpend(['image'], sigma1_x=(0.5, 1.0), sigma1_y=(0.5, 1.0), sigma1_z=(0.5, 1.0),
#                       sigma2_x=0.5, sigma2_y=0.5, sigma2_z=0.5,
#                       alpha=(10.0, 30.0), approx='erf', prob=1.0)
# T.RandGibbsNoised(['image'], prob=1.0, alpha=(0.0, 0.7))


# LAST PREPROCESS STEP
T.HistogramNormalized(['image'], num_bins=256, min=0, max=1.0, mask=None)

In [None]:
# BYU
T.RandGaussianNoised(['image'], prob=1.0, mean=0, std=30, sample_std=True) # gauss noise

# T.RandShiftIntensityd(['image'], offsets = (-40, 80), safe=False, prob=1.0, channel_wise=True) # mean shift
# T.RandStdShiftIntensityd(['image'], (-1.0, 1.5), prob=1.0, nonzero=False, channel_wise=True) # mean shift

# OneOf
# T.RandScaleIntensityFixedMeand(['image'], prob=1.0, factors=(-0.7, 0.0), fixed_mean=True, preserve_range=False) # decrease std or contrast, higher prob -> harder
# T.RandScaleIntensityFixedMeand(['image'], prob=1.0, factors=(0.0, 1.0), fixed_mean=True, preserve_range=False) # increase std or contrast, lower prob -> easier

# T.RandScaleIntensityd(['image'], factors = (-0.7, 0.7), prob=1.0, channel_wise=True) # mean/std multiplicative x*(1+factor)
# T.RandAdjustContrastd(['image'], prob=1.0, gamma=(0.25, 1.75), invert_image=False, retain_stats=False) # mean/std polynomial x**gamma

# T.RandHistogramShiftd(['image'], num_control_points=(10,20), prob=1.0)

# FLAG
# T.MedianSmoothd(['image'], radius = 2) # slow on CPU, radius >=2 -> large RAM
# T.RandGaussianSmoothd(['image'], sigma_x=(0.5, 4.5), sigma_y=(0.5, 4.5), sigma_z=(0.5, 4.5), prob=1.0, approx='erf')

# make it easier with (3.0, 1.0) -> really sharpen, not blurry like (10.0, 1.0)
# T.RandGaussianSharpend(['image'], sigma1_x=(1.0, 5.0), sigma1_y=(1.0, 5.0), sigma1_z=(1.0, 5.0),
#                       sigma2_x=(1.0, 1.0), sigma2_y=(1.0, 1.0), sigma2_z=(1.0, 1.0),
#                       alpha=(10.0, 30.0), approx='erf', prob=1.0)

# FLAG
# T.HistogramNormalized(['image'], num_bins=256, min=0, max=255.0, mask=None)


### DROPOUT
# CT.RandCoarseDropoutWithKeypointsd(keys=['image'], keypoints_key = 'kpts', holes=3, spatial_size=(500 // S, 500//S, 500 // S), dropout_holes=True,
#                                        fill_value=(0, 255), max_holes=10, max_spatial_size=(2000 // S, 2000 // S, 2000 // S), remove = "patch", keypoint_margins='auto', max_retries = 10, prob=1.0)



In [None]:
# BYU
T.RandGaussianNoised(['image'], prob=1.0, mean=0, std=30, sample_std=True) # gauss noise

T.RandShiftIntensityd(['image'], offsets = (-40, 80), safe=False, prob=1.0, channel_wise=True) # mean shift
T.RandStdShiftIntensityd(['image'], (-1.0, 1.5), prob=1.0, nonzero=False, channel_wise=True) # mean shift

# OneOf
T.RandScaleIntensityFixedMeand(['image'], prob=1.0, factors=(-0.7, 0.0), fixed_mean=True, preserve_range=False) # decrease std or contrast, higher prob -> harder
T.RandScaleIntensityFixedMeand(['image'], prob=1.0, factors=(0.0, 1.0), fixed_mean=True, preserve_range=False) # increase std or contrast, lower prob -> easier

T.RandScaleIntensityd(['image'], factors = (-0.7, 0.7), prob=1.0, channel_wise=True) # mean/std multiplicative x*(1+factor)
T.RandAdjustContrastd(['image'], prob=1.0, gamma=(0.25, 1.75), invert_image=False, retain_stats=False) # mean/std polynomial x**gamma

T.RandHistogramShiftd(['image'], num_control_points=(10,20), prob=1.0)

# FLAG
T.MedianSmoothd(['image'], radius = 2) # slow on CPU, radius >=2 -> large RAM
T.RandGaussianSmoothd(['image'], sigma_x=(0.5, 4.5), sigma_y=(0.5, 4.5), sigma_z=(0.5, 4.5), prob=1.0, approx='erf')

# make it easier with (3.0, 1.0) -> really sharpen, not blurry like (10.0, 1.0)
T.RandGaussianSharpend(['image'], sigma1_x=(1.0, 5.0), sigma1_y=(1.0, 5.0), sigma1_z=(1.0, 5.0),
                      sigma2_x=(1.0, 1.0), sigma2_y=(1.0, 1.0), sigma2_z=(1.0, 1.0),
                      alpha=(10.0, 30.0), approx='erf', prob=1.0)

# FLAG
T.HistogramNormalized(['image'], num_bins=256, min=0, max=255.0, mask=None)


### DROPOUT
CT.RandCoarseDropoutWithKeypointsd(keys=['image'], keypoints_key = 'kpts', holes=3, spatial_size=(500 // S, 500//S, 500 // S), dropout_holes=True,
                                       fill_value=(0, 255), max_holes=10, max_spatial_size=(2000 // S, 2000 // S, 2000 // S), remove = "patch", keypoint_margins='auto', max_retries = 10, prob=1.0)



In [None]:
### BACKUP
transforms = [
    CT.CustomSpacingd(keys=['image'], spacing_key='spacing_scale', mode='bilinear', padding_mode='zeros', align_corners=ALIGN_CORNERS, lazy=LAZY),
    # T.RandSpatialCropd(keys=["image"], roi_size=[224, 224, 112], random_center=True, random_size=False, lazy=LAZY),
    CT.RandSpatialCropByKeypointsd(
        keys = ['image'],
        keypoints_key = 'spaced_kpts',
        roi_size=(H, W, D),
        max_roi_size = None,
        random_center = True,
        random_size = False,
        margin = 0.25,
        auto_correct_center = True,
        allow_missing_keys = False,
        lazy = LAZY),
    T.RandZoomd(
            keys=['image'],
            prob=1.0,
            min_zoom=(
                0.6, 0.6, 0.6
            ),
            max_zoom=(
                1.2, 1.2, 1.2
            ),
            mode="bilinear",
            padding_mode="constant",
            align_corners=False,
            dtype=None,
            keep_size=True,
            lazy=LAZY,
    ),
    T.RandAffined(
        keys=['image'],
        prob=1.0,
        rotate_range=(
            (
                -math.radians(15),
                math.radians(15),
            ),
            (
                -math.radians(15),
                math.radians(15),
            ),
            (0, 2 * 3.14),
        ),
        shear_range=(
            (-0.2, 0.2),
            (-0.2, 0.2), # along Z
            (-0.2, 0.2),
        ),
        translate_range=None,
        scale_range=(
            (-0.3, 0.3),
            (-0.3, 0.3),
            (-0.3, 0.3),
        ),
        spatial_size=None,
        mode="bilinear",  # bilinear, nearest
        padding_mode="constant",
        cache_grid=True,
        lazy=LAZY,
    ),
    # T.RandAffined(
    #     keys=['image'],
    #     prob=1.0,
    #     rotate_range=None,
    #     shear_range=None,
    #     translate_range=(
    #                 (-max(0, H/2 - expect_avg_radius), max(0, H/2 - expect_avg_radius)),
    #                 (-max(0, W/2 - expect_avg_radius), max(0, W/2 - expect_avg_radius)),
    #                 (-max(0, D/2 - expect_avg_radius/4), max(0, D/2 - expect_avg_radius/4)),
    #             ),
    #     scale_range=None,
    #     spatial_size=None,
    #     mode="bilinear",  # bilinear, nearest
    #     padding_mode="constant",
    #     cache_grid=True,
    #     lazy=LAZY,
    # ),
]

In [None]:
transforms = [
    CT.CustomSpacingd(keys=['image'], spacing_key='spacing_scale', mode='bilinear', padding_mode='zeros', align_corners=ALIGN_CORNERS, lazy=LAZY),
]
LAZY = True
viz_transform(all_tomo_ids[:], all_tomos[:], all_kpts[:], all_voxel_spacings[:], target_voxel_spacing = 16.0, transforms = transforms, seed = 611, view_3d = False)

In [None]:
2000 / 16

In [None]:
import importlib
importlib.reload(CT)

S = 16.0
H, W, D = 768, 768, 128
# H, W, D = 448, 448, 224
# H, W, D = 224, 224, 128

expect_avg_radius = 1000. / S

transforms = [
    CT.CustomSpacingd(keys=['image'], spacing_key='spacing_scale', mode='bilinear', padding_mode='zeros', align_corners=ALIGN_CORNERS, lazy=LAZY),
    # T.RandSpatialCropd(keys=["image"], roi_size=[224, 224, 112], random_center=True, random_size=False, lazy=LAZY),
    CT.RandSpatialCropByKeypointsd(
        keys = ['image'],
        keypoints_key = 'spaced_kpts',
        roi_size=(H, W, D),
        max_roi_size = None,
        random_center = True,
        random_size = False,
        margin = 0.25,
        auto_correct_center = True,
        allow_missing_keys = False,
        lazy = LAZY),
    T.RandZoomd(
            keys=['image'],
            prob=1.0,
            min_zoom=(
                0.6, 0.6, 0.6
            ),
            max_zoom=(
                1.2, 1.2, 1.2
            ),
            mode="bilinear",
            padding_mode="constant",
            align_corners=False,
            dtype=None,
            keep_size=True,
            lazy=LAZY,
    ),
    T.RandAffined(
        keys=['image'],
        prob=1.0,
        rotate_range=(
            (
                -math.radians(15),
                math.radians(15),
            ),
            (
                -math.radians(15),
                math.radians(15),
            ),
            (0, 2 * 3.14),
        ),
        shear_range=(
            (-0.2, 0.2),
            (-0.2, 0.2), # along Z
            (-0.2, 0.2),
        ),
        translate_range=None,
        scale_range=(
            (-0.3, 0.3),
            (-0.3, 0.3),
            (-0.3, 0.3),
        ),
        spatial_size=None,
        mode="bilinear",  # bilinear, nearest
        padding_mode="constant",
        cache_grid=True,
        lazy=LAZY,
    ),
    # T.RandAffined(
    #     keys=['image'],
    #     prob=1.0,
    #     rotate_range=None,
    #     shear_range=None,
    #     translate_range=(
    #                 (-max(0, H/2 - expect_avg_radius), max(0, H/2 - expect_avg_radius)),
    #                 (-max(0, W/2 - expect_avg_radius), max(0, W/2 - expect_avg_radius)),
    #                 (-max(0, D/2 - expect_avg_radius/4), max(0, D/2 - expect_avg_radius/4)),
    #             ),
    #     scale_range=None,
    #     spatial_size=None,
    #     mode="bilinear",  # bilinear, nearest
    #     padding_mode="constant",
    #     cache_grid=True,
    #     lazy=LAZY,
    # ),
    
    # T.RandStdShiftIntensityd(
    #     ['image'],
    #     (1.2, 1.2),
    #     prob=1.0,
    #     nonzero=False,
    #     channel_wise=True,
    # ),
    # T.RandScaleIntensityFixedMeand(
    #     ['image'],
    #     prob=1.0,
    #     factors=(-0.6, -0.6),
    #     fixed_mean=True,
    #     preserve_range=False,
    # ),

    T.RandAdjustContrastd(
                    ['image'],
                    prob=1.0,
                    gamma=(0.25, 0.25),
                    invert_image=False,
                    retain_stats=False,
    ),
]

LAZY = True
K = 15
viz_transform(all_tomo_ids[:K], all_tomos[:K], all_kpts[:K], all_voxel_spacings[:K], target_voxel_spacing = S, transforms = transforms, seed = 6111998, view_3d = False)