### Install libraries

In [None]:
!pip install -qU wandb
!pip install -q segmentation-models-pytorch 
!pip install -q torchsummary
!pip install -q rasterio
!pip install -q colorama

In [None]:
%load_ext autoreload
%autoreload 2

## IMPORT LIBRARIES

In [None]:
# Misc
import copy
import time
import timm
import joblib
import random
import os, shutil
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from IPython import display as ipd
from collections import defaultdict
from joblib import Parallel, delayed
from matplotlib.patches import Rectangle 

# Sklearn
from sklearn.model_selection import train_test_split 
from sklearn.model_selection import StratifiedKFold, KFold

# Pytorch
import torch 
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.cuda import amp
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader

from torchsummary import summary
import segmentation_models_pytorch as smp 

# Image Processing Libraries
import cv2
import rasterio 
import skimage.io

from PIL import Image

# Albumentations for augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
c_  = Fore.GREEN
sr_ = Style.RESET_ALL

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
# Import and Login to Weight&Biases 
import wandb

try:
    wandb.login(key='29b54518076f707810c9bac855d80a73e7e057a5') # API Key
    anonymous = None
except:
    anonymous = "must"
    print('To use your W&B account,\nGo to Add-ons -> Secrets and provide your W&B access token. Use the Label name as WANDB. \nGet your W&B access token from here: https://wandb.ai/authorize')

In [None]:
import gc
gc.collect()

## CONFIGURATIONS

In [None]:
class CFG:
    seed          = 101
    debug         = False # set debug=False for Full Training
    exp_name      = 'Baselinev1'
    comment       = 'UNet-Resnet152-512x512-aug2-split2'
    model_name    = 'UNet'
    encoder       = 'resnet152'
    train_bs      = 5
    valid_bs      = train_bs
    img_size      = [512, 512]
    size          = 256
    epochs        = 20
    lr            = 1e-3
    scheduler     = 'CosineAnnealingLR'
    min_lr        = 1e-6
    T_max         = int(10000/train_bs*epochs)+50
    T_0           = 25
    warmup_epochs = 0
    wd            = 1e-4
    n_fold        = 5
    num_classes   = 2
    device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## REPRODUCIBILITY

In [None]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed(CFG.seed)

## IMAGE PRE-PROCESSING

#### Load Data

In [None]:
# Image and Mask Path
IMAGE_PATH = '../input/seagrass/train_images/training_images/'
MASK_PATH = '../input/seagrass/train_label/training_label/'

In [None]:
n_classes = 2

def create_df():
    img_name = []
    for dirname, _, filenames in os.walk(IMAGE_PATH):
        for filename in filenames:
            img_name.append(filename.split('.')[0])

    mask_name = []
    for dirname, _, filenames in os.walk(MASK_PATH):
        for filename in filenames:
            mask_name.append(filename.split('.')[0])
            
    img_name = sorted(img_name)
    mask_name = sorted(mask_name)

    return pd.DataFrame({'img_id': img_name, 'mask_id': mask_name}, index = np.arange(0, len(img_name)))

df = create_df()
print('Total Images: ', len(df))

In [None]:
df # list of drone images

#### Split data to Train, Test, and Validation

In [None]:
# Split data
X_trainval, X_test = train_test_split(df['img_id'].values, test_size=0.1, random_state=19)
X_train, X_val = train_test_split(X_trainval, test_size=0.15, random_state=19)

print('Train Size   : ', len(X_train))
print('Val Size     : ', len(X_val))
print('Test Size    : ', len(X_test))

In [None]:
img = skimage.io.imread(IMAGE_PATH + df['img_id'][500] + '.tif', plugin='tifffile')
mask = skimage.io.imread(MASK_PATH + df['mask_id'][500] + '.tif', plugin='tifffile')
print('Image Size', np.asarray(img).shape)
print('Mask Size', np.asarray(mask).shape)


plt.imshow(img)
plt.imshow(mask, alpha=0.30)
plt.title('Picture with Mask Appplied')
plt.show()

In [None]:
class DroneDataset(Dataset):
    
    def __init__(self, img_path, mask_path, X, transform=None, patch=False):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.mean = [0.4855, 0.5464, 0.4754]
        self.std = [0.1446, 0.1593, 0.1308]
        self.transform = transform
        self.patches = patch
        self.count = 0
   
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        
        img = skimage.io.imread(self.img_path + self.X[idx] + '.tif', plugin='tifffile')
        img = np.array(img, dtype='uint8')
        img = cv2.resize(img, dsize=(CFG.size,CFG.size), interpolation=cv2.INTER_AREA)
        
        mask = skimage.io.imread(self.mask_path + self.X[idx] + '.tif', plugin='tifffile')
        mask = np.array(mask, dtype='uint8')
        mask = cv2.resize(mask, dsize=(CFG.size,CFG.size), interpolation=cv2.INTER_AREA)

        if self.transform is not None:
            aug = self.transform(image=img, mask=mask)
            img = Image.fromarray(aug['image'])
            mask = aug['mask']
        
        if self.transform is None:
            img = Image.fromarray(img)
        
        t = T.Compose([T.ToTensor()])
        mean, std = t(img).mean([1,2]), t(img).std([1,2])

        transform = T.Compose([T.ToTensor(), T.Normalize(self.mean, self.std)])
        img = transform(img)
        
        #t = T.Compose([T.ToTensor(), T.Normalize(self.mean, self.std)])
        #img = t(img)
      
        mask = torch.from_numpy(mask).long()
        
        if self.patches:
            img, mask = self.tiles(img, mask)
            
        return img, mask
    
    def tiles(self, img, mask):

        img_patches = img.unfold(1, CFG.size, CFG.size).unfold(2, CFG.size, CFG.size) 
        img_patches  = img_patches.contiguous().view(3,-1, CFG.size, CFG.size) 
        img_patches = img_patches.permute(1,0,2,3)
        
        mask_patches = mask.unfold(0, CFG.size, CFG.size).unfold(1, CFG.size, CFG.size)
        mask_patches = mask_patches.contiguous().view(-1, CFG.size, CFG.size)
        
        return img_patches, mask_patches

#### Data augmentation

In [None]:
# Data Augmentation
t_train = A.Compose([A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), 
                     A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
                     A.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25, p=0.75),
                     A.GridDistortion(num_steps=5, distort_limit=0.3, interpolation=1, p=0.5), 
                     A.RandomBrightnessContrast((0,0.5),(0,0.5)),
                     A.GaussNoise()])

t_val = A.Compose([A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), 
                   A.GridDistortion(num_steps=5, distort_limit=0.3, interpolation=1, p=0.5)])

#datasets

train_set = DroneDataset(IMAGE_PATH, MASK_PATH, X_train, t_train, patch=False)
train_loader = DataLoader(train_set, batch_size=CFG.train_bs, shuffle=True, num_workers=0, pin_memory=True)

val_set = DroneDataset(IMAGE_PATH, MASK_PATH, X_val, t_val, patch=False)
val_loader = DataLoader(val_set, batch_size=CFG.valid_bs, shuffle=True, num_workers=0, pin_memory=True)  

In [None]:
imgs, msks = next(iter(train_loader))
imgs.size(), msks.size()

## MODEL

In [None]:
def build_model():
    model = smp.Unet(
        encoder_name = CFG.encoder,
        encoder_weights = 'imagenet',
        encoder_depth = 5,
        classes = CFG.num_classes,
        activation = 'sigmoid',
        decoder_channels = [256, 128, 64, 32, 16]
    )

    model.to(CFG.device)
    return model

def load_model(path):

    model = build_model()
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

#### Loss Function

In [None]:
# Loss Function
CrossEntropyLoss = nn.CrossEntropyLoss()
BCELoss = nn.BCELoss() 

# Metrics
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy

def dice_coef(pred_mask, mask, smooth = 1):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)
        
        mask = torch.flatten(mask).to(torch.float32)
        pred_mask = torch.flatten(pred_mask).to(torch.float32)
        
        intersection = torch.sum(mask * pred_mask)
        score = (2. * intersection + smooth) / (torch.sum(mask) + torch.sum(pred_mask) + smooth) 
        
    return score.cpu().numpy()

def mIoU(pred_mask, mask, smooth=1e-10, n_classes=23):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes): #loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0: #no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union +smooth)
                iou_per_class.append(iou)
                
        return np.nanmean(iou_per_class)

#### Train One Epoch Function

In [None]:
def train_one_epoch(model, optimizer, scheduler, criterion, dataloader, device, epoch, patch=False):

    model.train()

    scaler = amp.GradScaler()

    dataset_size = 0
    running_loss = 0.0

    train_scores = []

    pbar = tqdm(enumerate(dataloader), total = len(dataloader), desc = 'Train')
    for steps, (images, masks) in pbar:

        images = images.to(device)
        masks = masks.to(device)

        batch_size = images.size(0)

        y_pred = model(images)
        loss = criterion(y_pred, masks)
        
        loss.backward()
        
        # Zero the parameter gradients
        optimizer.step()  
        optimizer.zero_grad()
        
        scheduler.step()
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = running_loss / dataset_size
        
        train_iou = mIoU(y_pred, masks)
        train_acc = pixel_accuracy(y_pred, masks)
        train_dce = dice_coef(y_pred, masks)
        train_scores.append([train_iou, train_acc, train_dce]) 

        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(train_loss = f'{epoch_loss:0.4f}',
                        lr = f'{current_lr:0.5f}',
                        gpu_mem = f'{mem:0.2f} GB')
    
    train_scores = np.mean(train_scores, axis = 0)
      
    torch.cuda.empty_cache()
    gc.collect()

    return epoch_loss, train_scores

#### Validate One Epoch Function

In [None]:
@torch.no_grad()
def valid_one_epoch(model, dataloader, criterion, device, epoch):

    model.eval()

    dataset_size = 0
    running_loss = 0.0

    val_scores = []

    pbar = tqdm(enumerate(dataloader), total = len(dataloader), desc = 'Validation')
    for steps, (images, masks) in pbar:

        images = images.to(device)
        masks = masks.to(device)

        batch_size = images.size(0)

        y_pred = model(images)
        loss = criterion(y_pred, masks)

        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = running_loss / dataset_size

        val_iou = mIoU(y_pred, masks)
        val_acc = pixel_accuracy(y_pred, masks)
        val_dce = dice_coef(y_pred, masks)
        val_scores.append([val_iou, val_acc, val_dce])

        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(valid_loss = f'{epoch_loss:0.4f}',
                          lr = f'{current_lr:0.5f}',
                          gpu_memory = f'{mem:0.2f} GB')
      
    val_scores = np.mean(val_scores, axis=0)
    torch.cuda.empty_cache()
    gc.collect()

    return epoch_loss, val_scores

#### Run Training Function

In [None]:
def run_training(model, optimizer, scheduler, criterion, device, epochs):
    
    # log gradients to wandb
    wandb.watch(model, criterion, log='all', log_freq=100)

    if torch.cuda.is_available():
        print ("cuda: {}\n".format(torch.cuda.get_device_name()))

    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_iou = -np.inf
    best_dce = -np.inf
    best_acc = -np.inf
    best_epoch = -1
    history = defaultdict(list)

    for epoch in range(1, epochs + 1):
        gc.collect()
        print(f'Epoch {epoch}/{epochs}', end = '')
        
        train_loss, train_scores = train_one_epoch(model, optimizer, scheduler, criterion, 
                                     dataloader = train_loader,
                                     device = CFG.device, epoch = epoch)

        val_loss, val_scores = valid_one_epoch(model, val_loader, criterion, 
                                               device = CFG.device, epoch = epoch)
        train_iou, train_acc, train_dce = train_scores
        val_iou, val_acc, val_dce = val_scores

        history['Train Loss'].append(train_loss)
        history['Train IoU'].append(train_iou)
        history['Train Pixel Accuracy'].append(train_acc)
        history['Train Dice'].append(train_dce)
        history['Valid Loss'].append(val_loss)
        history['Valid IoU'].append(val_iou)
        history['Valid Pixel Accuracy'].append(val_acc)
        history['Valid Dice'].append(val_dce)
      
        # Log the metrics
        wandb.log({"Train Loss": train_loss,
                   "Train Iou": train_iou,
                   "Train Pixel Accuracy": train_acc,
                   "Train Dice": train_dce,
                   "Valid Loss": val_loss,
                   "Valid IoU": val_iou,
                   "Valid Pixel Accuracy": val_acc,
                   "Valid Dice": val_dce,
                   "LR":scheduler.get_last_lr()[0]})

        print(f'Train IoU: {train_iou:0.4f} | Train Pixel Accuracy: {train_acc:0.4f} | Train Dice: {train_dce:0.4f}')
        print(f'Valid IoU: {val_iou:0.4f} | Valid Pixel Accuracy: {val_acc:0.4f} | Valid Dice: {val_dce:0.4f}')
  

        # Deep copy model 
        if val_iou >= best_iou:
            print(f"{c_}Valid IoU Improved ({best_iou:0.4f} ---> {val_iou:0.4f})")
            print(f"{c_}Valid Dice Coef Improved ({best_dce:0.4f} ---> {val_dce:0.4f})")
            print(f"{c_}Valid Pixel Accuracy Improved ({best_acc:0.4f} ---> {val_acc:0.4f})")
            best_iou = val_iou
            best_acc = val_acc
            best_epoch = epoch
            run.summary["Best IoU"] = best_iou
            run.summary["Best Dice Coef"] = best_dce
            run.summary["Best Accuracy"] = best_acc
            run.summary["Best Epoch"] = best_epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"best_epoch-{best_epoch:02d}.bin"
            torch.save(model.state_dict(), PATH)
            wandb.save(PATH)
            print(f"Model Saved{sr_}")
        
        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = f"last_epoch-{epoch:02d}.bin"
        torch.save(model.state_dict(), PATH)

        print() ; print()

    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print('Best Score: {:.4f}'.format(best_iou))

    model.load_state_dict(best_model_wts)

    return model, history

#### Scheduler Function

In [None]:
def fetch_scheduler(optimizer):
    if CFG.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CFG.T_max, 
                                                   eta_min=CFG.min_lr)
    elif CFG.scheduler == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CFG.T_0, 
                                                             eta_min=CFG.min_lr)
    elif CFG.scheduler == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   mode='min',
                                                   factor=0.1,
                                                   patience=7,
                                                   threshold=0.0001,
                                                   min_lr=CFG.min_lr,)
    elif CFG.scheduer == 'ExponentialLR':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.85)
    elif CFG.scheduler == None:
        return None
        
    return scheduler

In [None]:
model = build_model()
optimizer = optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
scheduler = fetch_scheduler(optimizer)

## TRAINING

In [None]:
run = wandb.init(project='bluecares-seagrass-dl', 
                config={k:v for k, v in dict(vars(CFG)).items() if '__' not in k},
                anonymous=anonymous,
                name=f"model-{CFG.model_name}|encoder-{CFG.encoder}|dim-{CFG.img_size[0]}x{CFG.img_size[1]}",
                group=CFG.comment,
)

model = build_model()
optimizer = optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
scheduler = fetch_scheduler(optimizer)

#optimizer = torch.optim.AdamW(model.parameters(), lr = CFG.lr, weight_decay=CFG.wd)
#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, CFG.lr, epochs=CFG.epochs,
#                                            steps_per_epoch=len(train_loader))

model, history = run_training(model, optimizer, scheduler, CrossEntropyLoss, 
                  device=CFG.device,
                  epochs=CFG.epochs)
run.finish()
display(ipd.IFrame(run.url, width=1000, height=720)) 