In [None]:
import os
import sys
import glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import backend as K
from tensorflow.keras import mixed_precision
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from models.hrnet_clf import HRNet_CLF
# from models.hrnet_clf_accumilate import HRNet_CLF
# from models.hrnet_clf_gn_accumilate import HRNet_CLF_GN
from data_loaders import SunLoader

K.clear_session()
physical_devices = tf.config.experimental.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

def enable_amp():
    mixed_precision.set_global_policy("mixed_float16")
%config InlineBackend.figure_format = 'retina'
plt.style.use('ggplot')
plt.rc('xtick',labelsize=16)
plt.rc('ytick',labelsize=16)
print("Tensorflow version: ", tf.__version__)
print(physical_devices,"\n")
# enable_amp() 

In [None]:
img_height = 224
img_width = 224
n_classes = 362


pipeline = SunLoader(
    data_dir = "/workspace/PythonProjects/cityscapes_segmentation_tf2/SUN397", 
    bad_imgs_file = 'SUN397_bad_images.txt',
    n_classes = n_classes,
    img_height = img_height,
    img_width = img_width,
)

image_list = pipeline.get_image_list()
label_list = pipeline.get_label_list(image_list)

In [None]:
BATCH_SIZE = 64
# ACCUM_STEPS = 4
# ADJ_BATCH_SIZE = BATCH_SIZE * ACCUM_STEPS
# print("Effective batch size: {}".format(ADJ_BATCH_SIZE))
BUFFER_SIZE = 300

In [None]:
DATASET_LENGTH = len(image_list)
TRAIN_LENGTH = int(DATASET_LENGTH * 0.7)
VALID_LENGTH = int(DATASET_LENGTH * 0.15)
TEST_LENGTH = DATASET_LENGTH - (TRAIN_LENGTH+VALID_LENGTH)

STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
VALIDATION_STEPS = VALID_LENGTH // BATCH_SIZE

In [None]:
def display(display_list, title=True):
    plt.figure(figsize=(15, 5)) # dpi=200
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((image_list[:TRAIN_LENGTH], 
                                               label_list[:TRAIN_LENGTH]))
valid_ds = tf.data.Dataset.from_tensor_slices((image_list[TRAIN_LENGTH:-TEST_LENGTH], 
                                               label_list[TRAIN_LENGTH:-TEST_LENGTH]))

train = train_ds.map(pipeline.load_image_train, num_parallel_calls=8)
valid = valid_ds.map(pipeline.load_image_test)

train_dataset = train.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
valid_dataset = valid.batch(BATCH_SIZE)

In [None]:
for image, label in train.take(4): 
    sample_image, sample_label = image, label

print(sample_image.shape, sample_label.shape)
display([sample_image])

In [None]:
model = HRNet_CLF(
    stage1_cfg = {'NUM_MODULES': 1,'NUM_BRANCHES': 1,'BLOCK': 'BOTTLENECK','NUM_BLOCKS': [4]}, 
    stage2_cfg = {'NUM_MODULES': 1,'NUM_BRANCHES': 2,'BLOCK': 'BASIC',     'NUM_BLOCKS': [4, 4]},
    stage3_cfg = {'NUM_MODULES': 4,'NUM_BRANCHES': 3,'BLOCK': 'BASIC',     'NUM_BLOCKS': [4, 4, 4]},
    stage4_cfg = {'NUM_MODULES': 3,'NUM_BRANCHES': 4,'BLOCK': 'BASIC',     'NUM_BLOCKS': [4, 4, 4, 4]},
    input_height = img_height, 
    input_width = img_width, 
    n_classes = n_classes, 
    W = 48,
    # GN_GROUPS=24,
    # ACCUM_STEPS=ACCUM_STEPS
)

```python
Model: "HRNet_W48"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              multiple                  1728      
_________________________________________________________________
batch_normalization (BatchNo multiple                  256       
_________________________________________________________________
conv2d_1 (Conv2D)            multiple                  36864     
_________________________________________________________________
batch_normalization_1 (Batch multiple                  256       
_________________________________________________________________
re_lu (ReLU)                 multiple                  0         
_________________________________________________________________
sequential_1 (Sequential)    (1, 128, 256, 192)        168192    
_________________________________________________________________
sequential_2 (Sequential)    (1, 128, 256, 48)         83136     
_________________________________________________________________
sequential_4 (Sequential)    (1, 64, 128, 96)          166272    
_________________________________________________________________
high_resolution_module (High multiple                  880704    
_________________________________________________________________
sequential_11 (Sequential)   (1, 32, 64, 192)          166656    
_________________________________________________________________
high_resolution_module_1 (Hi multiple                  3840576   
_________________________________________________________________
high_resolution_module_2 (Hi multiple                  3840576   
_________________________________________________________________
high_resolution_module_3 (Hi multiple                  3840576   
_________________________________________________________________
high_resolution_module_4 (Hi multiple                  3840576   
_________________________________________________________________
sequential_65 (Sequential)   (1, 16, 32, 384)          665088    
_________________________________________________________________
high_resolution_module_5 (Hi multiple                  15891072  
_________________________________________________________________
high_resolution_module_6 (Hi multiple                  15891072  
_________________________________________________________________
high_resolution_module_7 (Hi multiple                  15891072  
_________________________________________________________________
sequential_144 (Sequential)  (1, 128, 256, 20)         535700    
=================================================================
Total params: 65,740,372
Trainable params: 65,655,732
Non-trainable params: 84,640

```

In [None]:
MODEL_PATH = "weights/"+model.name+".h5"

In [None]:
model.load_weights(MODEL_PATH)

In [None]:
model.summary()

In [None]:
model.compile(
    optimizer = Adam(learning_rate=3e-3),
    loss=CategoricalCrossentropy(from_logits=True),
    metrics = ['accuracy']
)

In [None]:
callbacks = [
    EarlyStopping(monitor='val_accuracy', mode='max', patience=15, verbose=2),
    ReduceLROnPlateau(monitor='val_accuracy', mode='max', patience=5, factor=0.1, min_lr=1e-5, verbose=2),
    ModelCheckpoint(MODEL_PATH, monitor='val_accuracy', mode='max', 
                    verbose=2, save_best_only=True, save_weights_only=True)    
]

In [None]:
EPOCHS = 100

End of epoch 1
- loss: 5.5664
- accuracy: 0.0305
- val_loss: 4.5994
- val_accuracy: 0.0761

In [None]:
results = model.fit(
    train_dataset,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
    epochs=EPOCHS,
    validation_data=valid_dataset,
    callbacks=callbacks,
    verbose=1
)