# Semantic segmentation of aerial images with deep networks

This notebook presents a straightforward PyTorch implementation of a Fully Convolutional Network for semantic segmentation of aerial images. More specifically, we aim to automatically perform scene interpretation of images taken from a plane or a satellite by classifying every pixel into several land cover classes.

As a demonstration, we are going to use the [SegNet architecture](http://mi.eng.cam.ac.uk/projects/segnet/) to segment aerial images over the cities of Vaihingen and Potsdam. The images are from the [ISPRS 2D Semantic Labeling dataset](http://www2.isprs.org/commissions/comm3/wg4/results.html). We will train a network to segment roads, buildings, vegetation and cars.

This work is a PyTorch implementation of the baseline presented in ["Beyond RGB: Very High Resolution Urban Remote Sensing With Multimodal Deep Networks "](https://hal.archives-ouvertes.fr/hal-01636145), *Nicolas Audebert*, *Bertrand Le Saux* and *Sébastien Lefèvre*, ISPRS Journal, 2018.

## Requirements

This notebook requires a few useful libraries, e.g. `torch`, `scikit-image`, `numpy` and `matplotlib`. You can install everything using `pip install -r requirements.txt`.

This is expected to run on GPU, and therefore you should use `torch` in combination with CUDA/cuDNN. This can probably be made to run on CPU but be warned that:
  * you have to remove all calls to `torch.Tensor.cuda()` throughout this notebook,
  * this will be very slow.
  
A "small" GPU should be enough, e.g. this runs fine on a 4.7GB Tesla K20m. It uses quite a lot of RAM as the dataset is stored in-memory (about 5GB for Vaihingen). You can spare some memory by disabling the caching below. 4GB should be more than enough without caching.

In [None]:
from gepcore.utils import convolution
from gepcore.utils import cell_graph
from gepcore.entity import Gene, Chromosome, KExpressionGraph
from gepcore.symbol import PrimitiveSet
#from gepcore.operators import *
from nas_seg.seg_model import get_net, arch_config
from ptflops import get_model_complexity_info

# import fastai deep learning tools
from fastai.vision.all import *


# imports and stuff
import numpy as np
from skimage import io
from glob import glob
from pathlib import Path

from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix
import random
import itertools
# Matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
# Torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.optim.lr_scheduler
import torch.nn.init
from torch.autograd import Variable

## Parameters

There are several parameters than can be tuned to use this notebook with different datasets. The default parameters are suitable for the ISPRS dataset, but you can change them to work with your data.

### Examples

  * Binary classification: `N_CLASSES = 2`
  * Multi-spectral data (e.g. IRRGB): `IN_CHANNELS = 4`
  * New folder naming convention : `DATA_FOLDER = MAIN_FOLDER + 'sentinel2/sentinel2_img_{}.tif'`

In [None]:
# Parameters
WINDOW_SIZE = (128, 128) # Patch size
STRIDE = 32 # Stride for testing
IN_CHANNELS = 3 # Number of input channels (e.g. RGB)
FOLDER = "/home/cliff/rs_imagery/ISPRS-DATASETS/" # Replace with your "/path/to/the/ISPRS/dataset/folder/"
BATCH_SIZE = 10 # Number of samples in a mini-batch

LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] # Label names
N_CLASSES = len(LABELS) # Number of classes
WEIGHTS = torch.ones(N_CLASSES) # Weights for class balancing
CACHE = True # Store the dataset in-memory

DATASET = 'Vaihingen'

if DATASET == 'Potsdam':
    MAIN_FOLDER = FOLDER + 'Potsdam/'
    # Uncomment the next line for IRRG data
    # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif'
    # For RGB data
    DATA_FOLDER = MAIN_FOLDER + '2_Ortho_RGB/top_potsdam_{}_RGB.tif'
    LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif'
    ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif'    
elif DATASET == 'Vaihingen':
    MAIN_FOLDER = FOLDER + 'Vaihingen/'
    DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif'
    LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif'
    ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif'

## Visualizing the dataset

First, let's check that we are able to access the dataset and see what's going on. We use ```scikit-image``` for image manipulation.

As the ISPRS dataset is stored with a ground truth in the RGB format, we need to define the color palette that can map the label id to its RGB color. We define two helper functions to convert from numeric to colors and vice-versa.

In [None]:
# ISPRS color palette
# Let's define the standard ISPRS color palette
palette = {0 : (255, 255, 255), # Impervious surfaces (white)
           1 : (0, 0, 255),     # Buildings (blue)
           2 : (0, 255, 255),   # Low vegetation (cyan)
           3 : (0, 255, 0),     # Trees (green)
           4 : (255, 255, 0),   # Cars (yellow)
           5 : (255, 0, 0),     # Clutter (red)
           6 : (0, 0, 0)}       # Undefined (black)

invert_palette = {v: k for k, v in palette.items()}

def convert_to_color(arr_2d, palette=palette):
    """ Numeric labels to RGB-color encoding """
    arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)

    for c, i in palette.items():
        m = arr_2d == c
        arr_3d[m] = i

    return arr_3d

def convert_from_color(arr_3d, palette=invert_palette):
    """ RGB-color encoding to grayscale labels """
    arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)

    for c, i in palette.items():
        m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
        arr_2d[m] = i

    return arr_2d

# # We load one tile from the dataset and we display it
# img = io.imread('/home/cliff/rs_imagery/ISPRS-DATASETS/Vaihingen/top/top_mosaic_09cm_area11.tif')
# fig = plt.figure()
# fig.add_subplot(121)
# plt.imshow(img)

# # We load the ground truth
# gt = io.imread('/home/cliff/rs_imagery/ISPRS-DATASETS/Vaihingen/gts_for_participants/top_mosaic_09cm_area11.tif')
# fig.add_subplot(122)
# plt.imshow(gt)
# plt.show()

# # We also check that we can convert the ground truth into an array format
# array_gt = convert_from_color(gt)
# print("Ground truth in numerical format has shape ({},{}) : \n".format(*array_gt.shape[:2]), array_gt)

We need to define a bunch of utils functions.

In [None]:
# Utils

def get_random_pos(img, window_shape):
    """ Extract of 2D random patch of shape window_shape in the image """
    w, h = window_shape
    W, H = img.shape[:-1] #img.shape[-2:]
    x1 = random.randint(0, W - w - 1)
    x2 = x1 + w
    y1 = random.randint(0, H - h - 1)
    y2 = y1 + h
    return x1, x2, y1, y2


# def CrossEntropy2d(input, target, weight=None):#, size_average=True):
#     """ 2D version of the cross entropy loss """
#     dim = input.dim()
#     if dim == 2:
#         return F.cross_entropy(input, target, weight)
#     elif dim == 4:
#         output = input.view(input.size(0),input.size(1), -1)
#         output = torch.transpose(output,1,2).contiguous()
#         output = output.view(-1,output.size(2))
#         target = target.view(-1)
#         return F.cross_entropy(output, target, weight)
#     else:
#         raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim))

def accuracy(input, target):
    return 100 * float(np.count_nonzero(input == target)) / target.size

def sliding_window(top, step=10, window_size=(20,20)):
    """ Slide a window_shape window across the image with a stride of step """
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            yield x, y, window_size[0], window_size[1]
            
def count_sliding_window(top, step=10, window_size=(20,20)):
    """ Count the number of windows in an image """
    c = 0
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            c += 1
    return c

def grouper(n, iterable):
    """ Browse an iterator by chunk of n elements """
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk

def metrics(predictions, gts, label_values=LABELS):
    cm = confusion_matrix(
            gts,
            predictions,
            range(len(label_values)))
    
    print("Confusion matrix :")
    print(cm)
    
    print("---")
    
    # Compute global accuracy
    total = sum(sum(cm))
    accuracy = sum([cm[x][x] for x in range(len(cm))])
    accuracy *= 100 / float(total)
    print("{} pixels processed".format(total))
    print("Total accuracy : {}%".format(accuracy))
    
    print("---")
    
    # Compute F1 score
    F1Score = np.zeros(len(label_values))
    for i in range(len(label_values)):
        try:
            F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i]))
        except:
            # Ignore exception if there is no element in class i for test set
            pass
    print("F1Score :")
    for l_id, score in enumerate(F1Score):
        print("{}: {}".format(label_values[l_id], score))

    print("---")
        
    # Compute kappa coefficient
    total = np.sum(cm)
    pa = np.trace(cm) / float(total)
    pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / float(total*total)
    kappa = (pa - pe) / (1 - pe);
    print("Kappa: " + str(kappa))
    return accuracy

# IoU = TP / (TP + FP + FN)

## Loading the dataset

We define a PyTorch dataset (```torch.utils.data.Dataset```) that loads all the tiles in memory and performs random sampling. Tiles are stored in memory on the fly.

The dataset also performs random data augmentation (horizontal and vertical flips) and normalizes the data in [0, 1].

In [None]:
# Dataset class

# class ISPRS_dataset(torch.utils.data.Dataset):
#     def __init__(self, ids, data_files=DATA_FOLDER, label_files=LABEL_FOLDER,
#                             cache=False, augmentation=True):
#         super(ISPRS_dataset, self).__init__()
        
#         self.augmentation = augmentation
#         self.cache = cache
        
#         # List of files
#         self.data_files = [DATA_FOLDER.format(id) for id in ids]
#         self.label_files = [LABEL_FOLDER.format(id) for id in ids]

#         # Sanity check : raise an error if some files do not exist
#         for f in self.data_files + self.label_files:
#             if not os.path.isfile(f):
#                 raise KeyError('{} is not a file !'.format(f))
        
#         # Initialize cache dicts
#         self.data_cache_ = {}
#         self.label_cache_ = {}
            
    
#     def __len__(self):
#         # Default epoch size is 10 000 samples
#         return 8000
    
#     @classmethod
#     def data_augmentation(cls, *arrays, flip=True, mirror=True):
#         will_flip, will_mirror = False, False
#         if flip and random.random() < 0.5:
#             will_flip = True
#         if mirror and random.random() < 0.5:
#             will_mirror = True
        
#         results = []
#         for array in arrays:
#             if will_flip:
#                 if len(array.shape) == 2:
#                     array = array[::-1, :]
#                 else:
#                     array = array[:, ::-1, :]
#             if will_mirror:
#                 if len(array.shape) == 2:
#                     array = array[:, ::-1]
#                 else:
#                     array = array[:, :, ::-1]
#             results.append(np.copy(array))
            
#         return tuple(results)
    
#     def __getitem__(self, i):
#         # Pick a random image
#         random_idx = random.randint(0, len(self.data_files) - 1)
        
#         # If the tile hasn't been loaded yet, put in cache
#         if random_idx in self.data_cache_.keys():
#             data = self.data_cache_[random_idx]
#         else:
#             # Data is normalized in [0, 1]
#             data = 1/255 * np.asarray(io.imread(self.data_files[random_idx]).transpose((2,0,1)), dtype='float32')
#             if self.cache:
#                 self.data_cache_[random_idx] = data
            
#         if random_idx in self.label_cache_.keys():
#             label = self.label_cache_[random_idx]
#         else: 
#             # Labels are converted from RGB to their numeric values
#             label = np.asarray(convert_from_color(io.imread(self.label_files[random_idx])), dtype='int64')
#             if self.cache:
#                 self.label_cache_[random_idx] = label

#         # Get a random patch
#         x1, x2, y1, y2 = get_random_pos(data, WINDOW_SIZE)
#         data_p = data[:, x1:x2,y1:y2]
#         label_p = label[x1:x2,y1:y2]
        
#         # Data augmentation
#         data_p, label_p = self.data_augmentation(data_p, label_p)

#         # Return the torch.Tensor values
#         return (torch.from_numpy(data_p),
#                 torch.from_numpy(label_p))

In [None]:
# class MyDataset(Dataset):
#     def __init__(self, image_paths, target_paths, train=True):
#         self.image_paths = image_paths
#         self.target_paths = target_paths

    
#     def __getitem__(self, index):
#         image = Image.open(self.image_paths[index])
#         mask = Image.open(self.target_paths[index])
#         x, y = self.transform(image, mask)
#         return x, y

#     def __len__(self):
#         return len(self.image_paths)

In [None]:
# class ISPRSDataset(VisionDataset):
#     def __init__(self, img_dir, msk_dir, transform=None, target_transform=None, transforms=None):
#         super(ISPRSDataset, self).__init__(transform, target_transform, transforms)
#         self.img_dir = img_dir
#         self.msk_dir = msk_dir
#         self.img_files = get_files(self.img_dir)
#         self.get_mask = lambda x: self.msk_dir/f'{x.stem}{x.suffix}'

#     def __len__(self):
#         return len(self.img_files)

#     def __getitem__(self, idx):
#         img_file = self.img_files[idx]
#         mask_file = self.get_mask(img_file)
#         img = io.imread(img_file)
#         mask = np.asarray(io.imread(mask_file))
#         if self.transforms is not None:
#             img, mask = self.transforms(img, mask)
#         return img, mask
    
data_path = Path('/home/cliff/rs_imagery/ISPRS-DATASETS/Vaihingen/vaihingen_256')
img_path = data_path/'images/train'
msk_path = data_path/'masks/train'

mask_labels = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] 
num_classes = len(mask_labels) 
print(img_path, '\n', msk_path)

In [None]:
# Dataset class
from torchvision.datasets.vision import VisionDataset
#from torch.utils.data import Dataset
import torchvision.transforms.functional as tf

class ISPRS_dataset(VisionDataset):
    def __init__(self, data_dir, label_dir, transform=None, target_transform=None, transforms=None): 
        super(ISPRS_dataset, self).__init__(transform, target_transform, transforms)
        self.data_dir = data_dir
        self.label_dir = label_dir
        self.data_files = get_files(self.data_dir)
        self.label_files = lambda x: self.label_dir/f'{x.stem}{x.suffix}'
        
#         #List of files
#         self.data_files = [DATA_FOLDER.format(id) for id in ids]
#         self.label_files = [LABEL_FOLDER.format(id) for id in ids]

#         #Sanity check : raise an error if some files do not exist
#         for f in self.data_files + self.label_files:
#             if not os.path.isfile(f):
#                 raise KeyError('{} is not a file !'.format(f))
        
#         #Initialize cache dicts
#         self.data_cache_ = {}
#         self.label_cache_ = {}
#         for i in range(len(self.data_files)):
#             data = np.asarray(io.imread(self.data_files[i]))
#             label = np.asarray(convert_from_color(io.imread(self.label_files[i])))
#             self.data_cache_[i] = data
#             self.label_cache_[i] = label

            
    def flip_img(self, image, mask):
        # to PIL image
        image = tf.to_pil_image(image)
        mask = tf.to_pil_image(mask)
        
        # Random horizontal flipping
        if random.random() > 0.5:
            image = tf.hflip(image)
            mask = tf.hflip(mask)

        # Random vertical flipping
        if random.random() > 0.5:
            image = tf.vflip(image)
            mask = tf.vflip(mask)

        # image to np array
        image = np.array(image)
        mask = np.array(mask)
        return image, mask

    
    def __len__(self):
#         total = 0
#         for i in self.data_cache_:
#             total += count_sliding_window(self.data_cache_[i], step=128, window_size=(256, 256)) 
        return len(self.data_files)
    
             
    def __getitem__(self, idx):
        data_file = self.data_files[idx]
        label_file = self.label_files(data_file)
        data_p = np.asarray(io.imread(data_file))
        label_p = np.asarray(convert_from_color(io.imread(label_file)))

        # Data augmentation
        if self.transforms is not None:
            data_p, label_p = self.flip_img(data_p, label_p)
            data_p, label_p = self.transforms(data_p, label_p)

        return (data_p, label_p)
        
#     for idx in self.data_cache_:
#              #for i, coords in enumerate(grouper(1, sliding_window(img, step=128, window_size=(256,256)))):
#             #for j, coords in enumerate(sliding_window(self.data_cache_[idx], step=128, window_size=(256,256))):
#             coords = sliding_window(self.data_cache_[idx], step=128, window_size=(256,256))
#             print(coords)
#             x, y, w, h = coords
#             data_p = self.data_cache_[j][x:x+w, y:y+h, :]
#             label_p = self.label_cache_[j][x:x+w, y:y+h]               
# #         # Pick a random image
#         random_idx = random.randint(0, len(self.data_files) - 1)
        
#         # If the tile hasn't been loaded yet, put in cache
#         if random_idx in self.data_cache_.keys():
#             data = self.data_cache_[random_idx]
#         else:
#             data = np.asarray(io.imread(self.data_files[random_idx]))
#             if self.cache:
#                 self.data_cache_[random_idx] = data
            
#         if random_idx in self.label_cache_.keys():
#             label = self.label_cache_[random_idx]
#         else: 
#             # Labels are converted from RGB to their numeric values
#             label = np.asarray(convert_from_color(io.imread(self.label_files[random_idx])))
#             if self.cache:
#                 self.label_cache_[random_idx] = label

#         # Get a random patch
#         x1, x2, y1, y2 = get_random_pos(data, WINDOW_SIZE)
#         data_p = data[x1:x2,y1:y2,:]
#         label_p = label[x1:x2,y1:y2]
        
#         # Data augmentation
#         if self.transforms is not None:
#             data_p, label_p = self.flip_img(data_p, label_p)
#             data_p, label_p = self.transforms(data_p, label_p)
        
#         return (data_p, label_p)

## Network definition

We can now define the Fully Convolutional network based on the SegNet architecture. We could use any other network as drop-in replacement, provided that the output has dimensions `(N_CLASSES, W, H)` where `W` and `H` are the sliding window dimensions (i.e. the network should preserve the spatial dimensions).

In [None]:
from pygraphviz import AGraph
#import glob

graph = [AGraph(g) for g in glob('../graphs/*.dot')]
_, comp_graphs = cell_graph.generate_comp_graph(graph)

conf = arch_config(comp_graphs=comp_graphs, channels=64, classes=N_CLASSES)
net = get_net(conf)
net.cuda()

We can now instantiate the network using the specified parameters. By default, the weights will be initialized using the [He policy](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf).

In [None]:
# ## # instantiate the network
# # net = SegNet()
# # define primitive set
# pset = PrimitiveSet('cnn')
# # add cellular encoding program symbols
# # pset.add_program_symbol(cell_graph.end)
# pset.add_program_symbol(cell_graph.seq)
# pset.add_program_symbol(cell_graph.cpo)
# pset.add_program_symbol(cell_graph.cpi)

# # add convolutional operations symbols
# conv_symbol = convolution.get_symbol()
# # pset.add_program_symbol(conv_symbol.conv1x1)
# # pset.add_program_symbol(conv_symbol.conv3x3)
# # pset.add_cell_symbol(conv_symbol.dwconv3x3)
# # pset.add_cell_symbol(conv_symbol.sepconv3x3)
# # pset.add_cell_symbol(conv_symbol.sepconv5x5)
# pset.add_cell_symbol(conv_symbol.dilconv3x3)
# pset.add_cell_symbol(conv_symbol.dilconv5x5)
# pset.add_cell_symbol(conv_symbol.conv3x3)
# # pset.add_cell_symbol(conv_symbol.conv1x1)
# # pset.add_cell_symbol(conv_symbol.dwconv3x3)
# # pset.add_cell_symbol(conv_symbol.maxpool3x3)
# # pset.add_cell_symbol(conv_symbol.avgpool3x3)

# def gene_gen():
#     return Gene(pset, 1)

# ch = Chromosome(gene_gen, 3)
# graph, comp_graphs = cell_graph.generate_comp_graph(ch)

# cell_graph.save_graph(graph, '../graphs/')
# cell_graph.draw_graph(graph, '../graphs/')

# conf = arch_config(comp_graphs=comp_graphs,
#                    channels=64,
#                    classes=N_CLASSES)

# net = get_net(conf)
# net.cuda()

### Loading the data

We now create a train/test split. If you want to use another dataset, you have to adjust the method to collect all filenames. In our case, we specify a fixed train/test split for the demo.

In [None]:
import torchvision.transforms as transforms

tfms = transforms.Compose([transforms.ToTensor(),
                           transforms.Normalize([0.4776, 0.3226, 0.3189], [0.1816, 0.1224, 0.1185])])

In [None]:
# Load the datasets
if DATASET == 'Potsdam':
    all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*')))
    all_ids = ["".join(f.split('')[5:7]) for f in all_files]
elif DATASET == 'Vaihingen':
    #all_ids = 
    all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*')))
    all_ids = [f.split('area')[-1].split('.')[0] for f in all_files]
# Random tile numbers for train/test split
train_ids = random.sample(all_ids, 2 * len(all_ids) // 3 + 1)
test_ids = list(set(all_ids) - set(train_ids))

# Exemple of a train/test split on Vaihingen :
train_ids = ['1', '3', '11', '13', '17', '26', '28', '32', '34', '37']
valid_ids = ['21', '15'] 
test_ids = ['5', '7', '23', '30']
print("Tiles for training : ", train_ids)
print("Tiles for testing : ", test_ids)
print("Tiles for validation : ", valid_ids)

train_set = ISPRS_dataset(data_dir=img_path, label_dir=msk_path, transforms=tfms)
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=BATCH_SIZE)

In [None]:
#img, lb = iter(train_loader).next() 

In [None]:
#plt.imshow(convert_to_color(lb[0]))

In [None]:
#im = transforms.ToPILImage()(img[0])
#plt.imshow((img[0].numpy()))

### Designing the optimizer

We use the standard Stochastic Gradient Descent algorithm to optimize the network's weights.

The encoder is trained at half the learning rate of the decoder, as we rely on the pre-trained VGG-16 weights. We use the ``torch.optim.lr_scheduler`` to reduce the learning rate by 10 after 25, 35 and 45 epochs.

In [None]:
max_lr = 1e-2
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0005)
optimizer = optim.Adam(net.parameters(), weight_decay=1e-2)
# We define the scheduler
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr, steps_per_epoch=len(train_loader), epochs=50)

In [None]:
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(net.parameters(), lr=1e-2, weight_decay=0.004)
# # We define the scheduler
# #scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [10, 25, 35], gamma=0.1)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)

In [None]:
#import torchvision.transforms.functional as tf

from IPython.display import clear_output

In [None]:
def test(net, test_ids, all=False, stride=WINDOW_SIZE[0], batch_size=BATCH_SIZE, window_size=WINDOW_SIZE):
    # Use the network on the test set
    test_images = (np.asarray(io.imread(DATA_FOLDER.format(id))) for id in test_ids)
    test_labels = (np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='uint8') for id in test_ids)
    eroded_labels = (convert_from_color(io.imread(ERODED_FOLDER.format(id))) for id in test_ids)
    all_preds = []
    all_gts = []
    
    # Switch the network to inference mode
    net.eval()

    for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels), total=len(test_ids), leave=False):
        pred = np.zeros(img.shape[:2] + (N_CLASSES,))

        total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size
        for i, coords in enumerate(tqdm(grouper(batch_size, 
                                                sliding_window(img, step=stride, window_size=window_size)), 
                                                total=total, leave=False)):
            # Display in progress results
            if i > 0 and total > 10 and i % int(10 * total / 100) == 0:
                    _pred = np.argmax(pred, axis=-1)
                    fig = plt.figure()
                    fig.add_subplot(1,3,1)
                    plt.imshow(np.asarray(img, dtype='uint8'))
                    fig.add_subplot(1,3,2)
                    plt.imshow(convert_to_color(_pred))
                    fig.add_subplot(1,3,3)
                    plt.imshow(gt)
                    clear_output()
                    plt.show()
                    
            # Build the tensor
            #image_patches = [np.copy(img[x:x+w, y:y+h]).transpose((2,0,1)) for x,y,w,h in coords]
            #image_patches = np.asarray(image_patches)
            #image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True)

            image_patches = [torch.clone(tf.normalize(tf.to_tensor(img[x:x+w, y:y+h]), 
                                        [0.4776, 0.3226, 0.3189], [0.1816, 0.1224, 0.1185]))
                             for x,y,w,h in coords]
            image_patches = torch.stack(image_patches).cuda()
            #print(image_patches)
            
            #image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True)
            
            # Do the inference
            outs = net(image_patches)
            outs = outs.data.cpu().numpy()
            
            # Fill in the results array
            for out, (x, y, w, h) in zip(outs, coords):
                out = out.transpose((1,2,0))
                pred[x:x+w, y:y+h] += out
            del(outs)

        pred = np.argmax(pred, axis=-1)

        # Display the result
        clear_output()
        fig = plt.figure()
        fig.add_subplot(1,3,1)
        plt.imshow(np.asarray(img, dtype='uint8'))
        fig.add_subplot(1,3,2)
        plt.imshow(convert_to_color(pred))
        fig.add_subplot(1,3,3)
        plt.imshow(gt)
        plt.show()

        all_preds.append(pred)
        all_gts.append(gt_e)

        clear_output()
        # Compute some metrics
        metrics(pred.ravel(), gt_e.ravel())
        accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]), 
                           np.concatenate([p.ravel() for p in all_gts]).ravel())
    if all:
        return accuracy, all_preds, all_gts
    else:
        return accuracy

In [None]:
def train(net, optimizer, epochs, scheduler=None, weights=WEIGHTS, save_epoch = 1):
    losses = np.zeros(1000000)
    mean_losses = np.zeros(100000000)
    #weights = weights.cuda()

    #criterion = nn.NLLLoss2d(weight=weights)
    iter_ = 0
    
    for e in range(1, epochs + 1):
        #if scheduler is not None:
           # scheduler.step()
        net.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            #data, target = Variable(data.cuda()), Variable(target.cuda())
            data = data.to(device='cuda', dtype=torch.float32)
            target = target.to(device='cuda', dtype=torch.long).squeeze()
            optimizer.zero_grad()
            output = net(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            #scheduler.step()
            losses[iter_] = loss.item()
            mean_losses[iter_] = np.mean(losses[max(0,iter_-100):iter_])
            
            if iter_ % 100 == 0:
                clear_output()
                #rgb = np.asarray(255 * np.transpose(data.data.cpu().numpy()[0],(1,2,0)), dtype='uint8')
                pred = np.argmax(output.data.cpu().numpy()[0], axis=0)
                gt = target.data.cpu().numpy()[0]
                print('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {}'.format(
                    e, epochs, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.item(), accuracy(pred, gt)))
                plt.plot(mean_losses[:iter_]) and plt.show()
                fig = plt.figure()
                #fig.add_subplot(131)
                #plt.imshow(rgb)
                #plt.title('RGB')
                fig.add_subplot(131)
                plt.imshow(convert_to_color(gt))
                plt.title('Ground truth')
                fig.add_subplot(132)
                plt.title('Prediction')
                plt.imshow(convert_to_color(pred))
                plt.show()
            iter_ += 1
            
            del(data, target, loss)
            
        if e % save_epoch == 0:
            # We validate with the largest possible stride for faster computing
            acc = test(net, valid_ids, all=False)
            scheduler.step(acc)
            torch.save(net.state_dict(), './segnet256_epoch{}_{}'.format(e, acc))
    torch.save(net.state_dict(), './segnet_final')

### Training the network

Let's train the network for 50 epochs. The `matplotlib` graph is periodically udpated with the loss plot and a sample inference. Depending on your GPU, this might take from a few hours (Titan Pascal) to a full day (old K20).

In [None]:
train(net, optimizer, 50, scheduler)

### Testing the network

Now that the training has ended, we can load the final weights and test the network using a reasonable stride, e.g. half or a quarter of the window size. Inference time depends on the chosen stride, e.g. a step size of 32 (75% overlap) will take ~15 minutes, but no overlap will take only one minute or two.

In [None]:
net.load_state_dict(torch.load('./model.pth'))

In [None]:
_, all_preds, all_gts = test(net, test_ids, all=True)

In [None]:
Confusion matrix :
[[4493643  157861  114900   50671    6974    4918]
 [ 189511 5306269   30744   10540      69     458]
 [ 102023   89794 2316540  463292     717      27]
 [  23284    9544  121207 3921389     124     101]
 [  41181    3322     455     557  133306    7710]
 [   4067     763     567      15       0     980]]
---
17607523 pixels processed
Total accuracy : 91.8478255004977%
---
F1Score :
roads: 0.9281820438895197
buildings: 0.9556416377851561
low veg.: 0.8337667357831099
trees: 0.9202856146122447
cars: 0.8135334629151015
clutter: 0.09521033712231614
---
Kappa: 0.8900782505413787

### Saving the results

We can visualize and save the resulting tiles for qualitative assessment.

In [None]:
for p, id_ in zip(all_preds, test_ids):
    img = convert_to_color(p)
    plt.imshow(img) and plt.show()
    io.imsave('./inference_tile_{}.png'.format(id_), img)