In [None]:
import os
import sys
import glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import backend as K
from tensorflow.keras import mixed_precision
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.optimizers.schedules import PolynomialDecay, PiecewiseConstantDecay
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from models.clf.hrnet_clf_accumilate import HRNet_CLF

K.clear_session()
physical_devices = tf.config.experimental.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

def enable_amp():
    mixed_precision.set_global_policy("mixed_float16")
%config InlineBackend.figure_format = 'retina'
plt.style.use('ggplot')
plt.rc('xtick',labelsize=16)
plt.rc('ytick',labelsize=16)
print("Tensorflow version: ", tf.__version__)
print(physical_devices,"\n")
enable_amp() 

In [None]:
class ImageNetLoader():
    
    def __init__(self, img_height, img_width, n_classes):
        self.n_classes = n_classes
        self.img_height = img_height
        self.img_width = img_width
        self.MEAN = np.array([0.485, 0.456, 0.406])
        self.STD = np.array([0.229, 0.224, 0.225])
        
    
    @tf.function
    def random_crop(self, image):

        scales = tf.convert_to_tensor(np.array([0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, 1.0]))
        scale = scales[tf.random.uniform(shape=[], minval=0, maxval=8, dtype=tf.int32)]
        scale = tf.cast(scale, tf.float32)

        shape = tf.cast(tf.shape(image), tf.float32)
        h = tf.cast(shape[0] * scale, tf.int32)
        w = tf.cast(shape[1] * scale, tf.int32)
        image = tf.image.random_crop(image, size=[h, w, 3])
        return image

    @tf.function
    def normalize(self, image):
        image = image / 255.0
        image = image - self.MEAN
        image = image / self.STD
        return image
    
    
    @tf.function
    def load_image_train(self, datapoint):

        img = datapoint['image']
        label = datapoint['label']
        label = tf.one_hot(tf.cast(label, tf.int32), self.n_classes)

        if tf.random.uniform(()) > 0.5:
            img = tf.image.flip_left_right(img)

        img = self.random_crop(img)
        img = tf.image.resize(img, (self.img_height, self.img_width))
        img = self.normalize(tf.cast(img, tf.float32))

        if tf.random.uniform(()) > 0.5:
            img = tf.image.random_brightness(img, 0.05)
            img = tf.image.random_saturation(img, 0.6, 1.6)
            img = tf.image.random_contrast(img, 0.7, 1.3)
            img = tf.image.random_hue(img, 0.05)

        return img, label
   

    def load_image_test(self, datapoint):
        img = datapoint['image']
        label = datapoint['label']
        label = tf.one_hot(tf.cast(label, tf.int32), self.n_classes)
        img = tf.image.resize(img, (self.img_height, self.img_width))
        img = self.normalize(tf.cast(img, tf.float32))
        return img, label

In [None]:
img_height = 224
img_width = 224
n_classes = 1000

pipeline = ImageNetLoader(
    n_classes = n_classes,
    img_height = img_height,
    img_width = img_width,
)

In [None]:
dataset, info = tfds.load('imagenet2012:5.0.0', data_dir='/workspace/tensorflow_datasets/', 
                          with_info=True, shuffle_files=True)

In [None]:
train = dataset['train'].map(pipeline.load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
valid = dataset['validation'].map(pipeline.load_image_test, num_parallel_calls=tf.data.AUTOTUNE)

TRAIN_LENGTH = info.splits['train'].num_examples
VALID_LENGTH = info.splits['validation'].num_examples

In [None]:
BATCH_SIZE = 128
ACCUM_STEPS = 2
BUFFER_SIZE = 8192
ADJ_BATCH_SIZE = BATCH_SIZE * ACCUM_STEPS
print("Effective batch size: {}".format(ADJ_BATCH_SIZE))

In [None]:
def display(display_list, title=True):
    plt.figure(figsize=(15, 5)) # dpi=200
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
for image, label in train.take(4): 
    sample_image, sample_label = image, label

print(sample_image.shape, sample_label.shape)
display([sample_image])

In [None]:
train_dataset = train.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
valid_dataset = valid.batch(BATCH_SIZE)

In [None]:
model = HRNet_CLF(
    stage1_cfg = {'NUM_MODULES': 1,'NUM_BRANCHES': 1,'BLOCK': 'BOTTLENECK','NUM_BLOCKS': [4]}, 
    stage2_cfg = {'NUM_MODULES': 1,'NUM_BRANCHES': 2,'BLOCK': 'BASIC',     'NUM_BLOCKS': [4, 4]},
    stage3_cfg = {'NUM_MODULES': 4,'NUM_BRANCHES': 3,'BLOCK': 'BASIC',     'NUM_BLOCKS': [4, 4, 4]},
    stage4_cfg = {'NUM_MODULES': 3,'NUM_BRANCHES': 4,'BLOCK': 'BASIC',     'NUM_BLOCKS': [4, 4, 4, 4]},
    input_height = img_height, 
    input_width = img_width, 
    n_classes = n_classes, 
    W = 48,
    # GN_GROUPS=24,
    ACCUM_STEPS=ACCUM_STEPS
)

```python
Model: "HRNet_W48"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              multiple                  1728      
_________________________________________________________________
batch_normalization (BatchNo multiple                  256       
_________________________________________________________________
conv2d_1 (Conv2D)            multiple                  36864     
_________________________________________________________________
batch_normalization_1 (Batch multiple                  256       
_________________________________________________________________
re_lu (ReLU)                 multiple                  0         
_________________________________________________________________
sequential_1 (Sequential)    (1, 128, 256, 192)        168192    
_________________________________________________________________
sequential_2 (Sequential)    (1, 128, 256, 48)         83136     
_________________________________________________________________
sequential_4 (Sequential)    (1, 64, 128, 96)          166272    
_________________________________________________________________
high_resolution_module (High multiple                  880704    
_________________________________________________________________
sequential_11 (Sequential)   (1, 32, 64, 192)          166656    
_________________________________________________________________
high_resolution_module_1 (Hi multiple                  3840576   
_________________________________________________________________
high_resolution_module_2 (Hi multiple                  3840576   
_________________________________________________________________
high_resolution_module_3 (Hi multiple                  3840576   
_________________________________________________________________
high_resolution_module_4 (Hi multiple                  3840576   
_________________________________________________________________
sequential_65 (Sequential)   (1, 16, 32, 384)          665088    
_________________________________________________________________
high_resolution_module_5 (Hi multiple                  15891072  
_________________________________________________________________
high_resolution_module_6 (Hi multiple                  15891072  
_________________________________________________________________
high_resolution_module_7 (Hi multiple                  15891072  
_________________________________________________________________
sequential_144 (Sequential)  (1, 128, 256, 20)         535700    
=================================================================
Total params: 65,740,372
Trainable params: 65,655,732
Non-trainable params: 84,640

```

In [None]:
MODEL_PATH = "weights/"+model.name+".h5"

In [None]:
model.load_weights(MODEL_PATH)

In [None]:
model.summary()

In [None]:
EPOCHS = 100

STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
VALIDATION_STEPS = VALID_LENGTH // BATCH_SIZE
DECAY_STEPS = (STEPS_PER_EPOCH * EPOCHS) // ACCUM_STEPS
DECAY_STEPS

In [None]:
CURR_EPOCH = 14
E1 = 30 - CURR_EPOCH
E2 = 60 - CURR_EPOCH
E3 = 90 - CURR_EPOCH

S1 = (STEPS_PER_EPOCH * E1) // ACCUM_STEPS
S2 = (STEPS_PER_EPOCH * E2) // ACCUM_STEPS
S3 = (STEPS_PER_EPOCH * E3) // ACCUM_STEPS

print("--- LR decay --- \nstep {}: {} \nstep {}: {} \nstep {}: {}".format(S1, 1e-2, S2, 1e-3, S3, 1e-4))

In [None]:
learning_rate_fn = PiecewiseConstantDecay(boundaries = [S1, S2, S3], values = [0.1, 0.01, 0.001, 0.0001])

In [None]:
model.compile(
    optimizer = SGD(learning_rate=learning_rate_fn, momentum=0.9, decay=0.0001),
    loss=CategoricalCrossentropy(from_logits=True),
    metrics = ['accuracy']
)

In [None]:
callbacks = [
    ModelCheckpoint(MODEL_PATH, monitor='val_accuracy', mode='max', 
                    verbose=2, save_best_only=True, save_weights_only=True)    
]

0.5609

In [None]:
results = model.fit(
    train_dataset,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
    epochs=EPOCHS,
    validation_data=valid_dataset,
    callbacks=callbacks,
    verbose=1
)