In [None]:
from src.data_utils.dataset import build_datset_train_val
from src.Transformers.transformers_prototype import build_transformer_model_v2
from src.callbacks import get_predefine_callbacks
import optuna
import tensorflow as tf
from src.constants import TARGET_MAX_LENGHT, MAX_LENGHT_SOURCE

In [None]:
train_dataset, val_dataset = build_datset_train_val(split=0.8, batch_size=512)

In [None]:
def objective(trial):
    tf.keras.backend.clear_session()

    model = build_transformer_model_v2(trial=trial)
    model.fit(train_dataset, validation_data=val_dataset, epochs=25, callbacks=get_predefine_callbacks(model_name="v2"))

    levenshtein = model.evaluate(val_dataset)[-1]
    # val_loss = model.evaluate(val_dataset)[0]
    # return  val_loss
    return  levenshtein 

In [None]:
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=10, gc_after_trial=True, show_progress_bar=True)

In [None]:
tf.keras.backend.clear_session()
trials = study.best_trials

for index, trial in enumerate(trials):
    print(f"Best model: {index+1}")
    model_name = "v3"

    model = build_transformer_model_v2(trial=trial)
    model.fit(train_dataset, validation_data=val_dataset, epochs=5000, callbacks=get_predefine_callbacks(model_name=model_name))
    print(model.summary())

    print('validation levenshtein distance: {}'.format(trial.value))
    print("Best hyperparameters: {}".format(trial.params))

    model.load_weights(f"../best_model/prototype/{model_name}")

    print(f"Metrics in Validation: {model.evaluate(val_dataset)}")

In [None]:
# save model

model.save("../models/v3/", save_format="tf")

In [None]:
from src.data_utils.dataset import char_to_num, num_to_char

In [None]:
target_sequence = [char_to_num[w] for w in ["<"]]

for batch in next(iter(val_dataset)):
    sources = batch["source"]
    targets = batch["target"]
    
    print(sources.shape)
    print(targets.shape)

    for source, target in zip(sources, targets):

        target_sequence = [char_to_num[w] for w in ["<"]]
        source_sequence = tf.expand_dims(source, axis=0)

        y_true = "".join([num_to_char[w] for w in target.numpy()])
    
        for i in range(TARGET_MAX_LENGHT):
            next_token = tf.expand_dims(tf.pad(tf.constant(target_sequence), [[0, TARGET_MAX_LENGHT-len(target_sequence)]], mode='CONSTANT', constant_values=0, name=None), axis=0)

            print("next target sequence: ", next_token)

            y_pred = model({"source": source_sequence, "target": next_token})

            y_pred = tf.cast(tf.argmax(y_pred, axis=2), dtype=tf.int32)

            print("argmax:", y_pred)

            mask = tf.not_equal(y_pred, 0)
            next_token = y_pred[mask][-1].numpy()

            print("next token: ", num_to_char[next_token], next_token)

            target_sequence.append(next_token)

            print("sequence so far: ", "".join([num_to_char[w] for w in target_sequence]))
            print("Label: ", y_true)

            if num_to_char[next_token]==">":
                break

        print("==========================================================================")

In [None]:
from src.custom.metrics import SparseLevenshteinV2

model_loaded = tf.keras.models.load_model("../models/v3/", custom_objects={"SparseLevenshtein": SparseLevenshtein})

In [None]:
#after register class as serializable
model_loaded = tf.keras.models.load_model("../models/v3/")