In [None]:
from utils.visualize_predictions import visualize_prediction
from utils.KittiRoadsDataset import KittiRoadsDataset
from torch.utils.data import DataLoader, random_split
from utils.UNetTorch import UNet
import torch.optim as optim
import torch.nn as nn
import torch


dataset = KittiRoadsDataset(
    root_dir="datasets/kitti_roads/training",
)

train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(
    dataset,
    [train_size, test_size]
)

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False
)

print(f"Train size: {len(train_dataset)}, Test size: {len(test_dataset)}")
sample = dataset[0]
print("Image shape:", sample["image"].shape)
print("Mask shape:", sample["road_gt"].shape)
print("Unique mask values:", torch.unique(sample["road_gt"]))

device = torch.device(
    "cuda" if torch.cuda.is_available() else "mps"
)

model = UNet(
    in_channels=3,
    out_channels=3
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-4
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=3
)

In [None]:
best_loss = float("inf")
num_epochs = 50


for i in range(3):
    model = UNet(in_channels=3, out_channels=3).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for batch in train_loader:
            images = batch['image'].to(device)
            masks = batch['road_gt'].to(device).long()

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        print(f"Model {i+1} | Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.4f}")

    # Save each model separately
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, f"unet_model_{i}.pth")

print("âœ… Finished training all 3 models!")

In [None]:
from utils.EnsembleUNetTorch import UNetEnsemble

model_paths = ["unet_model_0.pth", "unet_model_1.pth", "unet_model_2.pth"]
ensemble_model = UNetEnsemble(model_paths=model_paths, device=device)
ensemble_model.to(device)
ensemble_model.eval()