# Inference Notebook

## Simple Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2
import matplotlib.pyplot as plt

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

## Data Generator

In [None]:
class Dataset(BaseDataset):

    
    CLASSES = ['building']
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)

## Augmentations

In [None]:
import albumentations as albu

In [None]:
def get_training_augmentation():
    train_transform = [

        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),

        albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        albu.RandomCrop(height=320, width=320, always_apply=True),

        albu.IAAAdditiveGaussianNoise(p=0.2),
        albu.IAAPerspective(p=0.5),
        
        albu.IAAAffine(p=1),
                
        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
                albu.IAAAffine(p=1)
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.IAASharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
                albu.IAAAffine(p=1)
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.HueSaturationValue(p=1),
                albu.IAAAffine(p=1)
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(384, 480)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

## Define Model

In [None]:
import torch
import numpy as np
import segmentation_models_pytorch as smp

In [None]:
ENCODER = 'se_resnext101_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['building']
ACTIVATION = 'sigmoid'
DEVICE = 'cuda'

model = smp.UnetPlusPlus(
                encoder_name=ENCODER, 
                encoder_weights=ENCODER_WEIGHTS, 
                classes=len(CLASSES), 
                activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
# Load saved checkpoint
best_model = torch.load('./UnetPP_Istanbul.pth') # Path of the weights file

In [None]:
# Define path of the folder containing images to be predicted.
# For the sake of data generator, define y_test_dir same as x_test_dir.
# You will define prediction outpath two cells below.
x_test_dir = "TEST_IMAGE/"
y_test_dir = "TEST_IMAGE/"

In [None]:
# Create test dataset
test_dataset = Dataset(
    x_test_dir, 
    y_test_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

test_dataloader = DataLoader(test_dataset)

## Prediction - GeoTiff

In [None]:
from skimage import io
from osgeo import osr, gdal
import tifffile
prediction_images = os.listdir(x_test_dir)
output_path = '<path to dir>/TEST_PRED/' # Provide full directory path for prediction output folder.


for i in range(len(test_dataset)):

    
    image, gt_mask = test_dataset[i]
    
    gt_mask = gt_mask.squeeze()
    print(prediction_images[i])
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    
    image_dir = x_test_dir + prediction_images[i]
    out_path = output_path + prediction_images[i]
    im = tifffile.imread(image_dir)
    width = im.shape[1]
    height = im.shape[0]
    ds = gdal.Open(image_dir)
    tfw = ds.GetGeoTransform()
    prj = ds.GetProjection()
    srs = osr.SpatialReference(wkt=prj)
    driver = gdal.GetDriverByName("GTiff")
    
    outdata = driver.Create(out_path, height, width, 1, gdal.GDT_Byte)
    outdata.SetGeoTransform(tfw)##sets same geotransform as input
    outdata.SetProjection(ds.GetProjection())##sets same projection as input
    outdata.GetRasterBand(1).WriteArray(pr_mask)
    outdata.FlushCache() ##saves to disk!!
    outdata = None
    ds=None