In [None]:
%load_ext autoreload

In [None]:
import os
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
assert tf.config.list_physical_devices('GPU')

%autoreload
import dataset

import schedulers

%autoreload
import transformer

In [None]:
BATCH_SIZE = 32

MAXLEN = 50

LETTERS_SIZE = len(dataset.letters_table)
NIQQUD_SIZE = len(dataset.niqqud_table)
DAGESH_SIZE = len(dataset.dagesh_table)
SIN_SIZE = len(dataset.sin_table)

d_model = 500

model = transformer.Transformer(
    num_layers=2,
    d_model=d_model,
    num_heads=5,
    dff=128,
    input_vocab_size=LETTERS_SIZE,
    target_vocab_size=NIQQUD_SIZE, 
    maximum_position_encoding_input=MAXLEN, 
    maximum_position_encoding_target=MAXLEN,
    rate=0.0
)

# model.build((None, MAXLEN))
# model.summary()
model.save_weights('./checkpoints/uninit')

In [None]:
def load_data(source, validation=0.1):
    filenames = [os.path.join('texts', f) for f in source]
    train, valid = dataset.load_data(filenames, validation, maxlen=MAXLEN)
    return train, valid


def fit(data, epochs=1):
    learning_rate = transformer.CustomSchedule(d_model, warmup_steps=1000)  # 8e-5
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        loss=transformer.train_loss
    )
    total = len(data[0])//BATCH_SIZE
    for epoch in range(epochs):
        transformer.train_loss.reset_states()
        transformer.train_accuracy.reset_states()
        for i in range(total):
            s = slice(i*BATCH_SIZE, (i+1)*BATCH_SIZE)
            res = model.train_step(data[0].normalized[s], data[0].niqqud[s])
            print(f"{i:4d}/{total:4d} - loss: {res['loss']:.4f} - accuracy: {res['acc']:.4f}", end='       \r')
        print()

In [None]:
data_mix = load_data(['poetry', 'rabanit', 'pre_modern'])

In [None]:
model.load_weights('./checkpoints/uninit')
history = fit(data_mix, epochs=2)
model.save_weights('./checkpoints/mix')

In [None]:
data_modern = load_data(validation=0.1, source=['modern'])

In [None]:
model.load_weights('./checkpoints/uninit')
history = fit(data_modern, epochs=1)
# print(true_accuracy(data_modern))
model.save_weights('./checkpoints/modern')

In [None]:
model.summary()

In [None]:
model.load_weights('./checkpoints/modern')

def print_predictions(data, s):
    batch = data.normalized[s]
    prediction = model.predict(batch)
    [actual_niqqud, actual_dagesh, actual_sin] = [dataset.from_categorical(prediction[0]), dataset.from_categorical(prediction[1]), dataset.from_categorical(prediction[2])]
    [expected_niqqud, expected_dagesh, expected_sin] = [data.niqqud[s], data.dagesh[s], data.sin[s]]
    actual = dataset.merge(data.text[s], ts=batch, ns=actual_niqqud, ds=actual_dagesh, ss=actual_sin)
    expected = dataset.merge(data.text[s], ts=batch, ns=expected_niqqud, ds=expected_dagesh, ss=expected_sin)
    total = []
    for i, (a, e) in enumerate(zip(actual, expected)):
        print('מצוי: ', a)
        print('רצוי: ', e)
        last = expected_niqqud[i].tolist().index(0)
        res = expected_niqqud[i][:last] == actual_niqqud[i][:last]
        total.extend(res)
        print(round(np.mean(res), 2), f'({last - sum(res)} out of {last})')
        print()
    print(round(np.mean(total), 3))

print_predictions(data_modern[1], slice(0, None))

In [None]:
def plot_attention_weights(attention, sentence, result, layer):
    fig = plt.figure(figsize=(16, 8))

    sentence = tokenizer_pt.encode(sentence)

    attention = tf.squeeze(attention[layer], axis=0)

    for head in range(attention.shape[0]):
        ax = fig.add_subplot(2, 4, head+1)

        # plot the attention weights
        ax.matshow(attention[head][:-1, :], cmap='viridis')

        fontdict = {'fontsize': 10}

        ax.set_xticks(range(len(sentence)+2))
        ax.set_yticks(range(len(result)))

        ax.set_ylim(len(result)-1.5, -0.5)

        ax.set_xticklabels(
            ['<start>']+[tokenizer_pt.decode([i]) for i in sentence]+['<end>'], 
            fontdict=fontdict, rotation=90)

        ax.set_yticklabels([tokenizer_en.decode([i]) for i in result 
                            if i < tokenizer_en.vocab_size], 
                           fontdict=fontdict)

        ax.set_xlabel('Head {}'.format(head+1))

    plt.tight_layout()
    plt.show()