**miniRaVAEn** = code with minimal dependencies on large libraries or codebases

In [None]:
import os
import time
import math
import glob
import fsspec
import rasterio
import numpy as np
import pylab as plt
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from abc import abstractmethod
from typing import List, Any, Dict, Tuple
from argparse import Namespace

import torch
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.functional import cosine_similarity

def cosine_distance_score(mu1: Tensor, mu2: Tensor):
    return 1-cosine_similarity(mu1, mu2)

import pytorch_lightning as pl
print(pl.__version__)

In [None]:
PATH_model_weights = "content/last.ckpt"

# Load data

#### title Init data functions

In [None]:
def load_all_tile_indices_from_folder(settings_dataset):
    path = settings_dataset["data_base_path"]

    isDirectory = os.path.isdir(path)
    if isDirectory:
        # A directory, load all tifs inside
        allfiles = glob.glob(path+"/*.tif")
        allfiles.sort()
    elif ".tif" in path:
        # A single file, load that one directly
        allfiles = [path]

    tiles = []

    for idx,filename in enumerate(allfiles):

        tiles_from_file = file_to_tiles_indices(filename, settings_dataset,
            tile_px_size = settings_dataset["tile_px_size"], tile_overlap_px = settings_dataset["tile_overlap_px"],
            include_last_row_colum_extra_tile = settings_dataset["include_last_row_colum_extra_tile"])

        tiles += tiles_from_file
        print(idx, filename, "loaded", len(tiles_from_file), "tiles.")


    print("Loaded:", len(tiles), "total tile indices")
    return tiles


In [None]:
def file_to_tiles_indices(filename, settings, tile_px_size = 128, tile_overlap_px = 4,
                          include_last_row_colum_extra_tile = True):
    """
    Opens one tif file and extracts all tiles (given tile size and overlap).
    Returns list of indices to the tile (to postpone in memory loading).
    """

    with rasterio.open(filename) as src:
        filename_shape = src.height, src.width

    data_h, data_w = filename_shape
    if data_h < tile_px_size or data_w < tile_px_size:
        # print("skipping, too small!")
        return []

    h_tiles_n = int(np.floor((data_h-tile_overlap_px) / (tile_px_size-tile_overlap_px)))
    w_tiles_n = int(np.floor((data_w-tile_overlap_px) / (tile_px_size-tile_overlap_px)))

    tiles = []
    tiles_X = []
    tiles_Y = []
    for h_idx in range(h_tiles_n):
            for w_idx in range(w_tiles_n):
                    tiles.append([w_idx * (tile_px_size-tile_overlap_px), h_idx * (tile_px_size-tile_overlap_px)])
    if include_last_row_colum_extra_tile:
            for w_idx in range(w_tiles_n):
                    tiles.append([w_idx * (tile_px_size-tile_overlap_px), data_h - tile_px_size])
            for h_idx in range(h_tiles_n):
                    tiles.append([data_w - tile_px_size, h_idx * (tile_px_size-tile_overlap_px)])
            tiles.append([data_w - tile_px_size, data_h - tile_px_size])

    # Save file ID + corresponding tiles[]
    tiles_indices = [[filename]+t+[tile_px_size,tile_px_size] for t in tiles]
    return tiles_indices

In [None]:
ONCE_PRINT = True
def load_tile_idx(tile, settings):
    """
    Loads tile data values from the saved indices (file and window locations).
    """
    filename, x, y, w, h = tile

    # load window:
    window = rasterio.windows.Window(row_off=y, col_off=x, width=w, height=h)

    if settings['bands'] is None:
        # Load all
        with rasterio.open(filename) as src:
            tile_data = src.read(window=window)
    else:
        bands = [b+1 for b in settings['bands']]

        global ONCE_PRINT
        if ONCE_PRINT:
            print("DEBUG - loaded bands",bands)
            ONCE_PRINT = False
        with rasterio.open(filename) as src:
            tile_data = src.read(bands, window=window)

    tile_data = np.float32(tile_data)
    if settings['nan_to_num']:
        tile_data = np.nan_to_num(tile_data)

    dummy = np.zeros_like(tile_data)
    return tile_data, dummy

In [None]:
class DataNormalizerLogManual():

    def __init__(self, settings):
        self.settings_dataset = settings["dataset"]
        self.normalization_parameters = None

    def setup(self, data_module):
        self.BANDS_S2_BRIEF = ["B1","B2","B3","B4","B5","B6","B7","B8"]
        self.RESCALE_PARAMS = {
            "B1" : {  "x0": band_params['b0_x0'],
                      "x1": band_params['b0_x1'],
                      "y0": -1,
                      "y1": 1,
            },
            "B2" : {  "x0": band_params['b1_x0'],
                      "x1": band_params['b1_x1'],
                      "y0": -1,
                      "y1": 1,
            },
            "B3" : {  "x0": band_params['b2_x0'],
                      "x1": band_params['b2_x1'],
                      "y0": -1,
                      "y1": 1,
            },
            "B4" : {  "x0": band_params['b3_x0'],
                      "x1": band_params['b3_x1'],
                      "y0": -1,
                      "y1": 1,
            },
            "B5" : {  "x0": band_params['b4_x0'],
                      "x1": band_params['b4_x1'],
                      "y0": -1,
                      "y1": 1,
            },
            "B6" : {  "x0": band_params['b5_x0'],
                      "x1": band_params['b5_x1'],
                      "y0": -1,
                      "y1": 1,
            },
            "B7" : {  "x0": band_params['b6_x0'],
                      "x1": band_params['b6_x1'],
                      "y0": -1,
                      "y1": 1,
            },
            "B8" : {  "x0": band_params['b7_x0'],
                      "x1": band_params['b7_x1'],
                      "y0": -1,
                      "y1": 1,
            }            
        }
        print("normalization params are manually found")

    def normalize_x(self, data):
        bands = data.shape[0] # for example 15
        for band_i in range(bands):
            data_one_band = data[band_i,:,:]
            if band_i < len(self.BANDS_S2_BRIEF):
                # log
                data_one_band = np.log(data_one_band)
                data_one_band[np.isinf(data_one_band)] = np.nan

                # rescale
                r = self.RESCALE_PARAMS[self.BANDS_S2_BRIEF[band_i]]
                x0,x1,y0,y1 = r["x0"], r["x1"], r["y0"], r["y1"]
                data_one_band = ((data_one_band - x0) / (x1 - x0)) * (y1 - y0) + y0
            data[band_i,:,:] = data_one_band
        return data

    def denormalize_x(self, data):
        bands = data.shape[0] # for example 15
        for band_i in range(bands):
            data_one_band = data[band_i,:,:]
            if band_i < len(self.BANDS_S2_BRIEF):

                # rescale
                r = self.RESCALE_PARAMS[self.BANDS_S2_BRIEF[band_i]]
                x0,x1,y0,y1 = r["x0"], r["x1"], r["y0"], r["y1"]
                data_one_band = (((data_one_band - y0) / (y1 - y0)) * (x1 - x0)) + x0


                # undo log
                data_one_band = np.exp(data_one_band)

            data[band_i,:,:] = data_one_band
        return data

In [None]:
# Torch Dataset:

class TileDataset(Dataset):
    # Main class that holds a dataset with smaller tiles originally extracted from larger geotiff files
    # Minimal impact on memory, loads actual data of x only in __getitem__ (when loading a batch of data)
    # Additional functionality:
    # - Load useful statistics for its tiles (such as the number of plume pixels in the label)
    # - Filter itself using those statistics (example: keep valid tiles, or only tiles with plumes, etc...)
    # - Spawn filtered tiles (to later make train / test / val splits ...)
    def __init__(self, tiles, settings_dataset, data_normalizer=None):
        self.tiles = tiles
        self.settings_dataset = settings_dataset
        self.data_normalizer = data_normalizer

    def __len__(self):
        return len(self.tiles)

    def __getitem__(self, idx):
        tile = self.tiles[idx]
        # Load only when needed:
        x,y = load_tile_idx(tile, self.settings_dataset)

        if self.data_normalizer is not None:
            x = self.data_normalizer.normalize_x(x)

        x = torch.from_numpy(x)
        y = torch.from_numpy(y)
        return x,y

In [None]:
# Pytorch Lightning Module

class DataModule(pl.LightningDataModule):

    def __init__(self, settings, data_normalizer):
        super().__init__()
        self.settings = settings
        self.data_normalizer = data_normalizer

        self.batch_size = self.settings["dataloader"]["batch_size"]
        self.num_workers = self.settings["dataloader"]["num_workers"]

        self.train_ratio = self.settings["dataloader"]["train_ratio"]
        self.validation_ratio = self.settings["dataloader"]["validation_ratio"]
        self.test_ratio = self.settings["dataloader"]["test_ratio"]

        self.setup_finished = False

    def prepare_data(self):
        # Could contain data download and unpacking...
        pass

    def setup(self, stage=None):
        if self.setup_finished:
            return True # to prevent double setup

        tiles = load_all_tile_indices_from_folder(self.settings["dataset"])
        print("Altogether we have", len(tiles), "tiles.")

        if self.train_ratio == 1.0:
            tiles_train = tiles
            tiles_test = []
            tiles_val = []
        else:
            tiles_train, tiles_rest = train_test_split(tiles, test_size=1 - self.train_ratio)
            tiles_val, tiles_test = train_test_split(tiles_rest, test_size=self.test_ratio/(self.test_ratio + self.validation_ratio))

        print("train, test, val:",len(tiles_train), len(tiles_test), len(tiles_val))

        self.train_dataset = TileDataset(tiles_train, self.settings["dataset"], self.data_normalizer)
        self.test_dataset = TileDataset(tiles_test, self.settings["dataset"], self.data_normalizer)
        self.val_dataset = TileDataset(tiles_val, self.settings["dataset"], self.data_normalizer)

        self.setup_finished = True

    def train_dataloader(self):
        """Initializes and returns the training dataloader"""
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                            shuffle=False, num_workers=self.num_workers)

    def val_dataloader(self, num_workers=None):
        """Initializes and returns the validation dataloader"""
        num_workers = num_workers or self.num_workers
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                            shuffle=False, num_workers=num_workers)

    def test_dataloader(self, num_workers=None):
        """Initializes and returns the test dataloader"""
        num_workers = num_workers or self.num_workers
        return DataLoader(self.test_dataset, batch_size=self.batch_size,
                            shuffle=False, num_workers=num_workers)

In [None]:
BANDS = [0,1,2,3,4,5,6,7]

In [None]:
settings = {'dataloader': {
                'batch_size': 8,
                'num_workers': 4,
                'train_ratio': 1.00,
                'validation_ratio': 0.00,
                'test_ratio': 0.00,
            },
            'dataset': {
                #'data_base_path': 'Gandahar_Mkt_after_4b',
                #'data_base_path': 'El_Fasher_after_4b',
                #'data_base_path': 'Muqrin_after_4b',
                #'data_base_path': 'Jaranga_after_4b',
                #'data_base_path': 'Sarafaya_after_4b',
                #'data_base_path': 'Sarafaya_after_8b',
                #'data_base_path': 'Muqrin_after_8b',
                #'data_base_path': 'Jaranga_after_8b',
                'data_base_path': 'El_Fasher_after_8b',
                #'data_base_path': 'Gandahar_Mkt_after_8b',

                #'data_base_path': 'Babanusa_after_4b',
                #'data_base_path': 'single_scene_after',
                #'data_base_path': 'test_scene_after',
                #'data_base_path': 'test_scene_after_EG_v5',
                #'data_base_path': 'El_Fasher_after_v1',
                #'data_base_path': 'Al_Me_Eliq_after',
                #'data_base_path': 'Al_Ezayba_after_8b',
                #'data_base_path': 'Al_Ezayba_after_4b',
                #'data_base_path': 'Babanusa_after_8b',
                #'data_base_path': 'Muqrin_after_8b',
                #'data_base_path': 'Jaranga_after_8b',
                'bands': BANDS,
                'tile_px_size': 32,
                'tile_overlap_px': 0,
                'include_last_row_colum_extra_tile': False,
                'nan_to_num': False,
             },
            'normalizer': DataNormalizerLogManual,
           }

In [None]:
data_normalizer = settings["normalizer"](settings)
print("loaded data_normalizer")

data_module_after = DataModule(settings, data_normalizer)
data_module_after.setup()
data_normalizer.setup(data_module_after)
len_train_ds_after = len(data_module_after.val_dataloader())

settings_before = settings.copy()
#settings_before["dataset"]["data_base_path"] = 'Gandahar_Mkt_before_4b'
#settings_before["dataset"]["data_base_path"] = 'El_Fasher_before_4b'
#settings_before["dataset"]["data_base_path"] = 'Muqrin_before_4b'
#settings_before["dataset"]["data_base_path"] = 'Jaranga_before_4b'
#settings_before["dataset"]["data_base_path"] = 'Sarafaya_before_4b'
#settings_before["dataset"]["data_base_path"] = 'Sarafaya_before_8b'
#settings_before["dataset"]["data_base_path"] = 'Muqrin_before_8b'
#settings_before["dataset"]["data_base_path"] = 'Jaranga_before_8b'
settings_before["dataset"]["data_base_path"] = 'El_Fasher_before_8b'
#settings_before["dataset"]["data_base_path"] = 'Gandahar_Mkt_before_8b'

#settings_before["dataset"]["data_base_path"] = 'Babanusa_before_4b'
#settings_before["dataset"]["data_base_path"] = 'El_Fasher_before_v1'
#settings_before["dataset"]["data_base_path"] = 'Al_Ezayba_before_8b'
#settings_before["dataset"]["data_base_path"] = 'Al_Ezayba_before_4b'
#settings_before["dataset"]["data_base_path"] = 'Babanusa_before_8b'
#settings_before["dataset"]["data_base_path"] = 'Muqrin_before_8b'
#settings_before["dataset"]["data_base_path"] = 'Jaranga_before_8b'
data_module_before = DataModule(settings_before, data_normalizer)
data_module_before.setup()
data_normalizer.setup(data_module_before)
len_train_ds_before = len(data_module_before.val_dataloader())

assert len_train_ds_after == len_train_ds_before

In [None]:
bands = 8
#bands = 4
for band_i in range(bands):
    if band_i < len(data_normalizer.BANDS_S2_BRIEF):

        # rescale
        r = data_normalizer.RESCALE_PARAMS[data_normalizer.BANDS_S2_BRIEF[band_i]]
        x0,x1,y0,y1 = r["x0"], r["x1"], r["y0"], r["y1"]

        print(data_normalizer.BANDS_S2_BRIEF[band_i], "->", r)

In [None]:
# We can also later just use the loaded tiles:

after_array = []
before_array = []
for sample in tqdm(data_module_after.train_dataset):
    after_array.append(sample[0])
for sample in tqdm(data_module_before.train_dataset):
    before_array.append(sample[0])

In [None]:
before_array = [x.numpy() for x in before_array]
after_array = [x.numpy() for x in after_array]

In [None]:
before_array = np.asarray(before_array)
after_array = np.asarray(after_array)

In [None]:
before_array.shape

In [None]:
print("Now we have", len(before_array),"*",before_array[0].shape, "as data from the image before the event and ",len(after_array),"*",after_array[0].shape, "from the image after the event.")

In [None]:
# this data should already be normalised:
# min, max was obtained on a training dataset, so some samples may be outside of the -1,1, but most should be near
sample_idx = 0
band_inspect_idx = 0
np.nanmin(before_array[sample_idx][band_inspect_idx]), np.nanmax(before_array[sample_idx][band_inspect_idx])


# Make a model

#### @title Init model functions

#### RaVAEn models as torch nn modules


In [None]:
class BaseModel(nn.Module):
    @abstractmethod
    def forward(self, *inputs: Tensor) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def loss_function(self, batch: Tensor, *inputs: Any, **kwargs) -> Tensor:
        raise NotImplementedError
    
class BaseAE(BaseModel):
    def __init__(self, visualisation_channels):
        super().__init__()

        self.visualisation_channels = visualisation_channels

    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError

    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError

    def forward(self, input: Tensor, **kwargs) -> Tensor:
        z = self.encode(torch.nan_to_num(input))
        return self.decode(z)

    def loss_function(self,
                      input: Tensor,
                      results: Dict,
                      mask_invalid: bool = False,
                      **kwargs) -> Dict:

        if not mask_invalid:
            recons_loss = F.mse_loss(results, torch.nan_to_num(input))
        else:
            invalid_mask = torch.isnan(input)
            recons_loss = \
                F.mse_loss(results[~invalid_mask], input[~invalid_mask])

        return {'loss': recons_loss, 'Reconstruction_Loss': recons_loss}

    def _visualise_step(self, batch):
        result = self.forward(batch)
        rec_error = (batch - result).abs()
        return batch[:, self.visualisation_channels], result[:, self.visualisation_channels], \
              rec_error.max(1)[0]

    @property
    def _visualisation_labels(self):
        return ["Input", "Reconstruction", "Rec error"]
    
class BaseVAE(BaseAE):
    def sample(self, batch_size: int, current_device: int, **kwargs) -> Tensor:
        raise RuntimeWarning()

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError

    def _visualise_step(self, batch):
        result = self.forward(batch)
        # if VAE
        result = result[0]

        rec_error = (batch - result).abs()
        return batch[:, self.visualisation_channels], result[:, self.visualisation_channels], \
              rec_error.max(1)[0]   


##### Deeper models - have parameters to change the model sizes ...

In [None]:
class ConvBlock(nn.Module):
    """
    Convolutional block which preserves the height and width of the input image.

    (convolution => [BN] => LeakyReLU) * depth
    """

    def __init__(self, in_channels, out_channels, depth=2, activation=nn.LeakyReLU, batchnorm=True):
        super().__init__()

        layers = []
        for n in range(1, depth+1):
            layers += [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)]
            if batchnorm:
                layers += [nn.BatchNorm2d(out_channels)]
            if activation is not None:
                layers +=[activation()]
            in_channels = out_channels

        self.conv_block = nn.Sequential(*layers)


    def forward(self, x):
        return self.conv_block(x)   

class ResConvBlock(ConvBlock):
    def forward(self, x):
        dx = self.conv_block(x)
        return  x + dx     
    
class DownConv(nn.Module):
    """Downscaling block"""
    def __init__(self, in_channels, out_channels, activation=nn.LeakyReLU, batchnorm=True):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)]
        if batchnorm:
            layers += [nn.BatchNorm2d(out_channels)]
        if activation is not None:
            layers +=[activation()]
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)    
    
class UpConv(nn.Module):
    """Upscaling layer with single convolution"""

    def __init__(self, in_channels, out_channels, upsample_method='nearest',
                 activation=nn.LeakyReLU, batchnorm=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if upsample_method in ['nearest', 'linear', 'bilinear', 'bicubic']:
            align_corners=None if upsample_method=="nearest" else True
            self.up = nn.Sequential(
                nn.Upsample(
                    scale_factor=2,
                    mode=upsample_method,
                    align_corners=align_corners
                ),
                ConvBlock(in_channels, out_channels, depth=1, activation=activation, batchnorm=batchnorm)
            )

        elif upsample_method=='transpose':
            layers = [
                nn.ConvTranspose2d(
                    in_channels, out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1
                ),
            ]
            if batchnorm:
                layers += [nn.BatchNorm2d(out_channels)]
            if activation is not None:
                layers +=[activation()]
            self.up = nn.Sequential(*layers)

        else:
            raise NotImplementedError(
                f"Upsample method has not been implemented: {upsample_method}"
            )

    def forward(self, x):
        return self.up(x)    

In [None]:
class DeeperVAE(BaseVAE):

    def __init__(self,
                 input_shape: Tuple[int],
                 hidden_channels: List[int],
                 latent_dim: int,
                 extra_depth_on_scale: int,
                 visualisation_channels,
                 **kwargs) -> None:
        super().__init__(visualisation_channels)

        assert input_shape[1]>=2**len(hidden_channels), "Cannot have so many downscaling layers"

        self.latent_dim = latent_dim

        # Calculate size of encoder output
        encoder_output_width = int(input_shape[1]/(2**len(hidden_channels)))
        encoder_output_dim = int(encoder_output_width**2 * hidden_channels[-1])
        self.encoder_output_shape = (hidden_channels[-1], encoder_output_width, encoder_output_width)

        if encoder_output_dim<latent_dim:
            raise UserWarning(
                f"Encoder output dim {encoder_output_dim} is smaller than latent dim {latent_dim}."+
                "This means the bottle neck is tighter than intended."
            )

        in_channels = input_shape[0]

        self.encoder = self._build_encoder([in_channels]+hidden_channels, extra_depth_on_scale)

        self.fc_mu = nn.Linear(encoder_output_dim, latent_dim)
        self.fc_var = nn.Linear(encoder_output_dim, latent_dim)

        self.decoder_input = nn.Linear(latent_dim, encoder_output_dim)

        self.decoder = self._build_decoder(hidden_channels[::-1]+[in_channels], extra_depth_on_scale)

    @staticmethod
    def _build_encoder(channels, extra_depth):
        in_channels = channels[0]
        encoder = []
        for out_channels in channels[1:]:
            encoder+=[
                DownConv(
                    in_channels,
                    out_channels,
                    activation=nn.LeakyReLU,
                    batchnorm=True
                )
            ]
            if extra_depth>0:
                encoder+=[
                    ResConvBlock(
                        out_channels,
                        out_channels,
                        activation=nn.LeakyReLU,
                        batchnorm=True,
                        depth=extra_depth
                    )
                ]
            # for next time round loop
            in_channels = out_channels

        return nn.Sequential(*encoder)

    @staticmethod
    def _build_decoder(channels, extra_depth):
        in_channels = channels[0]
        decoder = []
        up_activation = nn.LeakyReLU
        res_activation = nn.LeakyReLU
        for i, out_channels in enumerate(channels[1:]):
            # if last layer use linear activation
            is_last_layer = (i==(len(channels)-2))
            if is_last_layer:
                up_activation = None

            decoder+=[
                UpConv(
                    in_channels,
                    out_channels,
                    upsample_method='nearest',
                    activation=up_activation,
                    batchnorm=not is_last_layer,
                )
            ]
            if extra_depth>0:
                decoder+=[
                    ResConvBlock(
                        out_channels,
                        out_channels,
                        activation=res_activation,
                        batchnorm=not is_last_layer,
                        depth=extra_depth
                    )
                ]
            # for next time round loop
            in_channels = out_channels

        return nn.Sequential(*decoder)

    def encode(self, input: Tensor, verbose=False) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)

        if verbose:
          x = input0
          print("input", x.shape)
          for multilayer in self.encoder:
              for layer in multilayer.conv:
                x = layer(x)
                print(layer," => it's output:\n", x.shape)

        result = torch.flatten(result, start_dim=1)
        if verbose: print("result", result.shape)


        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        if verbose: print("mu", self.fc_mu, " => it's output:\n", mu.shape)
        log_var = self.fc_var(result)
        if verbose: print("log_var", self.fc_var, " => it's output:\n", log_var.shape)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, *self.encoder_output_shape)
        result = self.decoder(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: Mean of the latent Gaussian [B x D]
        :param logvar: Standard deviation of the latent Gaussian [B x D]
        :return: [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(torch.nan_to_num(input))
        z = self.reparameterize(mu, log_var)
        return [self.decode(z), mu, log_var]

    def loss_function(self, input: Tensor, results: Any, **kwargs) -> Dict:
        """
        Computes the VAE loss function.

        :param args:
        :param kwargs:
        :return:
        """
        # invalid_mask = torch.isnan(input)
        input = torch.nan_to_num(input)

        recons = results[0]
        mu = results[1]
        log_var = results[2]

        # Account for the minibatch samples from the dataset
        kld_weight = kwargs['M_N']

        recons_loss = F.mse_loss(recons, input)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss,
                'Reconstruction_Loss': recons_loss,
                'KLD': -kld_loss}

    def sample(self,
               num_samples: int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [None]:
# RaVAEn module to be used with pytorch lightning

class Module(pl.LightningModule):
    def __init__(self, model_cls, cfg: dict, train_cfg: dict, model_cls_args:dict) -> None:
        super().__init__()
        self.__dict__.update(cfg)
        self.__dict__.update(train_cfg)

        self.model = model_cls(input_shape=self.input_shape, **model_cls_args)

        if hasattr(self.model, '_visualise_step'):
            self._visualise_step = \
                lambda batch: self.model._visualise_step(batch[0])
            self._visualisation_labels = self.model._visualisation_labels

    def forward(self, batch: torch.Tensor, **kwargs) -> torch.Tensor:
        return self.model(batch, **kwargs)

    def log_losses(self, loss, where):
        for k in loss.keys():
            self.log(f'{where}/{k}', loss[k], on_epoch=True, logger=True)

    def training_step(self, batch, batch_idx, optimizer_idx=0):
        batch = batch[0]
        batch_size = batch.shape[0]

        results = self.forward(batch)
        train_loss = self.model.loss_function(batch,
                                              results,
                                              M_N=batch_size / self.len_train_ds,
                                              optimizer_idx=optimizer_idx,
                                              batch_idx=batch_idx)

        self.log_losses(train_loss, 'train')

        return train_loss

    def validation_step(self, batch, batch_idx, optimizer_idx=0):
        batch = batch[0]
        batch_size = batch.shape[0]

        results = self.forward(batch)
        val_loss = self.model.loss_function(batch,
                                            results,
                                            M_N=batch_size / self.len_val_ds,
                                            optimizer_idx=optimizer_idx,
                                            batch_idx=batch_idx)

        self.log_losses(val_loss, 'valid')

        return val_loss

    def configure_optimizers(self):
        optims = []
        scheds = []

        optimizer = optim.Adam(self.model.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
        optims.append(optimizer)

        if hasattr(self, 'scheduler_gamma'):
            scheduler = \
                optim.lr_scheduler.ExponentialLR(optims[0],
                                                 gamma=self.scheduler_gamma)
            scheds.append(scheduler)

        if hasattr(self, 'lr2'):
            optimizer2 = \
                optim.Adam(getattr(self.model, self.submodel).parameters(),
                           lr=self.lr2)
            optims.append(optimizer2)

        # Check if another scheduler is required for the second optimizer
        if hasattr(self, 'scheduler_gamma2'):
            scheduler2 = \
                optim.lr_scheduler.ExponentialLR(optims[1],
                                                 gamma=self.scheduler_gamma2)
            scheds.append(scheduler2)

        return optims, scheds

# Use the model

In [None]:
def KL_divergence(mu1: Tensor, log_var1: Tensor, mu2: Tensor, log_var2: Tensor, reduce_axes: Tuple[int] = (-1,)):
    """ returns KL(D_2 || D_1) assuming Gaussian distributions with diagonal covariance matrices, and taking D_1 as reference
    ----
    mu1, mu2, log_var1, log_var2: Tensors of sizes (..., Z) (e.g. (Z), (B, Z))

    """

    log_det = log_var1 - log_var2
    trace_cov = (-log_det).exp()
    mean_diff = (mu1 - mu2)**2 / log_var1.exp()
    return 0.5 * ((trace_cov + mean_diff + log_det).sum(reduce_axes) - mu1.shape[-1])

In [None]:
def twin_vae_change_score(model, x_1, x_2, verbose=False):
    if "VAE" in str(model.__class__):
        #print(x_1.dtype)
        mu_1, log_var_1 = model.encode(x_1) # batch, latent_dim
        mu_2, log_var_2 = model.encode(x_2) # batch, latent_dim

    else:
        assert False, "To be implemented!"

    if verbose:
        print("x_1", type(x_1), len(x_1), x_1[0].shape) # x 256 torch.Size([3, 32, 32])
        print("mu_1", type(mu_1), len(mu_1), mu_1[0].shape) #
        print("log_var_1", type(log_var_1), len(log_var_1), log_var_1[0].shape) #

    # distance = KL_divergence(mu_1, log_var_1, mu_2, log_var_2)
    distance = cosine_distance_score(mu_1, mu_2)

    if verbose: print("distance", type(distance), len(distance), distance[0].shape)

    # convert to numpy
    distance = distance.detach().cpu().numpy()
    if verbose: print("distance", distance.shape)

    return distance

In [None]:
def which_device(model):
    device = next(model.parameters()).device
    print("Model is on:", device)
    return device

#model = module.model
#model.eval()

#device = which_device(model)

# We have: model.forward .encode, .decode

#compare_func = twin_vae_change_score

In [None]:
def test_model():
    # step 1
    #cfg_module = {"input_shape": (4, 32, 32)
    cfg_module = {"input_shape": (8, 32, 32),
                  "visualisation_channels": [0, 1, 2],
                  "len_train_ds": len_train_ds_after,
                  "len_val_ds": 0,
    }

    cfg_train = {}

    model_cls_args_VAE = {
            # Using Small model:
            "hidden_channels": [16, 32, 64], # number of channels after each downscale. Reversed on upscale
            "latent_dim": 128,                # bottleneck size
            "extra_depth_on_scale": 0,        # after each downscale and upscale, this many convolutions are applied
            "visualisation_channels": cfg_module["visualisation_channels"],
    }

    module = Module(DeeperVAE, cfg_module, cfg_train, model_cls_args_VAE)

    hparams = {}
    namespace = Namespace(**hparams)
    
    Module.load_from_checkpoint(checkpoint_path=PATH_model_weights, hparams=namespace,
                            model_cls=DeeperVAE, train_cfg=cfg_train, model_cls_args=model_cls_args_VAE)
    
    # step 2
    model = module.model
    model.eval()

#    device = which_device(model)

    # We have: model.forward .encode, .decode

    compare_func = twin_vae_change_score    
    
    # step 3
    predicted_distances = []

    # Dataloaders load it in batches - these are loaded on demand
    # for before, after in zip(data_module_before.train_dataloader(), data_module_after.train_dataloader()):
        # before_data = before[0]
        # after_data = after[0]

    # While iterating over the arrays loads only one by one - these are already loaded in memory
    for before, after in zip(before_array, after_array):
        before_data = torch.from_numpy(before).unsqueeze(0)
        after_data = torch.from_numpy(after).unsqueeze(0)

        distances = compare_func(model, before_data, after_data)
        predicted_distances.append(distances)

    return predicted_distances

In [None]:
before_array.shape

In [None]:
before_array[0].shape

# Re-tiling

In [None]:
def tiles2image(predicted_distances, grid_shape, overlap=0, tile_size = 32, channels = 1):
    # predicted_distances shape of ~ N
    image = np.zeros((channels, grid_shape[1]*tile_size, grid_shape[0]*tile_size), dtype=np.float32)
    index = 0
    for i in range(grid_shape[1]):
        for j in range(grid_shape[0]):
            tile = predicted_distances[index] * np.ones((channels, tile_size, tile_size))
            image[:, i*tile_size:(i+1)*tile_size, j*tile_size:(j+1)*tile_size] = tile
            index += 1
    return image

In [None]:
# El Fasher 8-band
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,5000):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (42, 42) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")
print(pd.DataFrame(max_predicted_distances).describe())

In [None]:
# Gandahar market 8-band
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,5000):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (32, 32) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")
print(pd.DataFrame(max_predicted_distances).describe())

In [None]:
# Jaranga 8-band
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,5000):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (16, 16) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")
print(pd.DataFrame(max_predicted_distances).describe())

In [None]:
# Muqrin 8-band
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,5000):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (33, 33) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")
print(pd.DataFrame(max_predicted_distances).describe())

In [None]:
# Sarafaya 8-band
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,5000):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (15, 15) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")
print(pd.DataFrame(max_predicted_distances).describe())

In [None]:
# Sarafaya 4-band, re-trained, manual normalization params, ORDER_1
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,5000):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (15, 15) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")
print(pd.DataFrame(max_predicted_distances).describe())

In [None]:
# Jaranga 4-band, re-trained, manual normalization params, ORDER_1
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,5000):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (16, 16) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")
print(pd.DataFrame(max_predicted_distances).describe())

In [None]:
# Muqrin 4-band, re-trained, manual normalization params, ORDER_1
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,5000):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (33, 33) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")
print(pd.DataFrame(max_predicted_distances).describe())

In [None]:
# El Fasher 4-band, re-trained, manual normalization params, ORDER_1
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,5000):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (42, 42) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")
print(pd.DataFrame(max_predicted_distances).describe())

In [None]:
# Gandahar market 4-band, re-trained, manual normalization params, ORDER_1
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,5000):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (32, 32) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")
print(pd.DataFrame(max_predicted_distances).describe())

In [None]:
# get no. of tiles per row and column of the input image
filename='./El_Fasher_after_4b/El_Fasher_2024_05_11_after_4_b.tif'
with rasterio.open(filename) as src:
    filename_shape = src.height, src.width
    print(filename_shape)

tile_px_size = 32    
data_h, data_w = filename_shape
h_tiles_n = int(np.floor(data_h / tile_px_size))
w_tiles_n = int(np.floor(data_w / tile_px_size)) 
(h_tiles_n,w_tiles_n)

In [None]:
max_predicted_distances = test_model()
max_predicted_distances = np.asarray(max_predicted_distances).flatten()

max_run_num = 0

for run_num in range(0,500):
    predicted_distances = test_model()
    predicted_distances = np.asarray(predicted_distances).flatten()
    if(max(predicted_distances) > max(max_predicted_distances)):
        max_predicted_distances = predicted_distances
        max_run_num = run_num
grid_shape = (32, 32) # x tiles
change_map_image = tiles2image(max_predicted_distances, grid_shape = grid_shape, overlap=0, tile_size = 32)

plt.imshow(change_map_image[0])
plt.colorbar()
plt.show()

print(f"Run number: {max_run_num}")