In [1]:
import numpy as np
import xarray as xr
import rioxarray as rxr
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix

In [3]:
# Load preprocessed dataset (DOP with forest_mask)
filepath = 'D:/COPY/Projekt_FF/Process/dop_forestmask.nc'
dop = rxr.open_rasterio(filepath)
dop = dop.squeeze()
dop

### **U-Net approach**
Define the model using a pretrained model (MobileNetV2). The following steps are based on the [TensorFlow tutorial on image segmentation](https://www.tensorflow.org/tutorials/images/segmentation).

In [4]:
# Split DOP into tiles
# Define tile size
tilesize = 224

In [34]:
# Define data generator for the model
# Tiles the dataset into multiple tiles and returns them in batches (row-wise)
# Also splits the data into training/validation/test
class CustomImageDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, ds, sampletype, tilesize=tilesize):
        self.ds   = ds
        self.ylen = self.ds.y.size // tilesize
        self.xlen = self.ds.x.size // tilesize
        self.sampletype = sampletype
        
    def __len__(self):
        return self.ylen

    def __getitem__(self, index):
        
        red       = self.ds.Band1[index*tilesize:(index+1)*tilesize,:-(self.ds.x.size%tilesize)]
        green     = self.ds.Band2[index*tilesize:(index+1)*tilesize,:-(self.ds.x.size%tilesize)]
        blue      = self.ds.Band3[index*tilesize:(index+1)*tilesize,:-(self.ds.x.size%tilesize)]
        forest    = self.ds.forest_mask[index*tilesize:(index+1)*tilesize,:-(self.ds.x.size%tilesize)]
        
        rgb       = np.array([red,green,blue]).transpose(1,2,0)
        forest    = np.array(forest)
        
        rgb_tiles    = np.array(np.split(rgb, self.xlen,axis=1))
        target_tiles = np.array(np.split(forest, self.xlen,axis=1))
        
        # Depending on sampletype, return training, validation or test set (complete set)
        if self.sampletype == "training" or self.sampletype == "validation":
            rgb_tiles_tr, rgb_tiles_val, target_tiles_tr, target_tiles_val = train_test_split(rgb_tiles, target_tiles, shuffle=True, test_size=0.1, random_state=0)
            if self.sampletype == "training": return rgb_tiles_tr, target_tiles_tr
            else:                             return rgb_tiles_val, target_tiles_val
        
        if self.sampletype == "test": return rgb_tiles, target_tiles
        
        return None

In [48]:
# Load pretrained MobileNetV2 model
base_model = tf.keras.applications.MobileNetV2(input_shape=[tilesize, tilesize, 3], include_top=False)

# Overview of the model architecture
# Switch 'button' to True for printing
button = False

if button == True:
    print(base_model.summary())
else:
    pass

In [49]:
# Modify the model
# First the downstack

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)

# We make this part of the model fix (not trainable)
down_stack.trainable = False