In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import random_split
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms.functional as F
from torchvision.io import ImageReadMode
from torchvision.io import read_image
from torchvision.utils import draw_segmentation_masks
from torchvision.utils import make_grid
from torchvision.ops import sigmoid_focal_loss
from torchvision.transforms.functional import convert_image_dtype

import pl_bolts
from pl_bolts.models.vision import UNet

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

import torchmetrics as tm

import os

In [None]:
pl.seed_everything(42, workers=True)

# 1. Preparation

## 1.1 DataModule

In [None]:
class RoadSatelliteModule(pl.LightningDataModule):
    def prepare_data(self):
        self.train_images = self.read_images('train/images/', ImageReadMode.RGB)
        self.train_masks = self.read_images('train/groundtruth/', ImageReadMode.GRAY)
    
        for i, train_mask in enumerate(self.train_masks):
            self.train_masks[i][self.train_masks[i] > 0] = 1
            
        self.train_zip = list(zip(self.train_images, self.train_masks))
        
        self.test_images = self.read_images('test/', ImageReadMode.RGB)
        
    def setup(self, stage=None):
        if stage in (None, 'fit'):
            train_length = int(len(self.train_zip) * 0.8)
            valid_length = len(self.train_zip) - train_length

            self.train_data, self.valid_data = random_split(self.train_zip, [train_length, valid_length])
            
        if stage in (None, 'test'):
            self.test_data = self.test_images
            
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=16)
    
    def val_dataloader(self):
        return DataLoader(self.valid_data, batch_size=16)
    
    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=16)
    
    def read_images(self, data_dir, read_mode):
        return [read_image(data_dir + file, read_mode) for file in os.listdir(data_dir)]

In [None]:
road_data = RoadSatelliteModule()

In [None]:
road_data.prepare_data()
road_data.setup()

## 1.2 Inspect data

In [None]:
def show_image(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
seg_images = [draw_segmentation_masks(train_pair[0], train_pair[1].bool()) for train_pair in road_data.train_zip]

In [None]:
#for seg_image in seg_images:
#    show_image(seg_image)

# 2. System

In [None]:
class SemanticSegmentationSystem(pl.LightningModule):
    def __init__(self, model: nn.Module, datamodule: pl.LightningDataModule, lr: float = 1e-4, batch_size: int = 16):
        super().__init__()
        
        self.model = model
        self.datamodule = datamodule
        
        self.lr = lr
        self.batch_size = batch_size

    def training_step(self, batch, batch_idx):
        X, y = batch
        
        X = X.float()
        y = y.float()
        
        y_pred = self.model(X)
       
        loss = sigmoid_focal_loss(y_pred, y, reduction='mean')
        
        self.log('training_loss', loss)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
                
        X = X.float()
        y = y.int()
        
        y_pred = self.model(X)
        y_sig = torch.sigmoid(y_pred)
       
        metric = tm.functional.accuracy(y_sig, y, average='samples')
        
        self.log('validation_metric', metric)
        
        return metric
    
    def test_step(self, batch, batch_idx):
        X, _ = batch
        
        return self.model(X)
    
    def visualize_results(self):
        Xs, ys = next(iter(self.val_dataloader()))
                
        y_preds = torch.sigmoid(self.model(X.float()))
        
        for y_pred in y_preds:
            show_image(y_pred)
            
    def visualize_results_overlay(self):
        Xs, ys = next(iter(self.val_dataloader()))
                
        y_preds = torch.sigmoid(self.model(Xs.float()))
        
        pred_zip = list(zip(Xs, y_preds))
        
        seg_images = [draw_segmentation_masks(train_pair[0], train_pair[1].round().bool(), colors=['#00ff00']) for train_pair in pred_zip]
        
        for seg_image in seg_images:
            show_image(seg_image)
            
    def train_dataloader(self):
        return self.datamodule.train_dataloader()

    def val_dataloader(self):
        return self.datamodule.val_dataloader()

    def test_dataloader(self):
        return self.datamodule.test_dataloader()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=4, verbose=2)
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler,
            'monitor': 'validation_metric'
        }

# 3. Model

In [None]:
class Conv2d(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(3, 1, 5, padding='same'),
            nn.Softmax()
        )
        
    def forward(self, X):
        return self.model(X)

In [None]:
model = Conv2d()

In [None]:
model = UNet(1, 3, 5, 12)

In [None]:
X, y = next(iter(road_data.train_dataloader()))

In [None]:
system = SemanticSegmentationSystem(model, road_data)

In [None]:
system.visualize_results()

In [None]:
system.visualize_results_overlay()

# 4. Training

In [None]:
early_stop_callback = EarlyStopping(
   monitor='validation_metric',
   patience=10,
   verbose=2,
   mode='max'
)

In [None]:
trainer = pl.Trainer(
    #fast_dev_run=True,
    gpus=-1,
    auto_select_gpus=True,
    auto_lr_find=True,
    auto_scale_batch_size='binsearch',
    stochastic_weight_avg=True,
    deterministic=True,
    callbacks=[early_stop_callback]
)

In [None]:
trainer.tune(system)

In [None]:
trainer.fit(system)

In [None]:
system.visualize_results_overlay()