In [1]:
import os
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

import dataset
import schedulers

from tensorflow import config
assert config.list_physical_devices('GPU')
import tensorflow as tf

from tensorflow_addons.layers.crf import CRF
from tensorflow_addons.text.crf import crf_log_likelihood

In [None]:

def unpack_data(data):
    if len(data) == 2:
        return data[0], data[1], None
    elif len(data) == 3:
        return data
    else:
        raise TypeError("Expected data to be a tuple of size 2 or 3.")


class ModelWithCRFLoss(keras.Model):
    """Wrapper around the base model for custom training logic."""

    def compute_loss(self, x, y, sample_weights, training=False):
        y_pred = self(x, training=training)
        # _, potentials, sequence_length, chain_kernel = y_pred
        potentials, sequence_length, chain_kernel = y_pred

        crf_loss = -crf_log_likelihood(potentials, y, sequence_length, chain_kernel)[0]

        if sample_weights is not None:
            crf_loss = crf_loss * sample_weights

        return tf.reduce_mean(crf_loss), sum(self.losses)

    def train_step(self, data):
        x, y, sample_weight = unpack_data(data)

        with tf.GradientTape() as tape:
            crf_loss, internal_losses = self.compute_loss(
                x, y, sample_weight, training=True
            )
            total_loss = crf_loss + internal_losses

        gradients = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {"crf_loss": crf_loss, "internal_losses": internal_losses}

    def test_step(self, data):
        x, y, sample_weight = unpack_data(data)
        crf_loss, internal_losses = self.compute_loss(x, y, sample_weight)
        return {"crf_loss_val": crf_loss, "internal_losses_val": internal_losses}



In [83]:
BATCH_SIZE = 32

LETTERS_SIZE = len(dataset.letters_table)
NIQQUD_SIZE = len(dataset.niqqud_table)
DAGESH_SIZE = len(dataset.dagesh_table)
SIN_SIZE = len(dataset.sin_table)
KINDS_SIZE = len(dataset.KINDS)

def build_model(EMBED_DIM=10, UNITS=190):
    layer = input_text = keras.Input(batch_shape=(None, None), batch_size=BATCH_SIZE)
    layer = layers.Embedding(LETTERS_SIZE, EMBED_DIM, input_length=None, mask_zero=True)(layer)
    
    layer = layers.Dense(UNITS, activation=None)(layer)
    
    bidi = layers.Bidirectional(layers.LSTM(UNITS, return_sequences=True, dropout=0.0), merge_mode='sum')
    layer = bidi(layer)
    layer = layers.concatenate([
        layers.add([layer, bidi(layer)]), 
        layers.subtract([layer, bidi(layer)])
    ])

    
#     outputs = [
#         CRF(NIQQUD_SIZE, name='N')(layer),
#         CRF(DAGESH_SIZE, name='D')(layer),
#         CRF(SIN_SIZE, name='S')(layer),
#     ]
#     model = ModelWithCRFLoss(inputs=input_text, outputs=outputs)
    
    outputs = [
        layers.Softmax(name='N')(layers.Dense(NIQQUD_SIZE)(layer)),
        layers.Softmax(name='D')(layers.Dense(DAGESH_SIZE)(layer)),
        layers.Softmax(name='S')(layers.Dense(SIN_SIZE)(layer)),
    ]
    model = keras.Model(inputs=input_text, outputs=outputs)
    # model.build((None, MAXLEN))

    jsmodel = model 
    # keras.utils.plot_model(model, to_file='model.png')
    return model, jsmodel

model, jsmodel = build_model()

model.summary()
model.save_weights('./checkpoints/uninit')

Model: "model_15"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_20 (InputLayer)           [(None, None)]       0                                            
__________________________________________________________________________________________________
embedding_19 (Embedding)        (None, None, 10)     440         input_20[0][0]                   
__________________________________________________________________________________________________
dense_64 (Dense)                (None, None, 190)    2090        embedding_19[0][0]               
__________________________________________________________________________________________________
bidirectional_18 (Bidirectional (None, None, 190)    579120      dense_64[0][0]                   
                                                                 bidirectional_18[0][0]    

In [3]:

def accuracy(y_true, y_pred):
    K = keras.backend
    f = K.floatx()
    # convert dense predictions to labels
    y_pred_labels =  K.cast(K.argmax(y_pred, axis=-1), f)
    
    res = K.cast(K.equal(y_true, y_pred_labels), f)
    return K.sum(res) / K.sum(K.cast(K.not_equal(y_true, 0), f))
    # return tf.gather(res, tf.where(K.not_equal(y_true, 0)))


In [4]:
def fit(train_validation, scheduler=None, verbose=1, lr=1e-4):
    train, valid = train_validation
    model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate=lr), metrics=[accuracy])
    callbacks = []
    if isinstance(scheduler, schedulers.CircularLearningRate):
        scheduler.set_dataset(train, BATCH_SIZE)
    if scheduler:
        callbacks.append(scheduler)
        
    x  = train.normalized
    vx = valid.normalized
    
    y  = {'N': train.niqqud, 'D': train.dagesh, 'S': train.sin }
    vy = {'N': valid.niqqud, 'D': valid.dagesh, 'S': valid.sin }
    
    return model.fit(x, y, validation_data=(vx, vy), batch_size=BATCH_SIZE, epochs=1, verbose=verbose, callbacks=callbacks)


In [17]:
MAXLEN = 82
def load_data(source, maxlen=MAXLEN, validation=0.1):
    filenames = [os.path.join('texts', f) for f in source]
    train, valid = dataset.load_data(filenames, validation, maxlen=maxlen)
    return train, valid

In [18]:
data_other = load_data(['biblical', 'garbage', 'poetry'])

In [19]:
data_rabanit = load_data(['rabanit'])

In [20]:
data_pre_modern = load_data(['pre_modern'])

In [71]:
data_modern = load_data(validation=0.2, source=['modern'])

In [40]:
model.load_weights('./checkpoints/uninit')
history = fit(data_other, scheduler=schedulers.CircularLearningRate(30e-4, 150e-4, 5e-4))
model.save_weights('./checkpoints/other')



In [41]:
model.load_weights('./checkpoints/other')
history = fit(data_rabanit, scheduler=schedulers.CircularLearningRate(30e-4, 50e-4, 5e-4))
model.save_weights('./checkpoints/rabanit')



In [42]:
model.load_weights('./checkpoints/rabanit')
history = fit(data_pre_modern, scheduler=schedulers.CircularLearningRate(30e-4, 80e-4, 1e-4))
model.save_weights('./checkpoints/pre_modern')



In [46]:
data_mix = load_data(['biblical', 'garbage', 'poetry', 'rabanit', 'pre_modern'])

In [76]:
model.load_weights('./checkpoints/uninit')
history = fit(data_mix, scheduler=schedulers.CircularLearningRate(30e-4, 80e-4, 1e-4))
model.save_weights('./checkpoints/mix')



In [84]:
model.load_weights('./checkpoints/uninit')
history = fit(data_modern, scheduler=schedulers.CircularLearningRate(5e-3, 6e-3, 6e-5))
model.save_weights('./checkpoints/modern')



In [None]:
for i in range(50):
    model.load_weights('./checkpoints/pre_modern')
    p1 = np.exp(np.random.uniform(low=np.log(1e-5), high=np.log(1e-2)))
    p2 = np.exp(np.random.uniform(low=np.log(1e-4), high=np.log(1e-1)))
    p3 = np.exp(np.random.uniform(low=np.log(1e-5), high=np.log(1e-2)))
    print(p1, p2, p3, end=', ', sep=', ')
    history = fit(data_modern, scheduler=schedulers.CircularLearningRate(p1, p2, p3), verbose=0)
    print(history.history['val_N_accuracy'][0])

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(nrows=2, ncols=2)

for n, v in enumerate(['accuracy', 'loss'], 0):
    for n1, t in enumerate(['D', 'N'], 0):
        p = ax[n][n1]
        p.plot(history.history[t + '_' + v][0:])
        p.plot(history.history['val_' + t + '_' +  v][0:])
        p.legend([t + '_Train', t + '_Test'], loc='center right')

plt.tight_layout()

In [None]:
import tensorflowjs as tfjs
jsmodel.load_weights('./checkpoints/modern')
tfjs.converters.save_keras_model(jsmodel, '.')

In [73]:
model.load_weights('./checkpoints/modern')

def print_predictions(data, s):
    batch = data.normalized[s]
    prediction = model.predict(batch)
    [actual_niqqud, actual_dagesh, actual_sin] = [dataset.from_categorical(prediction[0]), dataset.from_categorical(prediction[1]), dataset.from_categorical(prediction[2])]
    [expected_niqqud, expected_dagesh, expected_sin] = [data.niqqud[s], data.dagesh[s], data.sin[s]]
    actual = dataset.merge(batch, ns=actual_niqqud, ds=actual_dagesh, ss=actual_sin)
    expected = dataset.merge(batch, ns=expected_niqqud, ds=expected_dagesh, ss=expected_sin)
    total = []
    for i, (a, e) in enumerate(zip(actual, expected)):
        print('מצוי: ', a)
        print('רצוי: ', e)
        last = expected_niqqud[i].tolist().index(0)
        res = expected_niqqud[i][:last] == actual_niqqud[i][:last]
        total.extend(res)
        print(round(np.mean(res), 2), f'({last - sum(res)} out of {last})')
        print()
    print(round(np.mean(total), 3))

print_predictions(data_modern[1], slice(0, None))

מצוי:  יֵשׁ מַנְגְּנוֹנִים בֵּינְלְאוּמִּיִּים שֶׁמַּסְדִּירִים תְּבִיעוֹת בֵּין מְדִינוֹת, וְאִם יִשְׂרָאֵל כַּמְדִינָה תִּרְצֶה לִתְבּוֹעַ 
רצוי:  יֵשׁ מַנְגְּנוֹנִים בֵּינְלְאוּמִּיִּים שֶׁמַּסְדִּירִים תְּבִיעוֹת בֵּין מְדִינוֹת, וְאִם יִשְׂרָאֵל כִּמְדִינָה תִּרְצֶה לִתְבּוֹעַ 
0.99 (1 out of 79)

מצוי:  בְּסֵפֶר הַמְּקוֹרִי, אֲשֶׁר יָצָא בְּ-5555, לֹא הִתְאִים לְטוֹן שֶׁל הַסְּפָרִים הַמְּאוּחָרִים יוֹתֵר. קִינְג הִרְגִּישׁ 
רצוי:  בַּסֵּפֶר הַמְּקוֹרִי, אֲשֶׁר יָצָא בְּ-5555, לֹא הִתְאִים לַטּוֹן שֶׁל הַסְּפָרִים הַמְּאוּחָרִים יוֹתֵר. קִינְג הִרְגִּישׁ 
0.97 (2 out of 79)

מצוי:  מִשְׁכָּרָם מֵאֲשֶׁר לְפַטֵּר אֶת שְׁכְנֵיהֶם, וְעוֹבְדִים שֶׁמַּעֲדִיפִים לְקַצֵץ בְּשָׁעוֹת שֶׁלָּהֶם מֵאֲשֶׁר לִרְאוֹת חָבֵר מְאַבֵד 
רצוי:  מִשְּׂכָרָם מֵאֲשֶׁר לְפַטֵּר אֶת שְׁכֵנֵיהֶם, וְעוֹבְדִים שֶׁמַּעֲדִיפִים לְקַצֵּץ בַּשָּׁעוֹת שֶׁלָּהֶם מֵאֲשֶׁר לִרְאוֹת חָבֵר מְאַבֵּד 
0.98 (2 out of 80)

מצוי:  כְּבָר נִרְאִים חוֹלִים, שֶׁהֵם צְרִיכִים פֵּרוֹת וִירָקוֹת טְרִיִים, תַּבְשִׁילִי שֶׁעוֹעִית 

In [None]:
shutil.rmtree(os.sep.join([tempfile.gettempdir(), '.tensorboard-info']), ignore_errors=True)
shutil.rmtree('logs', ignore_errors=True)
os.makedirs('logs')
# %tensorboard --logdir logs

In [None]:
print(data_modern[1].text[0])
print(data_modern[1].text[1])

In [None]:
[hex(ord(x)) for x in 'כָּ']