# Segmentation


In [None]:
import albumentations
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils import data


from skimage.io import imread
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import numpy as np
from collections import namedtuple
import json
from pathlib import Path

In [None]:
from src.utilities import get_filenames_of_path
from src.datasets import SegmentationDataSet

In [None]:
# ref https://forums.fast.ai/t/feedback-on-using-custom-dice-loss-in-multi-class-semantic-segmentation/70431/4

def dice_loss(output, target, eps=1e-7):
    eps = 1e-7
    # convert target to onehot
    targ_onehot = torch.eye(output.shape[1])[target].permute(0,3,1,2).float().cuda()
    # convert logits to probs
    pred = F.softmax(output, dim=1)
    # sum over HW
    inter = (pred * targ_onehot).sum(axis=[0,2,3])
    union = (pred + targ_onehot).sum(axis=[0,2,3])
    # mean over C
    dices = (2. * inter / (union + eps))
    dice = dices[dices>0.0001].mean()
    return 1. - dice
    
class DiceLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction
        
    def forward(self, output, targ):
        """
        output is NCHW, targ is NHW
        """
        return dice_loss(output, targ)

    def activation(self, output):
        return F.softmax(output, dim=1)
    
    def decodes(self, output):
        return output.argmax(1)

In [None]:
# input parameters 

input_param={'run_no':'run_10083',
            'batch_size':4,
            'img_w':512, 
            'lr':0.1, 
            'epoches':10,
             "n_blocks":4,
             "start_filters":32,
             'train_size':5,
             'loss':  'dice', 
             'optimizer':'sgd'
             }

Parameters = namedtuple('Parameters',list(input_param.keys()))
parameters = Parameters(**input_param) 

# training data
load the images in the dataloader, in preparation for data processing

In [None]:
class SegmentationDataSetInheritance(SegmentationDataSet):
    # def __init__(self):
    #     pass 

    def plot_image(self, idx):
        """plot the image before transformation"""
        inp = self.inputs[idx]
        tar = self.targets[idx] 
        inp, tar = self.read_images(inp, tar, pre_transform=False)
        
        from matplotlib import pyplot as plt
        fig, axes = plt.subplots(1,2)
        axes[0].imshow(inp)
        axes[1].imshow(tar)    

## read file names

In [None]:
import pathlib

def get_filenames_of_path(path: pathlib.Path, ext: str = '*', ignore_list: list = []):
    """Returns a list of files in a directory/path. Uses pathlib."""
    filenames = [file for file in path.glob(ext) if (file.is_file() and file.stem not in ignore_list)]
    filenames.sort()  # sort alpabetically
    return filenames

In [None]:
# input and target train files
root = Path('./data/')
inputs_train = get_filenames_of_path(root / 'imgs/train')[:parameters.train_size]
targets_train = get_filenames_of_path(root / 'masks/train')[:parameters.train_size]

# input and target valid files
inputs_valid = get_filenames_of_path(root / 'imgs/validation')[:parameters.train_size]
targets_valid = get_filenames_of_path(root / 'masks/validation')[:parameters.train_size]

## data transformation

In [None]:
from src.transformations import ComposeDouble, create_dense_target, normalize_01
from src.transformations import FunctionWrapperDouble, Resize, AlbuSeg2d

In [None]:
# data transformation
# about resize https://ai.stackexchange.com/questions/6274/how-can-i-deal-with-images-of-variable-dimensions-when-doing-image-segmentation

img_w = parameters.img_w
img_h = img_w

transforms_valid = ComposeDouble([
    Resize(input_size=(img_w, img_h, 3), target_size=(img_w, img_h)),  
    FunctionWrapperDouble(create_dense_target, input=False, target=True),
    FunctionWrapperDouble(np.moveaxis, input=True, target=False, source=-1, destination=0),
    FunctionWrapperDouble(normalize_01, input=True, target=False)
])

transformation = [
    Resize(input_size=(img_w, img_h, 3), target_size=(img_w, img_h)), 
    FunctionWrapperDouble(create_dense_target, input=False, target=True),
    FunctionWrapperDouble(np.moveaxis, input=True, target=False, source=-1, destination=0),
    FunctionWrapperDouble(normalize_01, input=True, target=False)
]

transforms_train = ComposeDouble(transformation)

In [None]:
# test the classes

batch_size = parameters.batch_size

dataset_train = SegmentationDataSetInheritance(inputs=inputs_train,
                                        targets=targets_train,
                                        transform=transforms_train
                                        )

dataset_valid = SegmentationDataSetInheritance(inputs=inputs_valid,
                                        targets=targets_train,
                                        transform=transforms_valid
                                        )

In [None]:
# create a 
dataloader_train = data.DataLoader(dataset=dataset_train,
                                      batch_size=batch_size,
                                      shuffle=True)

dataloader_valid = data.DataLoader(dataset=dataset_valid,
                                      batch_size=batch_size,
                                      shuffle=True
                                  #  sampler=train_sampler
                                   )

# training model

In [None]:
from src.unet import UNet

In [None]:
# device
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print("The divice running is: ", device)

## loss function

the dataset is highly imblanced, with a low presence for crop and weed. it is therefore necessary to include adjust the weight factor to increase panelty for wrongly segmenting crops or weeds 

reference is made to 

In [None]:
loss_functions = {'dice':DiceLoss(),
                   'crossentropy': torch.nn.CrossEntropyLoss()}

## setup the trainer

In [None]:
from src.trainer import Trainer
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

In [None]:
from src.predict_script import preprocess, postprocess, predict
from src.utilities import Metrics

# Train the model

In [None]:
from torch import nn

In [None]:
# model
writer_str = parameters.run_no
writer = SummaryWriter(f'runs/{writer_str}')

model = UNet(in_channels=3,       # 3 because we have RGB images
            out_channels=3,      # 3 classes in segmentation
            n_blocks= parameters.n_blocks,          # 4 blocks of layers in each part of the unet
            start_filters= parameters.start_filters,
            activation='relu',
            normalization= 'batch',
            conv_mode='same').to(device)

# optimizer
optimizer = {'sgd':torch.optim.SGD(model.parameters(), parameters.lr),
'adam': torch.optim.Adam( model.parameters(), parameters.lr),} 

# select 
opti = optimizer[parameters.optimizer]
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=opti)

trainer = Trainer(model=model,
                  device=device,
                  criterion= loss_functions[parameters.loss], #DiceLoss(),#torch.nn.CrossEntropyLoss(),
                  optimizer= opti,
                  training_DataLoader=dataloader_train,
                  validation_DataLoader=dataloader_valid,
                  lr_scheduler= None, #scheduler,
                  epochs=100, #parameters.epoches,
                  epoch=0,
                  notebook=True,
                  writer = writer)

# train the model
training_loss, validation_loss, lr_rates = trainer.run_trainer()
writer.close()