In [None]:
import tensorflow as tf
import numpy as np
import IPython.display as ipd
import pickle
import datetime

from dataset import create_all_file_list, create_dataset
from lossfunction import ArcLoss, AdaptiveArcLoss, AdaptiveArcLossVer2
from metrics import AverageAngle, EqualErrorRate
from model import EmbedModel, CosineSimilarityModel, MyLRSchedule


In [None]:
# Run this cell for the first time run this notebook

#all_file = create_all_file_list()
#with open('file_list/all_file_list.pkl', 'wb') as f:
#    pickle.dump(all_file, f)

#all_file_test = create_all_file_list('/kaggle/input/darpa-timit-acousticphonetic-continuous-speech/data/TEST/*/*')
#with open('file_list/all_file_test_list.pkl', 'wb') as f:
#    pickle.dump(all_file_test, f)

In [None]:
with open('file_list/all_file_list.pkl', 'rb') as f:
    all_file_list = pickle.load(f)
with open('file_list/all_file_test_list.pkl', 'rb') as f:
    all_file_test_list = pickle.load(f)

In [None]:
inputs = tf.keras.Input(shape=(None,))
embed_model = EmbedModel()
classify_model = CosineSimilarityModel()
embedding_vector = embed_model(inputs)
outputs = classify_model(embedding_vector)
combined_model = tf.keras.Model(inputs, outputs)

ds_train = create_dataset(batch_size=32)

lr_schedule = MyLRSchedule()
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

margin_ratio = 0.3
initial_margin = 0.27

adaptive_arcloss = AdaptiveArcLossVer2(margin_ratio=margin_ratio, initial_margin=initial_margin)

angle_metric = AverageAngle()

log_dir = r"logs300/" + 'adaptivearcloss_mr' + str(margin_ratio).replace('.','') + '_i_' + str(initial_margin).replace('.','') + datetime.datetime.now().strftime("%Y%m%d-%H%M")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
combined_model.compile(optimizer=optimizer, loss=adaptive_arcloss, metrics=[angle_metric,'acc'])
combined_model.fit(ds_train, epochs=300, callbacks=[tensorboard_callback])

In [None]:
loss_config = adaptive_arcloss.get_config()
print(loss_config)

In [None]:
ds_test = create_dataset(test=True)
eer_metric = EqualErrorRate(ds_test)
print(eer_metric.calculate_eer(embed_model))

In [None]:
embed_model.save_weights(r"weight/embed_model.weights.h5")
classify_model.save_weights(r"weight/classify_model.weights.h5")