# Predict images using trained UNET and compute entropy

## 1.0 Import libraries

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from osgeo import gdal
import glob

from scipy.stats import entropy
from scipy.special import softmax

In [2]:
import torch
from torch import nn
from torchgeo.samplers import RandomGeoSampler, GridGeoSampler
from torch.utils.data import DataLoader
from torchgeo.datasets import stack_samples
from torchgeo.datasets import RasterDataset
from torchsummary import summary
#import pylab as plt
# torch.cuda.is_available()

# 2.0 Load and prepare datasets

In [3]:
# Create CLASS for images and labels
class BengaluruDatasetImages(RasterDataset):
    """
    Load image data that ends in *.tif
    """

    # filename_glob = "*.tif"
    def __init__(self, root, filename_glob='*.tif', **kwargs):
        self.filename_glob = filename_glob
        super().__init__(root, **kwargs)
    

class BengaluruDatasetLabels(RasterDataset):
    """
    Load label data that ends in *.tif
    """
    is_image = False
    filename_glob = "*.tif"

In [4]:
# base path of the dataset
TRAIN_PATH = os.path.join("Data", "Train")
VALID_PATH = os.path.join("Data", "Valid")
TEST_PATH = os.path.join("Data", "Test")
#tr_labels = lambda x: x[:,0,:,:].long()

In [5]:
# Create CLASS for image and label transformation

class TransformBengaluruImages(nn.Module):
    """
    Apply Min and Max scale to the image
    """

    def __init__(self):
        super().__init__()

    def forward(self, inputs):
        inputs["image"] -= inputs["image"].min()
        inputs["image"] /= inputs["image"].max()
        
        return inputs
    
class TransformBengaluruLabels(nn.Module):
    """
    Create additional dimensions for the labels np arrays
    """

    def __init__(self):
        super().__init__()

    def forward(self, inputs):
        # Batch
        if inputs["mask"].ndim == 4:
            inputs["mask"] = (inputs["mask"][:,0,:,:]).long()
        # Sample
        else:
            inputs["mask"] = (inputs["mask"][0,:,:]).long()
        
        return inputs



In [6]:
# Create Train and Test datasets
tr_im = BengaluruDatasetImages(os.path.join(TRAIN_PATH, "Images"), transforms=TransformBengaluruImages())
tr_la = BengaluruDatasetLabels(os.path.join(TRAIN_PATH, "Labels"), transforms=TransformBengaluruLabels())

TRAIN_DS = tr_im & tr_la

# Create Train and Test datasets
vl_im = BengaluruDatasetImages(os.path.join(VALID_PATH, "Images"), transforms=TransformBengaluruImages())
vl_la = BengaluruDatasetLabels(os.path.join(VALID_PATH, "Labels"), transforms=TransformBengaluruLabels())

VALID_DS = vl_im & vl_la

ts_im = BengaluruDatasetImages(os.path.join(TEST_PATH, "Images"), transforms=TransformBengaluruImages())
ts_la = BengaluruDatasetLabels(os.path.join(TEST_PATH, "Labels"), transforms=TransformBengaluruLabels())

TEST_DS = ts_im & ts_la

In [7]:
# Creat UNET CLASS
class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = self.contract_block(in_channels, 32, 3, 1)
        self.conv2 = self.contract_block(32, 64, 3, 1)
        self.conv3 = self.contract_block(64, 128, 3, 1)
        self.conv4 = self.contract_block(128, 256, 3, 1)

        self.upconv4 = self.expand_block(256, 128, 3, 1)
        self.upconv3 = self.expand_block(128*2, 64, 3, 1)
        self.upconv2 = self.expand_block(64*2, 32, 3, 1)
        self.upconv1 = self.expand_block(32*2, out_channels, 3, 1)

    def __call__(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)

        upconv4 = self.upconv4(conv4)
        upconv3 = self.upconv3(torch.cat([upconv4, conv3], 1))
        upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))

        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 
                            )
        return expand

In [8]:
unet_model = torch.load('UNET/unet_model_128px_100epoches.pt', map_location=torch.device('cpu'))
unet_model.eval()

UNET(
  (conv1): Sequential(
    (0): Conv2d(5, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv2d(64, 128, kerne

## 3.0 Define GEOIMAGE Class

In [9]:
class GeoImage():
    def __init__(self, img_src, n_classes=4):
        self.img_src = img_src
        self.n_classes = n_classes
        
        self.ds = gdal.Open(self.img_src)
        self.band = self.ds.GetRasterBand(1)
        self.arr = self.band.ReadAsArray()
        
        img_src_split = img_src.split('/')
        print('/'.join(img_src_split[:-1]))
        print(img_src_split[-1])
        self.ds2 = BengaluruDatasetImages('/'.join(img_src_split[:-1]), filename_glob=img_src_split[-1], transforms=TransformBengaluruImages())

        self.predictions = np.zeros(self.arr.shape[:2], int)
        self.probabilities = np.zeros([self.n_classes, *self.arr.shape[:2]], float)
        
    def create_dataloader(self, size=128, batch_size=128, stride=110):
        self.sampler = GridGeoSampler(self.ds2, size=size, stride=stride, roi=None)
        dataloader  = DataLoader(self.ds2, batch_size, sampler=self.sampler, collate_fn=stack_samples)
        return dataloader
    
    def predict(self, net, device='cpu', dataloader=None):
        if dataloader is None:
            dataloader = self.create_dataloader()
        
        for sample in dataloader:
            pred_batch = net(sample["image"]).to(device)

            for idx, pred in enumerate(pred_batch):
                prob_numpy = pred.detach().numpy()
                pred_numpy = np.argmax(prob_numpy, 0)
                self.set_predictions(pred_numpy, sample['bbox'][idx])
                self.set_probabilities(prob_numpy, sample['bbox'][idx])
        
    def get_pixel_coordinates(self, bbox):
        org_b = self.ds2.bounds
        range_x = org_b.maxx - org_b.minx
        range_y = org_b.maxy - org_b.miny
        img_b = bbox
        minx_pi = int(np.round((img_b.minx - org_b.minx) / range_x * self.arr.shape[0]))
        #maxx_pi = minx_pi+128
        maxx_pi = int(np.round((img_b.maxx - org_b.minx) / range_x * self.arr.shape[0]))
        miny_pi = int(np.round((img_b.miny - org_b.miny) / range_y * self.arr.shape[1]))
        #maxy_pi = miny_pi+128
        maxy_pi = int(np.round((img_b.maxy - org_b.miny) / range_y * self.arr.shape[1]))
        
        return minx_pi, maxx_pi, miny_pi, maxy_pi
    
    def set_predictions(self, pred, bbox):
        minx_pi, maxx_pi, miny_pi, maxy_pi = self.get_pixel_coordinates(bbox)
        self.predictions[minx_pi:maxx_pi, miny_pi:maxy_pi] = np.rot90(pred, -1)
    
    def set_probabilities(self, prob, bbox):
        minx_pi, maxx_pi, miny_pi, maxy_pi = self.get_pixel_coordinates(bbox)
        self.probabilities[:, minx_pi:maxx_pi, miny_pi:maxy_pi] = np.rot90(prob, -1, axes=(1,2))
    
    def get_predictions(self):
        return np.rot90(self.predictions)
    
    def get_probabilities(self):
        return np.rot90(self.probabilities, axes=(1,2))
    
    def write_tif_files(self, name, folder=''):
        self.write_geotiff(f'{folder}{name}_pred.tif' , self.get_predictions())
        probs = self.get_probabilities()*1000
        for c in range(self.n_classes):
            self.write_geotiff(f'{folder}{name}_prob_class{c}.tif', probs[c])
    
    def write_geotiff(self, filename, arr):
        if arr.dtype == np.float32:
            arr_type = gdal.GDT_Float32
        else:
            arr_type = gdal.GDT_Int32

        driver = gdal.GetDriverByName("GTiff")
        out_ds = driver.Create(filename, arr.shape[1], arr.shape[0], 1, arr_type)
        out_ds.SetProjection(self.ds.GetProjection())
        out_ds.SetGeoTransform(self.ds.GetGeoTransform())
        band = out_ds.GetRasterBand(1)
        band.WriteArray(arr)
        band.FlushCache()
        band.ComputeStatistics(False)
        

def read_geotiff(filename):
    ds = gdal.Open(filename)
    band = ds.GetRasterBand(1)
    arr = band.ReadAsArray()
    return arr, ds


## 4.0 Apply model to images

### 4.1 List and read WV images

In [10]:
image_path = r'Data/Train/Images/WV*.tif'

image_list = glob.glob(image_path)
image_list = sorted(image_list)
print(image_list)

['Data/Train/Images/WV_refl_B23567_P01.tif', 'Data/Train/Images/WV_refl_B23567_P04.tif', 'Data/Train/Images/WV_refl_B23567_P07.tif', 'Data/Train/Images/WV_refl_B23567_P09.tif', 'Data/Train/Images/WV_refl_B23567_P14.tif', 'Data/Train/Images/WV_refl_B23567_P16.tif', 'Data/Train/Images/WV_refl_B23567_P17.tif', 'Data/Train/Images/WV_refl_B23567_P18.tif', 'Data/Train/Images/WV_refl_B23567_P19.tif']


In [11]:
entropy_list = []
for im in image_list:
    # print(im)
    img = GeoImage(im)
    img.predict(net=unet_model)
    
    probs = softmax(img.get_probabilities(), axis=0)
    ent = entropy(probs, axis=0).mean()
    entropy_list.append(ent)
    
    tif_name = '1_' + im[-7:-4]
    print(tif_name)
    
    img.write_tif_files(name=tif_name, folder='PRED/First/')
    
entropy_list

Data/Train/Images
WV_refl_B23567_P01.tif
1_P01
Data/Train/Images
WV_refl_B23567_P04.tif
1_P04
Data/Train/Images
WV_refl_B23567_P07.tif
1_P07
Data/Train/Images
WV_refl_B23567_P09.tif
1_P09
Data/Train/Images
WV_refl_B23567_P14.tif
1_P14
Data/Train/Images
WV_refl_B23567_P16.tif
1_P16
Data/Train/Images
WV_refl_B23567_P17.tif
1_P17
Data/Train/Images
WV_refl_B23567_P18.tif
1_P18
Data/Train/Images
WV_refl_B23567_P19.tif
1_P19


[0.44662489528875216,
 0.3590839139655868,
 0.35541934988930063,
 0.3979224458973133,
 0.4842752838202053,
 0.48635569265786815,
 0.47112137627424877,
 0.33142849087068815,
 0.36622887440403756]