In [1]:
# DS LIBS
import numpy as np
import pandas as pd

# VISUALIZATION LIBS
import cv2
from pprint import pprint
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# SYSTEM
import os
import config
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")


# DL LIBS
import torch
import torch.nn as nn
from torch.nn.functional import relu
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

from segmentation_models_pytorch.utils.train import TrainEpoch, ValidEpoch



from sklearn.model_selection import train_test_split

import albumentations as album

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_training_augmentation():
    train_transform = [    
        album.RandomCrop(height=256, width=256, always_apply=True),
        album.OneOf(
            [
                album.HorizontalFlip(p=1),
                album.VerticalFlip(p=1),
                album.RandomRotate90(p=1),
            ],
            p=0.75,
        ),
    ]
    return album.Compose(train_transform)


def get_validation_augmentation():   
    # Add sufficient padding to ensure image is divisible by 32
    test_transform = [
        album.PadIfNeeded(min_height=1536, min_width=1536, always_apply=True, border_mode=0),
    ]
    return album.Compose(test_transform)


def to_tensor(x):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn=None):
    """Construct preprocessing transform    
    Args:
        preprocessing_fn (callable): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    """   
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)

In [3]:
class ImageDataset(Dataset):
    
    def __init__(self, image_list:str, mask_list:str, augmentation = None, preprocessing = None):
        self.image_list = image_list
        self.mask_list = mask_list
        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.image_list)


    def __getitem__(self, index):
        
        # GET ITEM
        image = cv2.resize(self.image_list[index], (config.IMAGE_SIZE , config.IMAGE_SIZE))
        #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transform(image)

        # GET MASK 
        mask = cv2.resize(self.mask_list[index], (config.IMAGE_SIZE , config.IMAGE_SIZE))
        #mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        mask = self.transform(mask)

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        

    def transform(self, image):
        img_mean, img_std = np.mean(image, axis=(0, 1)), np.std(image, axis=(0, 1))

        tensor_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=img_mean,std=img_std)
        ])
        
        return tensor_transforms(image)

In [4]:
# GET DATA
x:list = []
y:list = []

# GET X DATA
for filename in os.listdir(config.X_TRAIN_DATA_TOY):
    full_path = os.path.join(config.X_TRAIN_DATA, filename)
    image = cv2.imread(full_path)
    x.append(image)

# GET Y DATA
for filename in os.listdir(config.Y_TRAIN_DATA_TOY):
    full_path = os.path.join(config.Y_TRAIN_DATA, filename)
    mask = cv2.imread(full_path)
    y.append(mask)


# SPLIT DATA INTO TRAIN AND VAL
x_train, x_val, y_train, y_val = train_test_split(x,y,test_size=config.TEST_SPLIT, random_state=52)

In [5]:
# Get train and val dataset instances
train_dataset = ImageDataset(
    x_train, 
    y_train, 
    #augmentation=get_training_augmentation(),
    #preprocessing=get_preprocessing(preprocessing_fn=None),
)

valid_dataset = ImageDataset(
    x_val,
    y_val, 
    #augmentation=get_validation_augmentation(), 
    #preprocessing=get_preprocessing(preprocessing_fn=None),
)

# Get train and val data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4)

print(train_loader)

<torch.utils.data.dataloader.DataLoader object at 0x0000020550274210>


In [6]:
train_epoch = TrainEpoch(
    config.MODEL, 
    loss = config.LOSS, 
    metrics = config.METRICS, 
    optimizer = config.OPTIMIZER,
    device = config.DEVICE,
    verbose = True,
)

valid_epoch = ValidEpoch(
    config.MODEL, 
    loss = config.LOSS, 
    metrics = config.METRICS, 
    device = config.DEVICE,
    verbose = True,
)

In [10]:
%%time

if(config.TRAINING == True):

    best_iou = 0.0
    train_logs_list, valid_logs_list = [], []

    for index in range(0, config.EPOCHS):

        # Perform training & validation
        print('\nEpoch: {}'.format(index))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        # Save model if a better val IoU score is obtained
        if best_iou_score < valid_logs['iou_score']:
            best_iou = valid_logs['iou_score']
            torch.save(model, './best_model.pth at iteration {}'.format(index))
            print('Model saved!')


Epoch: 0
train:   0%|          | 0/1 [00:00<?, ?it/s]