In [None]:
import Unet
import GetData

import torch
import torch.nn as nn

import time
from random import randint
import os
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.get_device_name(device)

In [None]:
height = 224
width = 224
tfms = transforms.Compose([
    transforms.Resize((height, width), interpolation=Image.NEAREST)])

trainset = GetData.SegmentationDataset(image_dir="images/leftImg8bit/train", 
                                      label_dir="labels_class/train",
                                      transform=tfms)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True, num_workers=2)

testset = GetData.SegmentationDataset(image_dir="images/leftImg8bit/val", 
                                      label_dir="labels_class/val",
                                      transform=tfms)
testloader = DataLoader(testset, batch_size=16, shuffle=True, num_workers=2)

In [None]:
def evaluate_model(model, testloader, criterion):
    model.eval()
    running_loss = 0.
    with torch.no_grad():
        for i, batch in enumerate(testloader):
            image, inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs) 
            _, predicted = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
        
            

    fig, ax = plt.subplots(1, 2, figsize=(20, 6))
    img = image[0].cpu().data.numpy().reshape(height, width, 3)
    label = labels[0].cpu().data.numpy().reshape(height, width)
    preds = predicted[0].cpu().data.numpy().reshape(height, width)
    concat_labels = np.concatenate([label, preds], axis=1)
    ax[0].imshow(img.astype(np.uint8))
    ax[1].imshow(concat_labels)
    plt.show()
    test_loss = running_loss/len(testloader)
    print(f"Test loss: {test_loss:.4f}")
    iou(preds, label)
    return test_loss


In [None]:
classnames = {
    0:"void",
    1:"flat",
    2:"construction",
    3:"object",
    4:"nature",
    5:"sky",
    6:"human",
    7:"vehicle"
}
def iou(prediction, target):
    mean = []
    for i in range(8):
        prediction_c = prediction==i
        target_c = target==i
        intersection = np.logical_and(prediction_c, target_c)
        union = np.logical_or(prediction_c, target_c)
        intou = np.sum(intersection)/np.sum(union)
        mean.append(intou)
        print(f"{classnames[i]} iou:{intou}")
    mean = np.mean(mean)
    print(f"mIOU: {mean}")

In [None]:
def save_model_state(model, optimizer, loss, test_loss, epoch):
    model_path = f"saved_models/unet_epoch_110+{epoch}.pt"
    state_dict = {
        'epoch' : epoch,
        'model_state_dict' : model.state_dict(),
        'opt_state_dict' : optimizer.state_dict(),
        'training_loss' : loss,
        'test_loss' : test_loss,
    }
    torch.save(state_dict, model_path)

In [None]:
def train_net(model, trainloader, testloader, optimizer, criterion, epochs=50):
    
    checkpoint = len(trainloader)/3
    for epoch in range(epochs):
        model.train()
        running_loss = 0.
        running_acc = 0.
        epoch_loss = 0.
        
        start = time.time()
        print(f"-------------- Epoch: {epoch+1} Train --------------")
        for i, batch in enumerate(trainloader, 1):
            

            image, inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs) 
            _, predicted = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            
            epoch_loss += loss.item()
            running_loss += loss.item()
            running_acc += (labels==predicted).sum().item()

            loss.backward()
            optimizer.step()
            
            if i % checkpoint == 0: # print every 90 batches (90*16 images)
                print(f"Batch: {i+1}/{len(trainloader)}, loss: {(running_loss/checkpoint):.5f}, acc: {(100/(16*width*height)*running_acc/checkpoint):.5f}")
                running_loss = 0.
                running_acc = 0.
                
                fig, ax = plt.subplots(1, 2, figsize=(20, 6))
                img = image[0].cpu().data.numpy().reshape(height, width, 3)
                label = labels[0].cpu().data.numpy().reshape(height, width)
                preds = predicted[0].cpu().data.numpy().reshape(height, width)
                concat_labels = np.concatenate([label, preds], axis=1)
                ax[0].imshow(img.astype(np.uint8))
                ax[1].imshow(concat_labels)
                plt.show()
                iou(preds, label)
                
        test_loss = evaluate_model(model, testloader, criterion)
        epoch_loss /= len(trainloader)
        # save every 10 epochs
        if epoch % 10 == 9:
            save_model_state(model, optimizer, epoch_loss, test_loss, epoch+1)
            print("Saved model")
        print(f"Epoch: {epoch+1} complete, time: {int(time.time()-start)}s, loss: {epoch_loss:.5f}")
        
    return model

In [None]:
Unet_model = Unet.Unet(input_channels=3, num_classes=8).to(device)
optimizer = torch.optim.Adam(Unet_model.parameters())
loss_fn = nn.CrossEntropyLoss()
train_net(Unet_model, trainloader, testloader, optimizer, loss_fn, epochs=500)

# Load a trained model & setup

In [None]:
Unet_model = Unet.Unet(input_channels=3, num_classes=8).to(device)
loss_fn = nn.CrossEntropyLoss()
model_info = torch.load("saved_models/unet_epoch_140+70+200+110.pt")
optimizer = torch.optim.Adam(Unet_model.parameters())
optimizer.load_state_dict(model_info["opt_state_dict"])
Unet_model.load_state_dict(model_info["model_state_dict"])

In [None]:
# start training again
train_net(Unet_model, trainloader, testloader, optimizer, loss_fn, epochs=400)

In [None]:
# see the test data
evaluate_model(Unet_model, testloader, loss_fn)