In [None]:
import tensorflow as tf

import json
import os

import matplotlib.pyplot as plt

from model.ModelBuilder import ModelBuilder
from utils_train.customLoss import CenterNetLoss
from utils_train.Datagenerator import Dataset_COCO
from utils_train.customOptimizer import GCSGD

In [None]:
tf.config.optimizer.set_jit("autoclustering")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
class LRFind(tf.keras.callbacks.Callback): 
    def __init__(self, min_lr, max_lr, n_rounds): 
        self.min_lr = min_lr
        self.max_lr = max_lr
        self.step_up = (max_lr / min_lr) ** (1 / n_rounds)
        self.lrs = []
        self.losses = []
     
    def on_train_begin(self, logs=None):
        self.weights = self.model.get_weights()
        self.model.optimizer.lr = self.min_lr

    def on_train_batch_end(self, batch, logs=None):
        self.lrs.append(self.model.optimizer.lr.numpy())
        self.losses.append(logs["TotalL"])
        self.model.optimizer.lr = self.model.optimizer.lr * self.step_up
        if self.model.optimizer.lr > self.max_lr:
            self.model.stop_training = True
        
    def on_train_end(self, logs=None):
        self.model.set_weights(self.weights)

In [None]:
modelName = "MobileNetV3_FPN_TTFNet"

with open(os.path.join("model/0_Config", modelName+".json"), "r") as config_file:
    config = json.load(config_file)

In [None]:
config['training_config']['num_classes'] = 80
train_dataset = Dataset_COCO(config, mode = 'train')

In [None]:
EPOCHS = 1
lr_finder_steps = 400
lr_find = LRFind(1e-6, 1e1, lr_finder_steps)

model = ModelBuilder(config = config)
optimizer = GCSGD(learning_rate = 1e-1, momentum=0.9, nesterov=False)
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
tf.keras.mixed_precision.set_global_policy(tf.keras.mixed_precision.Policy('mixed_float16'))
model.compile(loss=CenterNetLoss(config), optimizer=optimizer, weighted_metrics=[])

In [None]:
model.fit(
    train_dataset.dataset,
    steps_per_epoch=lr_finder_steps,
    epochs=EPOCHS,
    callbacks=[lr_find]
)

plt.plot(lr_find.lrs, lr_find.losses)
plt.xscale('log')
plt.yscale('log')
plt.show()