In [1]:
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.optim import Adam
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from tqdm import tqdm
from torchmetrics.classification import MulticlassJaccardIndex, MulticlassAccuracy
import copy

class SemanticSegmentationDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_paths = sorted([os.path.join(image_dir, img) for img in os.listdir(image_dir)])
        self.label_paths = sorted([os.path.join(label_dir, lbl) for lbl in os.listdir(label_dir)])
        self.class_colors = {
            (2, 0, 0): 0,       
            (127, 0, 0): 1,     
            (248, 163, 191): 2  
        }
    
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        label = cv2.imread(self.label_paths[idx])
        label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB)

        label_mask = np.zeros(label.shape[:2], dtype=np.uint8)
        for rgb, idx in self.class_colors.items():
            label_mask[np.all(label == rgb, axis=-1)] = idx

        if self.transform:
            image = self.transform(image)
            label_mask = torch.from_numpy(label_mask).long()

        return image, label_mask

train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor()
])

dataset = SemanticSegmentationDataset(
    image_dir='input',
    label_dir='label',
    transform=train_transform)


def train_epoch(model, dataloader, criterion, optimizer, device, num_classes):
    model.train()
    running_loss = 0.0  
    accuracy_metric = MulticlassAccuracy(num_classes=num_classes).to(device)
    iou_metric = MulticlassJaccardIndex(num_classes=num_classes).to(device)
    pbar = tqdm(dataloader, desc='Training', unit='batch')
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)       
        optimizer.zero_grad()
        outputs = model(images)    
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()   
        running_loss += loss.item() * images.size(0)      
        preds = torch.argmax(outputs, dim=1)     
        accuracy_metric(preds, labels)
        iou_metric(preds, labels)
        pbar.set_postfix({
            'Batch Loss': f'{loss.item():.4f}',
            'Mean Accuracy': f'{accuracy_metric.compute():.4f}',
            'Mean IoU': f'{iou_metric.compute():.4f}',
        }) 
    epoch_loss = running_loss / len(dataloader.dataset)  
    mean_accuracy = accuracy_metric.compute().cpu().numpy()
    mean_iou = iou_metric.compute().cpu().numpy()
   
    return epoch_loss, mean_accuracy, mean_iou

def evaluate(model, dataloader, criterion, device, num_classes):
    model.eval()
    running_loss = 0.0    
    accuracy_metric = MulticlassAccuracy(num_classes=num_classes).to(device)
    iou_metric = MulticlassJaccardIndex(num_classes=num_classes).to(device)
    pbar = tqdm(dataloader, desc='Evaluating', unit='batch')
    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            preds = torch.argmax(outputs, dim=1)
            # Update metrics
            accuracy_metric(preds, labels)
            iou_metric(preds, labels)
            # Update tqdm description with metrics
            pbar.set_postfix({
                'Batch Loss': f'{loss.item():.4f}',
                'Mean Accuracy': f'{accuracy_metric.compute():.4f}',
                'Mean IoU': f'{iou_metric.compute():.4f}',
            })
    
    epoch_loss = running_loss / len(dataloader.dataset)
    mean_accuracy = accuracy_metric.compute().cpu().numpy()
    mean_iou = iou_metric.compute().cpu().numpy()
    
    return epoch_loss, mean_accuracy, mean_iou
    
class myModel(nn.Module):
    def __init__(self, n_classes):
        super(myModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=1)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.upSampling = nn.Upsample(scale_factor=4, mode="bilinear") 
        self.batchnorm = nn.BatchNorm2d(64)
        self.conv_out = nn.Conv2d(64, n_classes, kernel_size=1, padding=0, stride=1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.maxpool(x)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = self.maxpool(x)
        
        x = self.conv3(x)
        x = F.relu(x)
        x = self.upSampling(x)
        x = self.batchnorm(x)
        x = self.conv_out(x)
        
        return F.softmax(x, dim=1)

total_size = len(dataset)
train_size = int(0.8 * total_size)  
val_size = total_size - train_size  
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
classes = 3  
model = myModel(classes)
def count_parameters(model):  
    return sum(p.numel() for p in model.parameters())
total_params = count_parameters(model)
print(f"Total parameters: {total_params}")
model.to(device)
model = nn.DataParallel(model)

criterion = nn.CrossEntropyLoss()  
optimizer = Adam(model.parameters(), lr=0.001)
num_epochs = 5

epoch_saved = 0
best_val_mAcc = 0.0  
best_model_state = None

for epoch in range(num_epochs):
    epoch_loss_train, mAcc_train, mIoU_train = train_epoch(model, train_dataloader, criterion, optimizer, device, classes)
    epoch_loss_val, mAcc_val, mIoU_val = evaluate(model, val_dataloader, criterion, device, classes)
    
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"Train Loss: {epoch_loss_train:.4f}, Mean Accuracy: {mAcc_train:.4f}, Mean IoU: {mIoU_train:.4f}")
    print(f"Validation Loss: {epoch_loss_val:.4f}, Mean Accuracy: {mAcc_val:.4f}, Mean IoU: {mIoU_val:.4f}")

    if mAcc_val >= best_val_mAcc:
        epoch_saved = epoch + 1 
        best_val_mAcc = mAcc_val
        best_model_state = copy.deepcopy(model.state_dict())
    
print("===================")
print(f"Best Model at epoch : {epoch_saved}")
model.load_state_dict(best_model_state)
if isinstance(model, torch.nn.DataParallel):
    model = model.module
model_save = torch.jit.script(model)
model_save.save("NguyenVanA_21119375.pt")
# Check again
model = torch.jit.load("NguyenVanA_21119375.pt")
epoch_loss_val, mAcc_val, mIoU_val = evaluate(model, val_dataloader, criterion, device, classes)
print(f"Validation Loss: {epoch_loss_val:.4f}, Mean Accuracy: {mAcc_val:.4f}, Mean IoU: {mIoU_val:.4f}")

Train size: 4800, Validation size: 1200
Total parameters: 23907


Training:  19%|█▉        | 29/150 [01:59<08:18,  4.12s/batch, Batch Loss=0.7940, Mean Accuracy=0.5855, Mean IoU=0.4394]


KeyboardInterrupt: 