In [6]:
import load
import matplotlib.pyplot as plt


root = load.get_root()

picks = load.get_picks_dict(root)

vol, coords, scales = load.get_run_volume_picks(root, level=0)

mask = load.get_picks_mask(vol.shape, picks, coords, int(scales[0]))


In [7]:
import numpy as np
import load
import visual
radii = [ 6,
          6,
          9,
          15,
          13,
          14 ]
for key in coords.keys():
    coords[key] = np.array(coords[key], dtype=np.int16)
coord_list = []
for key in coords.keys():
    coord_list.append(coords[key])

In [8]:
import torch
def create_exponential_heatmap_gpu(labels, volume_shape, points, radii, device="cuda"):
    heatmap = torch.zeros((labels, *volume_shape), dtype=torch.float32, device=device)
    zz, yy, xx = torch.meshgrid(
        torch.arange(volume_shape[0], device=device),
        torch.arange(volume_shape[1], device=device),
        torch.arange(volume_shape[2], device=device),
        indexing="ij"
    )

    for label in range(labels):
        radius = radii[label]  # Get the radius for this label
        for point in points[label]:
            distances = torch.sqrt((zz - point[0])**2 + (yy - point[1])**2 + (xx - point[2])**2)
            mask = distances <= radius
            decay = torch.exp(-distances / radius) - torch.exp(torch.tensor(-1.0, device=device))
            decay /= 1 - torch.exp(torch.tensor(-1.0, device=device))  # Normalize
            heatmap[label][mask] = torch.maximum(heatmap[label][mask], decay[mask])

    background_channel = torch.ones(volume_shape, dtype=torch.float32, device=device)
    for i in range(heatmap.shape[0]):
        background_channel -= heatmap[i]
    background_channel[background_channel < 0.0] = 0.0
    
    return torch.cat((background_channel.unsqueeze(0), heatmap), dim=0)

In [9]:
heatmap = create_exponential_heatmap_gpu(6, vol.shape, coord_list, radii)
heatmap = heatmap.to('cpu')

In [10]:
np.set_printoptions(threshold=np.inf)
np.unique(heatmap[0], return_counts=True)
# np.mean(heatmap[0])

(array([0.        , 0.01097548, 0.0143733 , 0.01643369, 0.0178827 ,
        0.02124977, 0.02980119, 0.02986512, 0.03027961, 0.03199065,
        0.04190373, 0.04306632, 0.04872155, 0.04908365, 0.05245662,
        0.0598177 , 0.06081414, 0.06259567, 0.06792495, 0.06836292,
        0.0709002 , 0.0820508 , 0.08679885, 0.09055224, 0.09115565,
        0.09151286, 0.09650379, 0.1014865 , 0.10202652, 0.10905701,
        0.11096799, 0.11165753, 0.11168101, 0.11712778, 0.11778611,
        0.12339619, 0.1356889 , 0.1357429 , 0.1375544 , 0.14047736,
        0.142335  , 0.14645427, 0.14997685, 0.15199739, 0.16056553,
        0.1614795 , 0.16306591, 0.16319817, 0.16636163, 0.17026281,
        0.17251879, 0.17604974, 0.17924875, 0.18246835, 0.18326867,
        0.1840961 , 0.19294119, 0.19351363, 0.19386268, 0.1973362 ,
        0.19747269, 0.19929272, 0.20388362, 0.20451066, 0.20988724,
        0.2105959 , 0.21288198, 0.21608031, 0.21708477, 0.21820366,
        0.21909148, 0.21998465, 0.22165346, 0.22

In [14]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact
import augment
%matplotlib inline

params = augment.aug_params
params['patch_size'] = (250,250,250)
params['final_size'] = (150,150,150)

samples = random_augmentation(vol, heatmap, num_samples=1, aug_params=params)

new_vol = samples[0]['source']
new_mask = samples[0]['target']  # Mask with interest points (non-zero values)

def plot_cross_section(i):
    plot_vol = new_vol[0]
    plot_mask = new_mask
    
    plt.figure(figsize=(15, 5))
    alpha = 0.4
    # Initialize a blank canvas for displaying the heatmaps
    combined_heatmap_x = np.zeros_like(plot_mask[0, i, :, :])  # For x-slice
    combined_heatmap_y = np.zeros_like(plot_mask[0, :, i, :])  # For y-slice
    combined_heatmap_z = np.zeros_like(plot_mask[0, :, :, i])  # For z-slice

    # Slice at x-coordinate
    plt.subplot(131)
    for idx in range(plot_mask.shape[0]):
        # Overlay each channel's heatmap with transparency
        combined_heatmap_x = np.maximum(combined_heatmap_x, plot_mask[0, i, :, :])
    # plt.imshow(plot_vol[i,:,:], cmap="Blues", alpha=alpha)
    plt.imshow(combined_heatmap_x, cmap="Reds", alpha=alpha)
    plt.title(f'Slice at x={i}')

    # Slice at y-coordinate
    plt.subplot(132)
    for idx in range(plot_mask.shape[0]):
        # Overlay each channel's heatmap with transparency
        combined_heatmap_y = np.maximum(combined_heatmap_y, plot_mask[0, :, i, :])
    # plt.imshow(plot_vol[:,i,:], cmap="Blues", alpha=alpha)
    plt.imshow(combined_heatmap_y, cmap="Reds", alpha=alpha)
    plt.title(f'Slice at y={i}')

    # Slice at z-coordinate
    plt.subplot(133)
    for idx in range(plot_mask.shape[0]):
        # Overlay each channel's heatmap with transparency
        combined_heatmap_z = np.maximum(combined_heatmap_z, plot_mask[0, :, :, i])
    # plt.imshow(plot_vol[:,:,i], cmap="Blues", alpha=alpha)
    plt.imshow(combined_heatmap_z, cmap="Reds", alpha=alpha)
    plt.title(f'Slice at z={i}')

    plt.show()

# Interactive Slider for scrolling through slices
interact(plot_cross_section, i=(0, new_vol.shape[1] - 1))

interactive(children=(IntSlider(value=74, description='i', max=149), Output()), _dom_classes=('widget-interact…

<function __main__.plot_cross_section(i)>

In [83]:
def create_exponential_heatmap(labels, volume_shape, points, radii):
    heatmap = np.zeros((labels, *volume_shape), dtype=np.float32)
    zz, yy, xx = np.meshgrid(
        np.arange(volume_shape[0]),
        np.arange(volume_shape[1]),
        np.arange(volume_shape[2]),
        indexing="ij"
    )

    for label in range(labels):
        radius = radii[label]  # Get the radius for this label
        for point in points[label]:
            distances = np.sqrt((zz - point[0])**2 + (yy - point[1])**2 + (xx - point[2])**2)
            mask = distances <= radius
            decay = np.exp(-distances / radius) - np.exp(-1)  # Normalize decay to start near 0
            decay /= 1 - np.exp(-1)  # Scale to make center 1.0
            heatmap[label][mask] = np.maximum(heatmap[label][mask], decay[mask])
            
    background_channel = np.ones(volume_shape, dtype=np.float32)
    for i in range(heatmap.shape[0]):
        background_channel = background_channel - heatmap[i]
    background_channel[background_channel < 0.0] = 0.0
    return np.concatenate((np.expand_dims(background_channel, axis=0), heatmap), axis=0)

In [11]:
import numpy as np
import zarr
import os
import monai
from monai.transforms import (
    Compose,
    NormalizeIntensityd,
    RandFlipd,
    RandAffined,
    RandSpatialCropd,
    SpatialPadd,
    ResizeWithPadOrCropd,
    RandRotated,
    SqueezeDimd
)
# Will probably add scale prob / range
aug_params = {
    "patch_size": (100,100,100),
    "final_size":   (100,100,100),
    "flip_prob":  0.5,
    "rot_prob":   1.0,
    "rot_range":  np.pi / 2
}

def random_augmentation(volume, 
                        mask,
                        points = None,
                        num_samples=10, 
                        aug_params=aug_params,
                        save=False,
                        dest=None,
                        filename=None,
                        mask_type="continuous"):
    """
    Augment 3D volume and mask with cropping, normalization, padding, flipping, and rotation.

    Parameters:
        volume (np.ndarray): source 3D volume (shape: (C) x D x H x W). Channel optional
        mask (np.ndarray): source 3D mask (shape: D x H x W).
        num_samples (int): Number of augmented samples to generate.
        aug_params (dict): parameters for augmentation {patch_size, final_size, flip_prob, rot_prob, rot_range(radians)}

    Returns:
        list: Augmented (volume, mask) pairs.
    """
    if mask_type == "continuous":
        mask_transform = 'bilinear'
    else:
        mask_transform = 'nearest'
        
    if save: print(f"Generating {filename} samples")
    
    if len(volume.shape) == 3:  volume = np.expand_dims(volume, axis=0)
    if len(mask.shape) == 3:  mask = np.expand_dims(mask, axis=0)

    sample_dict = {"source": volume, "target": mask}
    keys = ["source", "target"]
    mode = ['bilinear', mask_transform]
    #add points if necessary
    if points is not None:
        keys.append("points")
        if len(points.shape) == 3:  points = np.expand_dims(points, axis=0)
        sample_dict["points"] = points
        mode.append('nearest')

    augment = Compose([
        RandSpatialCropd(
            keys=keys, 
            roi_size=aug_params["patch_size"], 
            random_center=True, 
            random_size=False
        ),
        NormalizeIntensityd(
            keys="source"
        ),
        RandFlipd(
                keys=keys, 
                spatial_axis=[0, 1, 2], 
                prob=aug_params["flip_prob"]
        ),
        RandRotated(
            keys=keys, 
            range_x=aug_params["rot_range"], 
            range_y=aug_params["rot_range"], 
            range_z=aug_params["rot_range"], 
            prob=aug_params["rot_prob"],  
            keep_size=True,
            padding_mode='zeros',
            mode=mode
        ),
        ResizeWithPadOrCropd(
            keys=keys, 
            spatial_size=aug_params["final_size"], 
            method="symmetric",
            mode="constant"
        )
    ])
    
    augmented_samples = []
    
    
    for n in range(num_samples):
        sample = augment(sample_dict)
        
        # Add to list
        if save:
            np.save(os.path.join(dest, "source/", f"{filename}-{n}.npy"), sample["source"])
            np.save(os.path.join(dest, "target/", f"{filename}-{n}.npy"), sample["target"])
            if points is not None:
                np.save(os.path.join(dest, "points/", f"{filename}-{n}.npy"), sample["points"])
            checkpts = [0.25,0.5,0.75,0.99]
            for i in checkpts:
                if n == int(num_samples * i):
                    print(f"\t{int(i*100)}%")
        else:
            augmented_samples.append(sample)
            
    if save: print(f"{filename} samples saved\n")
    
    return augmented_samples