### Imports

In [None]:
%env SM_FRAMEWORK=tf.keras

import ssl
from sys import platform
import tensorflow as tf
tf.__version__
import keras
import segmentation_models as sm

from config import *
from util import *

print('Done')

### Load Datasets

In [None]:
prep_input = sm.get_preprocessing(BACKBONE)  # vgg16 preprocessing

# define callbacks for learning rate scheduling and best checkpoints saving
callbacks = [
    keras.callbacks.ModelCheckpoint(best_model_name, save_weights_only=True,
                                    save_best_only=True, mode='min'),
    keras.callbacks.ReduceLROnPlateau(),
]

In [None]:
# datasets
train = Dataset(
    images_dir=train_dir + 'XTrain1/',
    masks_dir=train_dir + 'YTrain1/',
    classes=CLASSES,
    aug=get_augmentation(),
    prep=get_preprocess()
)
val = Dataset(
    images_dir=val_dir + 'XValidate1/',
    masks_dir=val_dir + 'YValidate1/',
    classes=CLASSES,
    aug=get_augmentation(),
    prep=get_preprocess(),
)

train_dataloader = Dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = Dataloader(val, batch_size=1, shuffle=False)
print('Done')

### Define Optimizer and Loss Functions

In [None]:
# define optimizer
optim = tf.keras.optimizers.Adam(LR)

# segmentation models losses
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss()
total_loss = dice_loss + (1 * focal_loss)
metrics = [sm.metrics.IOUScore(threshold=5), sm.metrics.FScore(threshold=0.5)]

print('Done')

### Define and Compile UNet Model

In [None]:
# bypass ssl context if not on win32 (this method is not recommended but easy)
if platform != 'win32':
    # noinspection PyProtectedMember,PyUnresolvedReferences
    ssl._create_default_https_context = ssl._create_unverified_context
    print('context set')

print('Done')

In [None]:
# create model
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)

# compile keras model with defined optimizer, loss and metrics
model.compile(optim, total_loss, metrics)

print('Done')

### Train Model

In [None]:
history1 = model.fit(
    x=train_dataloader,
    steps_per_epoch=len(train_dataloader),
    epochs=50,
    callbacks=callbacks,
    validation_data=val_dataloader,
    validation_steps=len(val_dataloader),
)

In [None]:
prep_input = sm.get_preprocessing(BACKBONE)  # vgg16 preprocessing

# define callbacks for learning rate scheduling and best checkpoints saving
callbacks = [
    keras.callbacks.ModelCheckpoint(best_model_name, save_weights_only=True,
                                    save_best_only=True, mode='min'),
    keras.callbacks.ReduceLROnPlateau(),
]

In [None]:
# datasets
train = Dataset(
    images_dir=train_dir + 'XTrain2/',
    masks_dir=train_dir + 'YTrain2/',
    classes=CLASSES,
    aug=get_augmentation(),
    prep=get_preprocess()
)
val = Dataset(
    images_dir=val_dir + 'XValidate2/',
    masks_dir=val_dir + 'YValidate2/',
    classes=CLASSES,
    aug=get_augmentation(),
    prep=get_preprocess(),
)

train_dataloader = Dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = Dataloader(val, batch_size=1, shuffle=False)
print('Done')

### Define Optimizer and Loss Functions

In [None]:
# define optimizer
optim = tf.keras.optimizers.Adam(LR)

# segmentation models losses
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss()
total_loss = dice_loss + (1 * focal_loss)
metrics = [sm.metrics.IOUScore(threshold=5), sm.metrics.FScore(threshold=0.5)]

print('Done')

### Define and Compile UNet Model

In [None]:
# bypass ssl context if not on win32 (this method is not recommended but easy)
if platform != 'win32':
    # noinspection PyProtectedMember,PyUnresolvedReferences
    ssl._create_default_https_context = ssl._create_unverified_context
    print('context set')

print('Done')

In [None]:
# create model
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)

# compile keras model with defined optimizer, loss and metrics
model.compile(optim, total_loss, metrics)

print('Done')

In [None]:
history2 = model.fit(
    x=train_dataloader,
    steps_per_epoch=len(train_dataloader),
    epochs=50,
    callbacks=callbacks,
    validation_data=val_dataloader,
    validation_steps=len(val_dataloader),
)

In [None]:
prep_input = sm.get_preprocessing(BACKBONE)  # vgg16 preprocessing

# define callbacks for learning rate scheduling and best checkpoints saving
callbacks = [
    keras.callbacks.ModelCheckpoint(best_model_name, save_weights_only=True,
                                    save_best_only=True, mode='min'),
    keras.callbacks.ReduceLROnPlateau(),
]

In [None]:
# datasets
train = Dataset(
    images_dir=train_dir + 'XTrain3/',
    masks_dir=train_dir + 'YTrain3/',
    classes=CLASSES,
    aug=get_augmentation(),
    prep=get_preprocess()
)
val = Dataset(
    images_dir=val_dir + 'XValidate3/',
    masks_dir=val_dir + 'YValidate3/',
    classes=CLASSES,
    aug=get_augmentation(),
    prep=get_preprocess(),
)

train_dataloader = Dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = Dataloader(val, batch_size=1, shuffle=False)
print('Done')

### Define Optimizer and Loss Functions

In [None]:
# define optimizer
optim = tf.keras.optimizers.Adam(LR)

# segmentation models losses
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss()
total_loss = dice_loss + (1 * focal_loss)
metrics = [sm.metrics.IOUScore(threshold=5), sm.metrics.FScore(threshold=0.5)]

print('Done')

### Define and Compile UNet Model

In [None]:
# bypass ssl context if not on win32 (this method is not recommended but easy)
if platform != 'win32':
    # noinspection PyProtectedMember,PyUnresolvedReferences
    ssl._create_default_https_context = ssl._create_unverified_context
    print('context set')

print('Done')

In [None]:
# create model
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)

# compile keras model with defined optimizer, loss and metrics
model.compile(optim, total_loss, metrics)

print('Done')

In [None]:
history3 = model.fit(
    x=train_dataloader,
    steps_per_epoch=len(train_dataloader),
    epochs=50,
    callbacks=callbacks,
    validation_data=val_dataloader,
    validation_steps=len(val_dataloader),
)

In [None]:
prep_input = sm.get_preprocessing(BACKBONE)  # vgg16 preprocessing

# define callbacks for learning rate scheduling and best checkpoints saving
callbacks = [
    keras.callbacks.ModelCheckpoint(best_model_name, save_weights_only=True,
                                    save_best_only=True, mode='min'),
    keras.callbacks.ReduceLROnPlateau(),
]

In [None]:
# datasets
train = Dataset(
    images_dir=train_dir + 'XTrain4/',
    masks_dir=train_dir + 'YTrain4/',
    classes=CLASSES,
    aug=get_augmentation(),
    prep=get_preprocess()
)
val = Dataset(
    images_dir=val_dir + 'XValidate4/',
    masks_dir=val_dir + 'YValidate4/',
    classes=CLASSES,
    aug=get_augmentation(),
    prep=get_preprocess(),
)

train_dataloader = Dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = Dataloader(val, batch_size=1, shuffle=False)
print('Done')

### Define Optimizer and Loss Functions

In [None]:
# define optimizer
optim = tf.keras.optimizers.Adam(LR)

# segmentation models losses
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss()
total_loss = dice_loss + (1 * focal_loss)
metrics = [sm.metrics.IOUScore(threshold=5), sm.metrics.FScore(threshold=0.5)]

print('Done')

### Define and Compile UNet Model

In [None]:
# bypass ssl context if not on win32 (this method is not recommended but easy)
if platform != 'win32':
    # noinspection PyProtectedMember,PyUnresolvedReferences
    ssl._create_default_https_context = ssl._create_unverified_context
    print('context set')

print('Done')

In [None]:
# create model
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)

# compile keras model with defined optimizer, loss and metrics
model.compile(optim, total_loss, metrics)

print('Done')

In [None]:
history4 = model.fit(
    x=train_dataloader,
    steps_per_epoch=len(train_dataloader),
    epochs=50,
    callbacks=callbacks,
    validation_data=val_dataloader,
    validation_steps=len(val_dataloader),
)

### Graph Model Loss

In [None]:
graph_f1_models(history1, history2, history3, history4)

In [None]:
graph_loss_models(history1, history2, history3, history4)

### Model Evaluation

In [None]:
best_model_name = os.path.join(model_dir, 'best_model_9.h5')

# load best weights
model.load_weights(best_model_name)

In [None]:
# remake_test(200)

# load test dataset
test = Dataset(
    images_dir=test_dir + 'XTestFixed',
    masks_dir=test_dir + 'YTestFixed',
    classes=CLASSES,
    aug=get_augmentation(),
    prep=get_preprocess()
)
test_dataloader = Dataloader(test, batch_size=1, shuffle=False)

In [None]:
scores = model.evaluate(test_dataloader)
print_metrics(scores, metrics)

In [None]:
show_results(
    test,
    model,
    len(test),
    postprocess,
    show=True,
    save=True,
    out=output_dir,
    largest_only=True,
    region_fill=True,
)