# Imports

In [None]:
import os, sys

# Make sure notebook can import from src/
PROJECT_ROOT = os.path.abspath(".")
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print("Project root:", PROJECT_ROOT)

In [None]:
import torch

from src.data.lits_dataset import make_loaders
from src.models.enhanced_unet import EnhancedUNet, DiceLoss
from src.training.train_loop import train_model

# Hyperparameters

In [None]:
BASE_PATH = "./LiTS17"

IMG_HEIGHT = 256
IMG_WIDTH = 256
BATCH_SIZE = 16
NUM_WORKERS = 2

LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
SAVE_DIR = "./"
EARLY_STOPPING_PATIENCE = 10

CLIP_MIN = -200
CLIP_MAX = 250

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Build DataLoaders

In [None]:
train_loader, val_loader = make_loaders(
    base_path=BASE_PATH,
    img_height=IMG_HEIGHT,
    img_width=IMG_WIDTH,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    val_size=0.2,
    random_state=42,   # keep same split
    clip_min=CLIP_MIN,
    clip_max=CLIP_MAX,
)

# Quick sanity check

In [None]:
sample_images, sample_masks = next(iter(train_loader))
print("Image batch:", sample_images.shape, sample_images.dtype)
print("Mask batch: ", sample_masks.shape, sample_masks.dtype)

print("Image range:", float(sample_images.min()), "to", float(sample_images.max()))
print("Mask unique:", sample_masks.unique())

# Initialize model

In [None]:
model = EnhancedUNet(in_channels=1, out_channels=1).to(device)
criterion = DiceLoss()

print(model.__class__.__name__, "initialized.")

# Train

In [None]:
train_losses, val_losses, val_dice_scores = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    criterion=criterion,
    learning_rate=LEARNING_RATE,
    num_epochs=NUM_EPOCHS,
    save_dir=SAVE_DIR,
    early_stopping_patience=EARLY_STOPPING_PATIENCE,
    grad_clip_max_norm=1.0,
)

In [None]:
print("Last Train Loss:", train_losses[-1] if len(train_losses) else None)
print("Last Val Loss:  ", val_losses[-1] if len(val_losses) else None)
print("Last Val Dice:  ", val_dice_scores[-1] if len(val_dice_scores) else None)

print("Checkpoint should be at:", f"{SAVE_DIR}/models/best_model.pth")