# Data Preprocessing and Exploration

This notebook demonstrates data preprocessing and visualization for µCT tooth scans.

In [None]:
import sys
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from data import CTPreprocessor, VolumeAugmenter

%matplotlib inline

## 1. Load Sample Data

In [None]:
# Load a sample µCT scan
sample_path = '../data/train/images/sample_001.npy'
volume = np.load(sample_path)

print(f"Volume shape: {volume.shape}")
print(f"Value range: [{volume.min():.2f}, {volume.max():.2f}]")
print(f"Mean: {volume.mean():.2f}, Std: {volume.std():.2f}")

## 2. Visualize Raw Data

In [None]:
def visualize_volume(volume, title='Volume'):
    """Visualize 3D volume with orthogonal slices."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Axial slice (z-axis)
    axes[0].imshow(volume[volume.shape[0]//2, :, :], cmap='gray')
    axes[0].set_title(f'{title} - Axial')
    axes[0].axis('off')
    
    # Coronal slice (y-axis)
    axes[1].imshow(volume[:, volume.shape[1]//2, :], cmap='gray')
    axes[1].set_title(f'{title} - Coronal')
    axes[1].axis('off')
    
    # Sagittal slice (x-axis)
    axes[2].imshow(volume[:, :, volume.shape[2]//2], cmap='gray')
    axes[2].set_title(f'{title} - Sagittal')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_volume(volume, 'Original Volume')

## 3. Apply Preprocessing

In [None]:
# Create preprocessor
preprocessor = CTPreprocessor(
    target_spacing=(0.1, 0.1, 0.1),
    target_size=(128, 128, 128),
    normalize=True,
    clip_range=(-1000, 3000)
)

# Preprocess volume
original_spacing = (0.15, 0.15, 0.15)  # Example original spacing
preprocessed = preprocessor.preprocess(volume, original_spacing)

print(f"Preprocessed shape: {preprocessed.shape}")
print(f"Value range: [{preprocessed.min():.2f}, {preprocessed.max():.2f}]")
print(f"Mean: {preprocessed.mean():.2f}, Std: {preprocessed.std():.2f}")

In [None]:
visualize_volume(preprocessed, 'Preprocessed Volume')

## 4. Apply Data Augmentation

In [None]:
# Create augmenter
augmenter = VolumeAugmenter(
    rotation_range=15.0,
    flip_prob=0.5,
    noise_std=0.05,
    brightness_range=0.2
)

# Apply augmentation
augmented = augmenter(preprocessed)

print(f"Augmented shape: {augmented.shape}")
visualize_volume(augmented, 'Augmented Volume')

## 5. Visualize Segmentation Masks

In [None]:
# Load corresponding mask
mask_path = '../data/train/masks/sample_001.npy'
mask = np.load(mask_path)

# Visualize mask
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

class_colors = ['black', 'red', 'green', 'blue']  # Background, Enamel, Dentin, Pulpa

axes[0].imshow(mask[mask.shape[0]//2, :, :], cmap='tab10', vmin=0, vmax=3)
axes[0].set_title('Mask - Axial')
axes[0].axis('off')

axes[1].imshow(mask[:, mask.shape[1]//2, :], cmap='tab10', vmin=0, vmax=3)
axes[1].set_title('Mask - Coronal')
axes[1].axis('off')

axes[2].imshow(mask[:, :, mask.shape[2]//2], cmap='tab10', vmin=0, vmax=3)
axes[2].set_title('Mask - Sagittal')
axes[2].axis('off')

plt.tight_layout()
plt.show()

# Print class distribution
unique, counts = np.unique(mask, return_counts=True)
class_names = ['Background', 'Enamel', 'Dentin', 'Pulpa']
print("\nClass distribution:")
for cls, count in zip(unique, counts):
    percentage = 100 * count / mask.size
    print(f"  {class_names[cls]}: {count} voxels ({percentage:.2f}%)")

## 6. Overlay Visualization

In [None]:
def overlay_mask_on_volume(volume, mask, alpha=0.5):
    """Overlay segmentation mask on volume."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    slice_idx = volume.shape[0] // 2
    
    # Axial
    axes[0].imshow(volume[slice_idx, :, :], cmap='gray')
    axes[0].imshow(mask[slice_idx, :, :], cmap='tab10', alpha=alpha, vmin=0, vmax=3)
    axes[0].set_title('Overlay - Axial')
    axes[0].axis('off')
    
    # Coronal
    axes[1].imshow(volume[:, volume.shape[1]//2, :], cmap='gray')
    axes[1].imshow(mask[:, mask.shape[1]//2, :], cmap='tab10', alpha=alpha, vmin=0, vmax=3)
    axes[1].set_title('Overlay - Coronal')
    axes[1].axis('off')
    
    # Sagittal
    axes[2].imshow(volume[:, :, volume.shape[2]//2], cmap='gray')
    axes[2].imshow(mask[:, :, mask.shape[2]//2], cmap='tab10', alpha=alpha, vmin=0, vmax=3)
    axes[2].set_title('Overlay - Sagittal')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

overlay_mask_on_volume(preprocessed, mask)