<a href="https://colab.research.google.com/github/federico2879/MLDL2024_semantic_segmentation/blob/master/training/training-bisenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/federico2879/MLDL2024_semantic_segmentation.git

Cloning into 'MLDL2024_semantic_segmentation'...
remote: Enumerating objects: 552, done.[K
remote: Counting objects: 100% (203/203), done.[K
remote: Compressing objects: 100% (108/108), done.[K
remote: Total 552 (delta 117), reused 150 (delta 93), pack-reused 349[K
Receiving objects: 100% (552/552), 272.11 KiB | 10.08 MiB/s, done.
Resolving deltas: 100% (324/324), done.


In [2]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import DataLoader
import numpy as np
from MLDL2024_semantic_segmentation.datasets.cityscapes import CityScapes
from MLDL2024_semantic_segmentation.datasets.gta5 import GTA5
from MLDL2024_semantic_segmentation.models.bisenet.build_bisenet import *
from MLDL2024_semantic_segmentation.train import * 
from MLDL2024_semantic_segmentation.utils import *
from MLDL2024_semantic_segmentation.models.IOU import * 

In [3]:
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

# Setup fixed parameters
num_epochs = 50
num_classes = 19
batch_size = 4

# optimization
momentum = 0.9
w_decay = 1e-4
power = 0.9
lr_init = 2.5e-2

In [4]:
# Transformations

transform_image = {
    'train': transforms.Compose([transforms.Resize((1280, 720)),
                    transforms.RandomHorizontalFlip(p=0.15),
                    #transforms.ColorJitter(brightness=0.35,contrast=0.15,saturation=0.35,hue=0.05),
                    transforms.GaussianBlur(kernel_size=3, sigma=(0.2,0.8)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]) ]),
    
    
    'test': transforms.Compose([
        transforms.Resize((1024, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
}
                   
transform_target = {
    'train': transforms.Compose([
        transforms.Resize((1280, 720), interpolation=InterpolationMode.NEAREST),
    ]),
    'test': transforms.Compose([
        transforms.Resize((1024, 512), interpolation=InterpolationMode.NEAREST)
    ])
}




In [5]:
# Create dataloader

#/d/famigliadiena
dataset_train = GTA5('/kaggle/input/mldl-gta5/GTA', 
                     transform = transform_image['train'], 
                     label_transform = transform_target['train'])
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_val = CityScapes('/kaggle/input/cityscapes/Cityscapes/Cityspaces', 
                          split = 'val', transform = transform_image['test'], 
                          label_transform = transform_target['test'])
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)

In [6]:
# Inizialization of the model
model = BiSeNet(num_classes=num_classes, context_path="resnet18").to(device)
#model = torch.nn.DataParallel(model, device_ids = [0,1]).to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 162MB/s]
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:01<00:00, 159MB/s]  


In [7]:
# Define loss and optimizer
loss_fn =  nn.CrossEntropyLoss(ignore_index=255)
optimizer = torch.optim.SGD(model.parameters(), lr=lr_init, momentum=momentum, 
                            weight_decay=w_decay)

In [9]:
# metrics
meanIOU_tr = np.zeros((num_epochs,1))
IOU_tr = np.zeros((num_epochs, num_classes))
loss_tr = np.zeros((num_epochs,1))

meanIOU_val = np.zeros((num_epochs,1))
IOU_val = np.zeros((num_epochs, num_classes))
loss_val = np.zeros((num_epochs,1))

In [12]:
## Set the random seeds
torch.manual_seed(42)
torch.cuda.manual_seed(42)

In [13]:

for epoch in range(num_epochs):
    poly_lr_scheduler(optimizer, lr_init, epoch, lr_decay_iter=1,
                    max_iter=num_epochs, power=power)
    meanIOU_tr[epoch], IOU_tr[epoch,:], loss_tr[epoch] = train(model, optimizer, dataloader_train, loss_fn, num_classes, 0)
    
    meanIOU_val[epoch], IOU_val[epoch,:], loss_val[epoch] = test(model, dataloader_val, loss_fn, num_classes, 0)
    print(f"epoch: {epoch+1}, Validation IOU: {meanIOU_val[epoch,0]:.2f}")

    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'meanIOU_tr': meanIOU_tr,
        'IOU_tr': IOU_tr,
        'loss_tr': loss_tr,
        'meanIOU_val': meanIOU_val,
        'IOU_val': IOU_val,
        'loss_val': loss_val,
    },"checkpoint.pth.tar")

epoch: 1, Validation IOU: 0.13
epoch: 2, Validation IOU: 0.13
epoch: 3, Validation IOU: 0.12
epoch: 4, Validation IOU: 0.15
epoch: 5, Validation IOU: 0.08
epoch: 6, Validation IOU: 0.14
epoch: 7, Validation IOU: 0.09
epoch: 8, Validation IOU: 0.15
epoch: 9, Validation IOU: 0.15
epoch: 10, Validation IOU: 0.07
epoch: 11, Validation IOU: 0.09
epoch: 12, Validation IOU: 0.15
epoch: 13, Validation IOU: 0.13
epoch: 14, Validation IOU: 0.17
epoch: 15, Validation IOU: 0.16
epoch: 16, Validation IOU: 0.16
epoch: 17, Validation IOU: 0.14
epoch: 18, Validation IOU: 0.16
epoch: 19, Validation IOU: 0.14
epoch: 20, Validation IOU: 0.14
epoch: 21, Validation IOU: 0.12
epoch: 22, Validation IOU: 0.14
epoch: 23, Validation IOU: 0.11
epoch: 24, Validation IOU: 0.15
epoch: 25, Validation IOU: 0.12
epoch: 26, Validation IOU: 0.13
epoch: 27, Validation IOU: 0.14
epoch: 28, Validation IOU: 0.14
epoch: 29, Validation IOU: 0.10
epoch: 30, Validation IOU: 0.12
epoch: 31, Validation IOU: 0.14
epoch: 32, Valida

In [None]:

from MLDL2024_semantic_segmentation.load_checkpoint import *



model, optimizer, start_epoch, meanIOU_tr, IOU_tr, loss_tr, meanIOU_val, IOU_val, loss_val \
= load_checkpoint(model, optimizer,"/kaggle/input/checkpoint/checkpoint/checkpoint.pth.tar")


for epoch in range(start_epoch,num_epochs):
    poly_lr_scheduler(optimizer, lr_init, epoch, lr_decay_iter=1,
                      max_iter=num_epochs, power=power)
    meanIOU_tr[epoch], IOU_tr[epoch,:], loss_tr[epoch] = train(model, optimizer, 
                                                               dataloader_train, loss_fn,
                                                               num_classes, 0)
    meanIOU_val[epoch], IOU_val[epoch,:], loss_val[epoch] = test(model, dataloader_val, 
                                                                 loss_fn, num_classes, 0)
    print(f"epoch: {epoch + 1}, Validation IOU: {meanIOU_val[epoch,0]:.2f}")

    
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'meanIOU_tr': meanIOU_tr,
        'IOU_tr': IOU_tr,
        'loss_tr': loss_tr,
        'meanIOU_val': meanIOU_val,
        'IOU_val': IOU_val,
        'loss_val': loss_val
        },"checkpoint.pth.tar")


In [16]:
# writing csv
import csv

with open('meanIOU_tr.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(meanIOU_tr)

with open('IOU_tr.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(IOU_tr)

with open('loss_tr.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(loss_tr)

with open('meanIOU_val.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(meanIOU_val)

with open('IOU_val.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(IOU_val)

with open('loss_val.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(loss_val)