In [None]:
import os
import gc
import cv2
import torch
import shutil
import random
import mlflow
import IPython
import numpy as np
from getpass import getpass
from glob import glob
from skimage import color
import albumentations as albu
from google.colab import files
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset as BaseDataset

## UNet Model : EffNet-B0 + Attention

In [None]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch()

In [None]:
train_images      = sorted(glob('/content/imgs/train/*'))
validation_images = sorted(glob('/content/imgs/validation/*'))
test_images       = sorted(glob('/content/imgs/test/*'))

train_masks      = sorted(glob('/content/masks/train/*'))
validation_masks = sorted(glob('/content/masks/validation/*'))
test_masks       = sorted(glob('/content/masks/test/*'))

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [None]:
visualize(image=cv2.imread(train_images[0]), mask=cv2.imread(train_masks[0]))

In [None]:
class Dataset(BaseDataset):
    def __init__(
            self,
            typ='train', 
            augmentation=None, 
            preprocessing=None,
    ):
        if typ == 'train':
          self.images = train_images
          self.masks = train_masks
        elif typ == 'valid':
          self.images = validation_images
          self.masks = validation_masks
        elif typ == 'test':
          self.images = test_images
          self.masks = test_masks

        self.typ = typ
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = np.array(ImageOps.autocontrast(Image.open(self.images[i]).convert('RGB')))
        mask = np.array(Image.open(self.masks[i]).convert('L'))

        bmask = np.array((mask == 0).astype(int))
        # bmask = np.array((mask == 0).astype(int)) + np.array((mask == 39).astype(int))
        pmask = np.array((mask == 39).astype(int))
        wmask = np.array((mask == 78).astype(int))

        # mask = np.stack([bmask, pmask, wmask], axis=-1).astype('float')
        # mask = np.stack([bmask, wmask], axis=-1).astype('float')
        mask = 0 * bmask + 1 * pmask + 2 * wmask
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image / 255., mask
        
    def __len__(self):
        return len(self.images)

In [None]:
dataset = Dataset(typ='train')
print(len(dataset))

image, mask = dataset[1] # get some sample
print(mask.shape)

visualize(image=image, bg=mask == 0, pl=mask == 1, wd=mask == 2)

In [None]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),

        albu.ShiftScaleRotate(scale_limit=0.2, rotate_limit=90, shift_limit=0.3, p=0.3, border_mode=0),
        albu.SmallestMaxSize(max_size=768, always_apply=True),
        albu.GaussNoise(p=0.3),
        albu.Perspective(p=0.3),

        albu.OneOf(
            [
                albu.CLAHE(p=0.3),
                albu.RandomBrightnessContrast(p=0.3),
                albu.RandomGamma(p=0.3),
                albu.HueSaturationValue(p=0.3)
            ],
            p=0.3,
        ),

        albu.OneOf(
            [
                albu.Sharpen(p=0.3),
                albu.Blur(blur_limit=3, p=0.3),
                albu.MotionBlur(blur_limit=3, p=0.3),
            ],
            p=0.3,
        )
    ]

    return albu.Compose(train_transform)


def get_validation_augmentation():
    test_transform = [
        albu.SmallestMaxSize(max_size=768, always_apply=True),
    ]
    return albu.Compose(test_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def to_tensor_mask(x, **kwargs):
    return x.astype('int32')

def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor_mask),
    ]
    
    return albu.Compose(_transform)

In [None]:
augmented_dataset = Dataset(
    augmentation=get_training_augmentation() 
)

for i in range(3):
    image, mask = augmented_dataset[i]
    visualize(image=image, bg=mask == 0, pl=mask == 1, wd=mask == 2)

In [None]:
ENCODER = 'timm-efficientnet-b0' # 'resnet34' 
ENCODER_WEIGHTS = 'noisy-student' # 'imagenet'  
ACTIVATION = None 
DEVICE = 'cuda'

model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS, 
    classes=3,
    activation=ACTIVATION,
    decoder_attention_type='scse'
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

train_dataset = Dataset(
    typ = 'train',
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
)

valid_dataset = Dataset(
    typ = 'valid',
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
)

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=2)


In [None]:
class IoU(torch.nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoU, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        # print(np.unique(inputs.cpu().numpy().flatten()))
        # print(np.unique(targets.cpu().numpy().flatten()))

        # inputs = F.sigmoid(inputs)   

        inputs = inputs.long()
        targets = targets.long()

        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return IoU

# lloss = IoULoss()
lloss = torch.nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.2, 0.7]).to('cuda'))

metrics = [
    IoU()
]

optimizer = torch.optim.Adam(params=model.parameters())

In [None]:
def show(x, y, pred):
    image, gt_mask = x[0], y[0]
    pr_mask = pred[0]

    visualize(
        image=np.moveaxis(image * 255, 0, -1),
        bg_gt=np.moveaxis(gt_mask, 0, -1)[:, :, 0],
        pl_gt=np.moveaxis(gt_mask, 0, -1)[:, :, 1],
        wd_gt=np.moveaxis(gt_mask, 0, -1)[:, :, 2],
        bg_mask=np.moveaxis(pr_mask, 0, -1)[:, :, 0], 
        pl_mask=np.moveaxis(pr_mask, 0, -1)[:, :, 1],
        wd_mask=np.moveaxis(pr_mask, 0, -1)[:, :, 2],
    )

Train the model and save it

In [None]:
max_score = 1000
device = 'cuda'

model.to(device)

for i in range(40):
    
    print('\nEpoch: {}'.format(i))

    running_loss = []
    
    running_plant_iou  = []
    running_weed_iou = []
    running_bg_iou = []

    model.train()

    with tqdm(train_dataloader) as iterator:
            for x, y in iterator:
                optimizer.zero_grad()

                x, y = x.to(device), y.to(device)
                pred = model(x)
                loss = lloss(pred, y.long())

                loss.backward()
                optimizer.step()

                pred = torch.nn.Softmax(dim=1)(pred)
                pred = torch.argmax(pred, dim=1)
                for iter in range(len(pred)):
                  # print(pred[iter].shape)
                  # print(y[iter].shape)
                  met1  = metrics[0](pred[iter] == 0, y[iter] == 0)
                  met2  = metrics[0](pred[iter] == 1, y[iter] == 1)
                  met3  = metrics[0](pred[iter] == 2, y[iter] == 2)

                  running_bg_iou.append(met1.item())
                  running_plant_iou.append(met2.item())
                  running_weed_iou.append(met3.item())

                del pred, x, y

                running_loss.append(loss.item())

                iterator.set_postfix_str('Train Loss: ' + str(np.mean(np.array(running_loss))) + \
                                        ' Train BG IoU: ' + str(np.mean(np.array(running_bg_iou))) + \
                                         ' Train Plant IoU: ' + str(np.mean(np.array(running_plant_iou))) + \
                                         ' Train Weed IoU: ' + str(np.mean(np.array(running_weed_iou))))

    model.eval()

    running_val_loss = []
    running_val_plant_iou  = []
    running_val_weed_iou = []
    running_val_bg_iou = []

    c = 0
    with tqdm(valid_dataloader) as iterator:
        for x, y in iterator:
            x, y = x.to(device), y.to(device)

            with torch.no_grad():
              x, y = x.to(device), y.to(device)
              pred = model(x)
              loss = lloss(pred, y.long())

              pred = torch.nn.Softmax(dim=1)(pred)
              pred = torch.argmax(pred, dim=1)
              for iter in range(len(pred)):
                
                met1  = metrics[0](pred[iter] == 0, y[iter] == 0)
                met2  = metrics[0](pred[iter] == 1, y[iter] == 1)
                met3  = metrics[0](pred[iter] == 2, y[iter] == 2)

                running_val_bg_iou.append(met1.item())
                running_val_plant_iou.append(met2.item())
                running_val_weed_iou.append(met3.item())

            if c == 1:
               show(x.detach().cpu().numpy(), y.detach().cpu().numpy(), pred.detach().cpu().numpy())

            del pred, x, y

            running_val_loss.append(loss.item())

            iterator.set_postfix_str('Valid Loss: ' + str(np.mean(np.array(running_val_loss))) + \
                                    ' Valid BG IoU: ' + str(np.mean(np.array(running_val_bg_iou))) + \
                                      ' Valid Plant IoU: ' + str(np.mean(np.array(running_val_plant_iou))) + \
                                      ' Valid Weed IoU: ' + str(np.mean(np.array(running_val_weed_iou))))
            
            c += 1

    if max_score > np.mean(np.array(running_val_loss)):
        max_score = np.mean(np.array(running_val_loss))
        torch.save(model, './best_model.pth')
        print('Model saved!')
