# 3D Tooth Segmentation - Training Example

This notebook demonstrates how to train a 3D U-Net model for tooth segmentation from ÂµCT scans.

In [None]:
import sys
sys.path.append('../src')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from models import UNet3D
from data import CTPreprocessor, VolumeAugmenter, create_data_loaders
from utils import get_loss_function, SegmentationMetrics
from training.train import Trainer

## 1. Setup and Configuration

In [None]:
# Configuration
config = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_classes': 4,
    'batch_size': 2,
    'num_epochs': 100,
    'learning_rate': 1e-4,
    'checkpoint_dir': '../checkpoints',
    'log_dir': '../logs'
}

print(f"Using device: {config['device']}")

## 2. Create Data Loaders

In [None]:
# Create preprocessor
preprocessor = CTPreprocessor(
    target_spacing=(0.1, 0.1, 0.1),
    target_size=(128, 128, 128),
    normalize=True,
    clip_range=(-1000, 3000)
)

# Create augmenter
augmenter = VolumeAugmenter(
    rotation_range=15.0,
    flip_prob=0.5,
    noise_std=0.05,
    brightness_range=0.2
)

# Create data loaders
train_loader, val_loader = create_data_loaders(
    train_root='../data/train',
    val_root='../data/val',
    batch_size=config['batch_size'],
    num_workers=4,
    preprocessor=preprocessor,
    augmenter=augmenter
)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")

## 3. Create Model

In [None]:
# Create model
model = UNet3D(
    n_channels=1,
    n_classes=config['num_classes'],
    base_channels=32,
    trilinear=False
)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

## 4. Define Loss, Optimizer, and Scheduler

In [None]:
# Loss function
criterion = get_loss_function(
    loss_type='combined',
    num_classes=config['num_classes'],
    dice_weight=0.5,
    ce_weight=0.3,
    focal_weight=0.2
)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=1e-4
)

# Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config['num_epochs'],
    eta_min=1e-6
)

## 5. Train Model

In [None]:
# Class names
class_names = ['Background', 'Enamel', 'Dentin', 'Pulpa']

# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=config['device'],
    num_classes=config['num_classes'],
    class_names=class_names,
    checkpoint_dir=config['checkpoint_dir'],
    log_dir=config['log_dir']
)

# Train
trainer.train(num_epochs=config['num_epochs'])

## 6. Visualize Training Results

You can visualize the training progress using TensorBoard:

```bash
tensorboard --logdir logs
```

In [None]:
print(f"Best validation Dice score: {trainer.best_val_dice:.4f}")