In [None]:
# Install required libs
# !pip install -U segmentation-models-pytorch --user 

## Loading data

In [None]:
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import matplotlib.pyplot as plt
from datasets.utils import visualize
from datasets.rellis_3d import Rellis3D as Dataset
from torch.utils.data import DataLoader
import torch
import numpy as np
import segmentation_models_pytorch as smp

### Dataloader

Writing helper class for data extraction, tranformation and preprocessing  
https://pytorch.org/docs/stable/data

In [None]:
# Lets look at data we have

# ds = Dataset(classes=['grass', 'tree', 'sky'], split='train')
# ds = Dataset(classes=['grass', 'tree', 'sky'], split='val')
ds = Dataset(classes=['grass', 'tree', 'sky'], split='test')

ind = int( np.random.choice(range(len(ds))) )
image, mask = ds[ind] # get some sample
image_vis = image * ds.std + ds.mean

visualize(
    image=image_vis, 
    grass_mask=mask[0, ...],
    tree_mask=mask[1, ...],
    sky_mask=mask[2, ...]
)

In [None]:
src_size = np.array(image.shape[:2])
src_size

In [None]:
# Visualize resulted augmented images and masks

ds_aug = Dataset(split='train')

# same image with different random transforms
for i in range(3):
    image, mask = ds_aug[1]
    image_vis = image.transpose([1, 2, 0]) * ds_aug.std + ds_aug.mean
    
    visualize(image=image_vis, grass_mask=mask[2, ...])

## Create model and train

In [None]:
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['void', 'dirt', 'grass', 'tree', 'pole', 'water',
           'sky', 'vehicle', 'object', 'asphalt', 'building',
           'log', 'person', 'fence', 'bush', 'concrete',
           'barrier', 'puddle', 'mud', 'rubble']
ACTIVATION = 'sigmoid' if len(CLASSES) == 1 else 'softmax2d'  # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'
IMG_SIZE = (352, 640)
LR = 0.0001
# create segmentation model with pretrained encoder
model = smp.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES),
    activation=ACTIVATION,
)

In [None]:
train_dataset = Dataset(classes=CLASSES, crop_size=IMG_SIZE, split='train')
valid_dataset = Dataset(classes=CLASSES, crop_size=IMG_SIZE, split='val')


train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4)

In [None]:
# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index

loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=LR),
])

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
# train model

max_score = 0
n_epochs = 10
for i in range(0, n_epochs):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')