In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import keras
from model.models import get_model, metric_dist
from tf_dataset.tf_dataset import get_tf_dataset

## 1. Train with $\rho=16$

In [None]:
root = "/home/ji/Dropbox/Robotics/CMSC733/Project1/Phase2/Data"

train_ds = get_tf_dataset(root+"/Train_Resize",
                          mode="unsupervised",
                          do_resize=False,
                          rho=16)

val_ds = get_tf_dataset(root+"/Val_Resize",
                          mode="unsupervised",
                          do_resize=False,
                          rho=16)

In [None]:
# create model
batch_size = 8
monitor_name = "mae_loss"
checkpoint_path = f"./chkpt/mdl_unsupervised_rho16"
model = get_model(mode="unsupervised")

try:
    model.load_weights(checkpoint_path)
    print("weight loaded")
except:
    pass

model.compile(optimizer=keras.optimizers.Adam(learning_rate=2e-4,
                                              clipvalue=0.01),
              run_eagerly=False)

In [None]:
num_epochs=50
steps_per_epoch = int(np.floor(5000/batch_size))
    # reduce learning rate when performance plateau
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor=monitor_name,
                                                factor=0.2,
                                                patience=3,
                                                min_lr=1e-6,
                                                verbose=1,
                                                cooldown=3)

checkpoint_callback = keras.callbacks.ModelCheckpoint(
                                                filepath=checkpoint_path,
                                                save_weights_only=True,
                                                monitor=monitor_name,
                                                mode='min',
                                                save_best_only=True,
                                                verbose=True)

for _ in range(10):
    try:
        history = model.fit(train_ds,
                            epochs=num_epochs,
                            steps_per_epoch=steps_per_epoch,
                            validation_data=val_ds,
                            validation_steps=int(np.floor(1000/batch_size)),
                            validation_freq=1,
                            verbose=True,
                            callbacks=[reduce_lr, checkpoint_callback])
    except:
        model.load_weights(checkpoint_path)
        print("======================== reset ==========================")



## 2. Train with $\rho=32$

In [None]:
root = "/home/ji/Dropbox/Robotics/CMSC733/Project1/Phase2/Data"

# get new dataset
train_ds2 = get_tf_dataset(root+"/Train_Resize",
                          mode="unsupervised",
                          do_resize=False,
                          rho=32)

val_ds2 = get_tf_dataset(root+"/Val_Resize",
                          mode="unsupervised",
                          do_resize=False,
                          rho=32)

batch_size = 8
monitor_name = "mae_loss"
checkpoint_path = f"./chkpt/mdl_unsupervised_rho32"
model=get_model(mode="unsupervised")
model.load_weights(checkpoint_path)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3,
                                              clipvalue=0.01),
              run_eagerly=False)
model.optimizer.learning_rate=8e-6
num_epochs=50
steps_per_epoch = int(np.floor(5000/batch_size))
    # reduce learning rate when performance plateau
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor=monitor_name,
                                                factor=0.2,
                                                patience=3,
                                                min_lr=1e-6,
                                                verbose=1,
                                                cooldown=3)

checkpoint_callback = keras.callbacks.ModelCheckpoint(
                                                filepath=checkpoint_path,
                                                save_weights_only=True,
                                                monitor=monitor_name,
                                                mode='min',
                                                save_best_only=True,
                                                verbose=True)

for _ in range(10):
    try:
        history = model.fit(train_ds2,
                            epochs=num_epochs,
                            steps_per_epoch=steps_per_epoch,
                            validation_data=val_ds2,
                            validation_steps=int(np.floor(1000/batch_size)),
                            validation_freq=1,
                            verbose=True,
                            callbacks=[reduce_lr, checkpoint_callback])
    except:
        model.load_weights(checkpoint_path)
        print("======================== reset ==========================")
