In [1]:
from utils.visualize_predictions import visualize_prediction
from utils.KittiRoadsDataset import KittiRoadsDataset
from torch.utils.data import DataLoader, random_split
from utils.UNetTorch import UNet
import torch.optim as optim
import torch.nn as nn
import torch


dataset = KittiRoadsDataset(
    root_dir="datasets/kitti_roads/training",
)

train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(
    dataset,
    [train_size, test_size]
)

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False
)

print(f"Train size: {len(train_dataset)}, Test size: {len(test_dataset)}")
sample = dataset[0]
print("Image shape:", sample["image"].shape)
print("Mask shape:", sample["road_gt"].shape)
print("Unique mask values:", torch.unique(sample["road_gt"]))

Train size: 230, Test size: 26
Image shape: torch.Size([3, 256, 256])
Mask shape: torch.Size([256, 256])
Unique mask values: tensor([0, 1, 2])


In [2]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else "mps"
)

model = UNet(
    in_channels=3,
    out_channels=3
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-4
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=3
)

In [None]:
best_loss = float("inf")
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for batch in train_loader:
        images = batch['image'].to(device)
        masks = batch['road_gt'].to(device).long()

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, "best_unet_model.pth")
        print("Saved new best model ✅")

    scheduler.step(epoch_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.4f}")

visualize_prediction(
    model,
    dataset,
    idx=5,
    device=device
)
print("Training complete!")

✅ Model improved! Saved new best model.
Epoch [1/50], Loss: 0.6792
✅ Model improved! Saved new best model.
Epoch [2/50], Loss: 0.4682
✅ Model improved! Saved new best model.
Epoch [3/50], Loss: 0.4452
Epoch [4/50], Loss: 0.4457
✅ Model improved! Saved new best model.
Epoch [5/50], Loss: 0.3997
✅ Model improved! Saved new best model.
Epoch [6/50], Loss: 0.3612
✅ Model improved! Saved new best model.
Epoch [7/50], Loss: 0.3346
✅ Model improved! Saved new best model.
Epoch [8/50], Loss: 0.3316
✅ Model improved! Saved new best model.
Epoch [9/50], Loss: 0.3253
✅ Model improved! Saved new best model.
Epoch [10/50], Loss: 0.3095
Epoch [11/50], Loss: 0.3144
Epoch [12/50], Loss: 0.3345
✅ Model improved! Saved new best model.
Epoch [13/50], Loss: 0.2987
Epoch [14/50], Loss: 0.3021
Epoch [15/50], Loss: 0.3113
✅ Model improved! Saved new best model.
Epoch [16/50], Loss: 0.2882
✅ Model improved! Saved new best model.
Epoch [17/50], Loss: 0.2862
✅ Model improved! Saved new best model.
Epoch [18/50]