In [1]:
import torch
import numpy as np
from preparation import *
from data_utils import *
from torch.utils.data import DataLoader
import importlib

# paths
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
# Dummy Data
data = np.random.rand(5, 250, 250) # assume each value in the voxel is a fluorescence intensity
data_list = [data] * 50

dataset = CubeDataset(data_list)

sample = dataset[0]
print(sample.shape)

(5, 250, 250)


In [3]:
def custom_collate_fn(batch, mask_percentage=0.6, kernel=12):
    cubes = torch.stack([torch.tensor(cube, dtype=torch.float32) for cube in batch])  # (B, Z, Y, X)
    B, Z, Y, X = cubes.shape

    masks = torch.ones_like(cubes, dtype=torch.float32)
    num_blocks = int(mask_percentage * (Z * Y * X) / (kernel ** 3))

    visible_cubes = []

    for b in range(B):
        for _ in range(num_blocks):
            zi = np.random.randint(0, max(Z - kernel, 1))
            yi = np.random.randint(0, max(Y - kernel, 1))
            xi = np.random.randint(0, max(X - kernel, 1))
            masks[b, zi:zi + kernel, yi:yi + kernel, xi:xi + kernel] = 0

    masked_cubes = cubes * masks

    return cubes, masked_cubes, masks  # all are (B, Z, Y, X)

sparse_train_loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=custom_collate_fn,
    num_workers=0,
    pin_memory=False, # ok because we load data on gpu already ?
    drop_last=True
)

cubes, masked_cubes, masks = next(iter(sparse_train_loader))
cubes = cubes.unsqueeze(dim=1) # shape (B, C=1, Z, Y, X)
cubes = cubes.to(device)
print(len(cubes))
print(f"Cubes shape w/ feature dimension: {cubes.shape}")
# print(masked_cubes.shape)
# print(masks.shape)

2
Cubes shape w/ feature dimension: torch.Size([2, 1, 5, 250, 250])


We have a dataset of our original samples. They are (5 x 250 x 250) voxel cubes, or (35 x 106.25 x 106.5 um cubes)

Now, we want to embed each of these patches. The PatchEmbed3D layer takes a 3D volumetric image and splits it into non-overlapping 3D patches, then projects each patch into a fixed-length embedding vector. Analogous to how a ViT turns a 2D image into patches. We choose the patch dimension of (5 x 25 x 25) arbitrarily for now (the sample dimension we use is evenly divisilbe by our patch dimensions - important to maintain shape match between patch embeddings and pos embeddings). The patches will be our tokens for our transformer encoder. 

In [4]:
patch_size = (5, 25, 25)
patch_embedding_model = PatchEmbed3D(patch_size=patch_size).to(device)
patch_embeddings = patch_embedding_model(cubes) # reshape cubes from (B, Z, X, Y) --> (B, N_patches, embed_dim)
print(patch_embeddings.shape)

torch.Size([2, 100, 128])


Now, we want to calculate a positional encoding for each patch location. We normalize the voxel coordinates within the range [-1, 1] from the centers to encourage stable training.

In [5]:
def get_patch_centers(volume: torch.Tensor, patch_size: tuple) -> torch.Tensor:
    """
    Extract normalized patch center positions from a 3D volume.

    Args:
        volume (Tensor): Input tensor of shape (B, 1, Z, Y, X)
        patch_size (tuple): Tuple of (pz, py, px) indicating patch size

    Returns:
        Tensor: (B, N, 3) tensor of normalized patch centers in [-1, 1]^3
    """
    B, C, Z, Y, X = volume.shape
    pz, py, px = patch_size
    assert Z % pz == 0 and Y % py == 0 and X % px == 0

    # Number of patches in each dim
    nz = Z // pz
    ny = Y // py
    nx = X // px
    N = nz * ny * nx  # total number of patches per sample

    # Compute the *absolute* (unnormalized) voxel center positions
    z_centers = torch.arange(pz//2, Z, step=pz, dtype=torch.float32)
    y_centers = torch.arange(py//2, Y, step=py, dtype=torch.float32)
    x_centers = torch.arange(px//2, X, step=px, dtype=torch.float32)

    zz, yy, xx = torch.meshgrid(z_centers, y_centers, x_centers, indexing="ij")
    coords = torch.stack([zz, yy, xx], dim=-1)  # (nz, ny, nx, 3)
    coords = coords.view(-1, 3)  # (N, 3)

    # Normalize to [-1, 1] using shape
    norm_coords = coords.clone()
    norm_coords[:, 0] = 2 * (coords[:, 0] / (Z - 1)) - 1
    norm_coords[:, 1] = 2 * (coords[:, 1] / (Y - 1)) - 1
    norm_coords[:, 2] = 2 * (coords[:, 2] / (X - 1)) - 1

    # Repeat for all batch elements
    norm_coords = norm_coords.unsqueeze(0).repeat(B, 1, 1)  # (B, N_patches, 3)
    return norm_coords

In [6]:
patch_centers = get_patch_centers(cubes, patch_size)  # (B, N_patches, 3)
print(f"normalized patch shape: {patch_centers.shape}")

normalized patch shape: torch.Size([2, 100, 3])


In [7]:
from pos_embed_model import *

pos_encoder = LearnedPositionalEncoder(in_dim=3, embed_dim=128) # using same embed_dim as patch embeddings, currently NOT using intensity values in pos_encoder
pos_embeddings = pos_encoder(patch_centers).to(device)
print(f"pos embedding shape: {pos_embeddings.shape}") # (B, N_patches, embed_dim)

pos embedding shape: torch.Size([2, 100, 128])


Here is where we mask our embeddings! We define a method to randomly select 60% of our tokens to mask. The masking percentage may be an interesting hyperparameter to tune.

Now, we do element-wise add of the patch embeddings and positional embeddings to get our input to the transformer encoder. ViTs and PoLArMAE use elementwise add rather than concatenation, so we follow their convention.

In [8]:
def random_masking(x, mask_ratio=0.6):
    """
    x: (B, N, D) patch embeddings
    Returns:
        visible_x: (B, N_vis, D)
        mask_indices: (B, N_masked) indices of masked patches
        unmask_indices: (B, N_vis) indices of visible patches
    """
    B, N, _ = x.shape
    N_vis = int(N * (1 - mask_ratio))

    noise = torch.rand(B, N, device=x.device)  # (B, N)
    ids_sorted = torch.argsort(noise, dim=1)   # ascending order
    ids_keep = ids_sorted[:, :N_vis]
    ids_mask = ids_sorted[:, N_vis:]

    # Gather visible tokens
    batch_idx = torch.arange(B).unsqueeze(-1).to(x.device)  # (B, 1)
    x_visible = x[batch_idx, ids_keep]

    return x_visible, ids_keep, ids_mask

B, N, embed_dim = patch_embeddings.shape
patch_embed_vis, ids_keep, ids_mask = random_masking(patch_embeddings, mask_ratio=0.6)
patch_embed_vis, ids_keep, ids_mask = patch_embed_vis.to(device), ids_keep.to(device), ids_mask.to(device)
pos_embed_vis = pos_embeddings[torch.arange(B).unsqueeze(1), ids_keep].to(device)  # (B, N_visible_patches, D)

x = patch_embed_vis + pos_embed_vis
print(f"transformer input shape: {x.shape}")

transformer input shape: torch.Size([2, 40, 128])


Now we instantiate the transformer. It uses 6 self-attention layers and 4 heads. The transformer does not change the dimension of the input data, it simply applies attention.

In [9]:
from transformer_things import *
transformer_encoder = VisionTransformer3D().to(device)

latents = transformer_encoder(x).to(device)
print(f"per-patch encoder features shape: {latents.shape}") # (B, N_visible_patches, embed_dim)

per-patch encoder features shape: torch.Size([2, 40, 128])


Cool cool now we have our encoded latent vectors for the visible tokens. Our next step is to prepare a full token sequence for the transformer decoder. This involves creating a learnable token embedding which is passed in place of all masked tokens, re-adding positional embeddings to all tokens, and placing these in a sequence with the enocded visible tokens.

Re-adding positional encodings to ALL tokens is a deviation from what I expected. Explain why this is good later?

In [10]:
mask_token = nn.Parameter(torch.zeros(1, 1, 128)).to(device)

# Prepare full token sequence for decoder (visible + masked)
B, N, D = patch_embeddings.shape
x_full = torch.zeros(B, N, D, device=device).to(device)

# Fill in visible tokens at their original indices
x_full.scatter_(1, ids_keep.unsqueeze(-1).expand(-1, -1, D), patch_embed_vis) # tensor.scatter_ used to place elems from src into tgt at given indices - here x_vis indices

# Fill in mask tokens at masked indices
x_full.scatter_(1, ids_mask.unsqueeze(-1).expand(-1, -1, D), mask_token.expand(B, ids_mask.size(1), -1))

print(f"full decoder input sequence shape: {x_full.shape}")

# adding positional encodings to all
decoder_input = x_full + pos_embeddings

full decoder input sequence shape: torch.Size([2, 100, 128])


The decoder is trying to predict the raw voxel intensities inside each patch. The decoder is much more lightweight than our encoder - it only has 2 attention layers, whereas the encoder has 6.

In [16]:
px, py, pz = patch_size
output_dim = px * py * pz

decoder = Decoder(embed_dim=128, hidden_dim=256, num_layers=2, output_dim=output_dim, num_heads=4).to(device)
recon = decoder(decoder_input)  # (B, N, output_dim) - flattened array of intensity values inside each patch
print(f"Decoder output shape: {recon.shape}")

# ground truth
all_patches = extract_voxel_patches(cubes, patch_size)  # (B, N, P)
print(f"ground truth shape: {all_patches.shape}")

# selecting only masked patch ground truths
target_masked = all_patches[torch.arange(B).unsqueeze(1), ids_mask]  # (B, N_mask, P)
print(f"masked token array shape: {target_masked.shape}")

# # decoder outputs
recon_masked = recon[torch.arange(B).unsqueeze(1), ids_mask]  # (B, N_mask, P)
# # Reconstruction loss
loss = F.mse_loss(recon_masked, target_masked)
print(loss)

Decoder output shape: torch.Size([2, 100, 3125])
ground truth shape: torch.Size([2, 100, 3125])
masked token array shape: torch.Size([2, 60, 3125])
tensor(0.6758, device='cuda:0', grad_fn=<MseLossBackward0>)


Future work: Rather than just trying to reconstruct intensities per voxel in the decoder, we can try to reconstruct x, y, z, 1 --> spatial coordinates of voxel centers and intensities. This may be useful for future downstream tasks like instance detection, and track endpoint localization.