In [None]:
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import torch
from torch import nn, optim

from model import CNN_Segmentation

from utils import ImageSegmentationTrainingDataset

## Load dataset

In [18]:
# Full dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

dataset = ImageSegmentationTrainingDataset(root_dir='train', transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [19]:
# Data split
val_percent = 0.2
val_size = int(len(dataset) * val_percent)
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

## Functions for training

In [21]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for rgb, _, labels in dataloader:
        rgb, labels = rgb.to(device), labels.to(device)

        # (batch, class, H, W) → targets: (batch, H, W)
        targets = labels.argmax(dim=1)

        optimizer.zero_grad()
        outputs = model(rgb)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    return running_loss / len(dataloader)

In [22]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for rgb, _, labels in dataloader: 
            rgb, labels = rgb.to(device), labels.to(device) #shape [16,3,544,1024],[16,9,544,1024] (one hot)
            targets = labels.argmax(dim=1) #shape [16,544,1024] :
            outputs = model(rgb)
            loss = criterion(outputs, targets)
            # shape [16, 544, 1024] vs [16, 9, 544, 1024] aka input [16, 9, 544, 1024] vs target [16, 544, 1024]
            # normal for CE : "Input (C), (C,N), (C,N,d_1,...,d_k)" "Target (), (N), (N,d_1,...,d_k)" in the documentation
            # CE needs class indices
            total_loss += loss.item()

    return total_loss / len(dataloader)

## Loop

In [None]:
# Hyperparameters
EPOCHS = 2
LR = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN_Segmentation().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

In [10]:
# Loop
for epoch in range(EPOCHS):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    val_loss = evaluate(model, val_loader, criterion, device)
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([10, 544, 1024])
outputs torch.Size([10, 9, 544, 1024])
Epoch 1/2 | Train Loss: 2.0543 | Val Loss: 1.9422
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([16, 544, 1024])
outputs torch.Size([16, 9, 544, 1024])
torch.Size([

In [11]:
# Saves
torch.save(model.state_dict(), 'segmentation_model.pth')