# Training a Unet model for segmentation

Importing useful packages

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor

import os
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
import ipdb

Useful paths

In [3]:
TILE_PATH = './dataset/tiles/images/'
TILE_GT_PATH = './dataset/tiles/ground_truth/'
LABELS_FILE = './dataset/tiles/labels.csv'

### Loading the data

First we create a Dataset object and a transform function that converts a data point to a tensor. We load the tile dataset and then we split it randomly as follows: 

* 10% N1 train - pixel-level labels;
* 70% N2 train - class labels;
* 20% validation - validation set.

In [5]:
class VaihingenDataset(Dataset):

    def __init__(self, img_dir, gt_dir, transform=None):

        self.img_dir = img_dir
        self.gt_dir = gt_dir

        self.data_points = os.listdir(self.img_dir)

        self.transform = transform

    def __len__(self):
        return len(self.data_points)

    def __getitem__(self, idx):

        img_name = os.path.join(self.img_dir,
                                self.data_points[idx])
        gt_name = os.path.join(self.gt_dir, self.data_points[idx])

        img = io.imread(img_name) 
        gt = io.imread(gt_name)

        sample = (img.astype(np.float32), gt.astype(np.float32))

        if self.transform:
            sample = self.transform(sample)

        return sample

In [6]:
class ToTensor(object):

    def __call__(self, sample):
        img, gt = sample[0], sample[1]

        # swapping axes
        img = img.transpose((2, 0, 1))
        gt = gt.transpose((2, 0, 1))
        return (torch.from_numpy(img), torch.from_numpy(gt))

In [7]:
ds = VaihingenDataset(TILE_PATH, TILE_GT_PATH, transform=ToTensor())

In [8]:
n1_len = int(0.1 * len(ds))
n2_len = int(0.7 * len(ds))
n_valid_len = len(ds) - n1_len - n2_len

print('N1 train size: {}/{}.'.format(n1_len, len(ds)))
print('N2 train size: {}/{}.'.format(n2_len, len(ds)))
print('Validation size: {}/{}.'.format(n_valid_len, len(ds)))

n1_train, n2_train, n_valid = random_split(ds, [n1_len, n2_len, n_valid_len])

N1 train size: 1463/14636.
N2 train size: 10245/14636.
Validation size: 2928/14636.


We define a data loader for each subset

In [10]:
n1_dataloader = DataLoader(n1_train, batch_size=128)
n2_dataloader = DataLoader(n2_train, batch_size=128)
n_valid_dataloader = DataLoader(n_valid, batch_size=128)