In [None]:
! pip install -U segmentation-models

In [None]:
import tensorflow as tf
import keras
import numpy as np
from tensorflow.keras.preprocessing.image import load_img
from keras.utils.np_utils import to_categorical   
import random
import segmentation_models as sm
sm.set_framework('tf.keras')
from segmentation_models import Unet
from segmentation_models import get_preprocessing
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score

In [None]:
# Load filenames from txt
input_files = []
input_mixed = []
with open('/kaggle/input/wildfire-segmentation-dataset/mixed.txt') as f:
    input_mixed = f.read().splitlines()
    input_files.extend(input_mixed)

In [None]:
# Train on full dataset
# (Comment out to train only for files with mixed label in CNN classifier)
with open('/kaggle/input/wildfire-segmentation-dataset/burned.txt') as f:
    input_files.extend(f.read().splitlines())
with open('/kaggle/input/wildfire-segmentation-dataset/other.txt') as f:
    input_files.extend(f.read().splitlines())
with open('/kaggle/input/wildfire-segmentation-dataset/unknown.txt') as f:
    input_files.extend(f.read().splitlines())
with open('/kaggle/input/wildfire-segmentation-dataset/vegetation.txt') as f:
    input_files.extend(f.read().splitlines())

In [None]:
random.shuffle(input_files)
len(input_files)

In [None]:
# Split data to train and validation 
from sklearn.model_selection import train_test_split
train_files, valid_files = train_test_split(input_files, test_size=0.2, random_state=101)

In [None]:
NUM_CLASSES = 4

In [None]:
# Class to load data from band images to np array along with corresponding masks

class Sequence_generator(tf.keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, input_files, bands, input_dir, mask_dir):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_files = input_files
        self.bands = bands
        self.input_dir = input_dir
        self.mask_dir = mask_dir

    def __len__(self):
        return len(self.input_files) // self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_files = self.input_files[i:i+self.batch_size]
        x = np.zeros((self.batch_size,) + self.img_size +
                     (len(self.bands),), dtype="float32")
        for j, file_name in enumerate(batch_files):
            for b, band in enumerate(self.bands):
                img = load_img(
                    self.input_dir + f'{band}/{file_name}.jpg', 
                    target_size=self.img_size, 
                    color_mode='grayscale'
                )
                x[j, :, :, b] = img
        y = np.zeros((self.batch_size,) + self.img_size +
                     (NUM_CLASSES,), dtype="float32")
        for j, file_name in enumerate(batch_files):
            img = load_img(self.mask_dir + file_name+'.png',
                           target_size=self.img_size, color_mode='grayscale')
            y[j] = to_categorical(img, num_classes=NUM_CLASSES)
        return x, y

    def on_epoch_end(self):
        random.shuffle(self.input_files)
        return super().on_epoch_end()

In [None]:
# Bands to be used for segmentation
unet_bands = ['B8A', 'B11', 'B12']

In [None]:
batch_size = 32
epochs = 40
image_size = (256,256)
input_dir =  "/kaggle/input/wildfire-segmentation-dataset/all_data/"
mask_dir = "/kaggle/input/wildfire-segmentation-dataset/all_data/normalized_mask/"

In [None]:
train_gen = Sequence_generator(batch_size, image_size, train_files, unet_bands, input_dir, mask_dir)
valid_gen = Sequence_generator(batch_size, image_size, valid_files, unet_bands, input_dir, mask_dir)

In [None]:
t1 = train_gen.__getitem__(0)
t1[1].shape

In [None]:
model_name = 'resnet34'
model = Unet(
    model_name, 
    classes=NUM_CLASSES, 
    input_shape=image_size+(3,), 
    encoder_weights='imagenet'
)
model.compile('adamax', loss=bce_jaccard_loss, metrics=[iou_score])

save_file = f'{model_name}_{unet_bands[0]}_{unet_bands[1]}_{unet_bands[2]}.h5'

In [None]:
save_file

In [None]:
callbacks = [
    keras.callbacks.ModelCheckpoint(save_file, save_best_only=True)
]

# fit model
model.fit(
    train_gen,
    batch_size=batch_size,
    epochs=epochs,
    validation_data=valid_gen,
    callbacks=callbacks
)

In [None]:
mixed_gen = Sequence_generator(batch_size, image_size, input_mixed, unet_bands, input_dir, mask_dir)
model.evaluate(mixed_gen)