In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "last" # all | last | last_expr | none 

In [None]:
# for name in dir():
#     if not name.startswith('_'):
#         del globals()[name]

!tar -xf ./data/images.tar.gz
!tar -xf ./data/annotations.tar.gz

### Data preparation and model configuration

In [None]:
# ============= Import required packaages ==============
# Import all custom variables and modules
from custom_classes_defs.preprocessing import *  
# from custom_classes_defs.unet0 import *
# from custom_classes_defs.Unet_like import *   
# from custom_classes_defs.unet import *
from custom_classes_defs.fnet0 import *
# from custom_classes_defs.fnet_like import *
# from custom_classes_defs.fnet import *

import keras
import numpy as np

from keras.utils import plot_model

RND_STATE = 247
BATCH_SIZE = 64
EPOCHS = 200
keras.utils.set_random_seed(RND_STATE)

INTERACTIVE_SESSION = True

# import keras_tuner as kt
# -------------------------------------------------------

In [None]:
# Verify tensorflow/keras versions
print(f"tensorflow version: {tf.__version__}")
print(f"keras version: {keras.__version__}")

# Verify CPU/GPU availability
print(tf.config.list_physical_devices())
NUM_GPU = len(tf.config.list_physical_devices('GPU'))
print(f"Number of GPUs assigned for computation: {NUM_GPU}")

if NUM_GPU:
    # print GPU info
    !nvidia-smi

In [None]:
pets = Oxford_Pets(
        input_dir = "./data/images/",
        target_dir = "./data/annotations/trimaps/",
        img_size = (160, 160),
        batch_size = BATCH_SIZE
    )

train_dataset, valid_dataset, test_dataset = \
        pets.split_data(train_ratio=0.5, val_ratio=0.2, seed=RND_STATE)


print(f"training data (size = {pets.train_size})")
print(f"validation data (size = {pets.validation_size})")
print(f"test data (size = {pets.test_size})")
print("Data images tensor:",train_dataset.element_spec[0])
print("Data labels tensor:",train_dataset.element_spec[1])


In [None]:
train_shape = (pets.train_size,) + tuple(train_dataset.element_spec[-1].shape[1:])
cl_weights = np.ones(train_shape)


In [None]:

# Model configurations
conf = model_config(
    epochs=EPOCHS,
    loss="sparse_categorical_crossentropy",
    batch_size=BATCH_SIZE,
    save_path='./oxford-pets/fnet00',
    img_shape=pets.img_size,
    target_size=pets.img_size,
    train_size=pets.train_size,
    test_size=pets.test_size,
    validation_size=pets.validation_size,
    channels_dim=(3,3),
    pos_label=pets.pos_label,
    new_training_session=True,
    multiple_gpu_device=(NUM_GPU>1),
    verbose=1
)



callbacks=conf.callbacks()
conf.set( validation_data=valid_dataset,  callbacks=callbacks)
conf.set(
    'compile',
    optimizer=tf.keras.optimizers.legacy.Adam(1e-4), 
    metrics= ['accuracy'],
)

# conf.double_check(INTERACTIVE_SESSION)
conf.info()


### Build model

In [None]:
### SINGLE-HOST, MULTI-DEVICE SYNCHRONOUS TRAINING
## François Chollet. Deep Learning with Python, Second Edition (Kindle Location 12675). Manning Publications Co.. 
print("\n\n{}\n\t{}\n{}".format('='*55,f'Build model', '-'*55))

if conf.multiple_gpu_device:

    strategy = tf.distribute.MirroredStrategy()
    print(f"Number of devices: {strategy.num_replicas_in_sync}")
    with strategy.scope():
        # m_obj = UNET2D(panel_sizes=[32,64,128,256], model_arch=conf.model_arch)
        m_obj = FNET2D(panel_sizes=[32,64,128,256], model_arch=conf.model_arch, add_residual=False)
        model = m_obj.build_model()
        model.compile(**conf.compile_args)

else:

    # m_obj = UNET2D(panel_sizes=[32,64,128,256], model_arch=conf.model_arch)
    m_obj = FNET2D(panel_sizes=[32,64,128,256], model_arch=conf.model_arch, add_residual=False)
    model = m_obj.build_model()
    model.compile(**conf.compile_args)

# model.summary()
plot_model(model, 'm_obj.png',show_shapes=True)
num_trainable_weights = sum([np.prod(w.shape) for w in model.trainable_weights])
print(f"Total number of parameters: {model.count_params():,}")
print(f"Total trainable wieghts: {num_trainable_weights:,}")
print(f"Total non-trainable wieghts: {model.count_params()-num_trainable_weights:,}")



### Train the model

In [None]:
print("\n\n{}\n\t{}\n{}".format('='*55,f'Train {m_obj.Name} model', '-'*55))

model, train_history = \
    conf.execute_training(
        model, 
        data=train_dataset, 
        plot_history=INTERACTIVE_SESSION
)

In [None]:
show_convergence(train_history.history, ['accuracy','val_accuracy'])

In [None]:
show_convergence(train_history.history, 'lr')

### Visualize predictions

In [None]:
# Generate predictions for all images in the validation set
y_preds = model.predict(test_dataset, verbose=2)


In [None]:
if INTERACTIVE_SESSION:
    pets.display_sample_image(y_preds, 'test')


In [None]:
# print("\n\n{}\n\t{}\n{}".format('='*55,f'Evaluate {m_obj.Name} model', '-'*55))

# model.evaluate(x=test_dataset)


In [None]:
# Using scikit-learn
scores = m_obj.evaluate_sklearn(test_dataset, y_preds,report=True)
print(scores)
