In [1]:
import numpy as np
import torch
import gc
import random
import os
from scipy.ndimage import gaussian_filter
import SimpleITK as sitk
from monai.transforms import MapTransform
from monai.transforms import (
    Compose,
    NormalizeIntensityd,
    RandFlipd,
    RandSpatialCropd,
    ResizeWithPadOrCropd,
    RandRotated,
    ToDeviced,
    Zoomd,
    SqueezeDimd
)

aug_params = {
    "patch_size": (100,100,100),
    "final_size":   (100,100,100),
    "flip_prob":  0.5,
    "rot_prob":   0.5,
    "scale_prob": 1.0,
    "rot_range":  np.pi,
    "scale_range": 0.2
}

def tuple_op(operator, tuple1, tuple2):
    if len(tuple1) != len(tuple2):
        raise ValueError("Tuples must have the same length")
    
    # Apply the operator element-wise using zip and a list comprehension
    return tuple(operator(x, y) for x, y in zip(tuple1, tuple2))

class RandCropMMapd(MapTransform):
    def __init__(self, keys, roi_size):
        super().__init__(keys)
        self.roi_size = roi_size

    def __call__(self, data):
        shape = data['src'].shape
        ranges = tuple_op(lambda x,y: x-y, shape, self.roi_size)
        start  = tuple(random.randint(0, ranges[i]) for i in range(3))
        stop   = tuple_op(lambda x,y: x+y, start, self.roi_size)
        
        result = {}
        for key in self.keys:
            crop = data[key][tuple(slice(j, k) for j, k in zip(start, stop))].copy()
            result[key] = np.expand_dims(crop, axis=0)
        return result

class ToTorchd(MapTransform):
    def __init__(self, keys):
        super().__init__(keys)

    def __call__(self, data):
        for key in self.keys:
            data[key] = torch.from_numpy(data[key])
        return data

def gaussian_filter_sitk(volume, radius):
    image = sitk.GetImageFromArray(volume)
    smoothed_image = sitk.SmoothingRecursiveGaussian(image, radius)
    return sitk.GetArrayFromImage(smoothed_image)

def get_heatmap(shape, pts, radius): 
    volume = np.zeros(shape, dtype=np.float32)
    for pt in pts: 
        pt = tuple(np.rint(pt).astype(int))
        if all(0 <= pt[i] < shape[i] for i in range(3)):
            volume[pt] = 1.0
    volume = gaussian_filter_sitk(volume, radius)
    volume = volume / np.max(volume)
    return volume.astype(np.float16)
#testing

def get_heatmap_old(shape, pts, radius):
    volume = np.zeros(shape, dtype=np.float32)
    for pt in pts: 
        pt = tuple(pt)
        volume[pt] = 1.0
    volume = gaussian_filter(
        volume, 
        radius, 
        axes=(0,1,2)
        )
    volume = volume / np.max(volume)
    return volume.astype(np.float16)


import time
import random

def rand_aug(sample: dict, aug_params=None, gpu: bool = True):
    device = "cuda" if gpu else "cpu"
    keys = ["src", "tgt"]
    mode = ["bilinear", "bilinear"]

    scale_range = [
        random.uniform(1.0 - aug_params["scale_range"], 1.0 + aug_params["scale_range"])
        for _ in range(3)
    ]

    transforms = [
        ("RandCropMMapd", RandCropMMapd(keys=keys, roi_size=aug_params["patch_size"])),
        ("ToTorchd", ToTorchd(keys=keys)),
        ("ToDeviced", ToDeviced(keys=keys, device=device)),
        ("NormalizeIntensityd", NormalizeIntensityd(keys="src")),
        ("RandFlipd", RandFlipd(keys=keys, spatial_axis=[0, 1, 2], prob=aug_params["flip_prob"])),
        ("Zoomd", Zoomd(keys=keys, zoom=scale_range, mode="trilinear", keep_size=False, padding_mode="zeros")),
        ("RandRotated", RandRotated(
            keys=keys, range_x=aug_params["rot_range"], prob=aug_params["rot_prob"],  
            keep_size=False, padding_mode="zeros", mode=mode
        )),
        ("ResizeWithPadOrCropd", ResizeWithPadOrCropd(
            keys=keys, spatial_size=aug_params["final_size"], method="symmetric", mode="constant"
        )),
    ]

    times = {}
    total_start = time.time()

    for name, transform in transforms:
        start = time.time()
        sample = transform(sample)
        times[name] = time.time() - start

    total_time = time.time() - total_start
    times["Total"] = total_time

    return sample, times



In [2]:
def release_mmap_array(mmap_array):
    if hasattr(mmap_array, 'base') and mmap_array.base is not None:
        try:
            mmap_array.base.close()  # Close the file descriptor explicitly
        except AttributeError:
            pass  # Some numpy versions may not expose .base.close()
    del mmap_array

In [8]:
src_dir = '../data/hm30rad/src'
tgt_dir = '../data/hm30rad/tgt'

names = os.listdir(src_dir)

idxs = random.sample(range(1, 647 + 1), 16)
del_time=0
for idx in idxs:
    sample = {
                'src': np.load(os.path.join(src_dir,f'{names[idx]}'), mmap_mode='r'),
                'tgt': np.load(os.path.join(tgt_dir,f'{names[idx]}'), mmap_mode='r')
            }
    aug, times = rand_aug(
            sample=sample,
            aug_params=aug_params,
            gpu=False
        )
    # print(times)
    del_start = time.time()
    for key in aug.keys():
        release_mmap_array(sample[key])
        # gc.collect()
    del_time += ( time.time() - del_start)
    

In [4]:
del_time
#bottleneck seems to be how mmapped arrays are being released

0.03900265693664551