In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

import tensorflow as tf
import tensorflowjs as tfjs

import wandb
from wandb.keras import WandbCallback

import dataset
import schedulers

assert tf.config.list_physical_devices('GPU')

In [3]:
LETTERS_SIZE = len(dataset.letters_table)
NIQQUD_SIZE = len(dataset.niqqud_table)
DAGESH_SIZE = len(dataset.dagesh_table)
SIN_SIZE = len(dataset.sin_table)

def build_model(units=500, maxlen=64):
    inp = keras.Input(shape=(maxlen,), batch_size=None)
    embed = layers.Embedding(LETTERS_SIZE, units, mask_zero=True)(inp)
    
    layer = layers.Bidirectional(layers.LSTM(units, return_sequences=True), merge_mode='sum')(embed)
    layer = layers.add([layer, layers.Bidirectional(layers.LSTM(units, return_sequences=True), merge_mode='sum')(layer)])
    layer = layers.BatchNormalization()(layer)
    layer = layers.add([embed, layers.Dense(units, activation='relu')(layer)])

    outputs = [
        layers.Softmax(name='N')(layers.Dense(NIQQUD_SIZE)(layer)),
        layers.Softmax(name='D')(layers.Dense(DAGESH_SIZE)(layer)),
        layers.Softmax(name='S')(layers.Dense(SIN_SIZE)(layer)),
    ]
    model = keras.Model(inputs=inp, outputs=outputs)

    return model


In [4]:

# masked version of accuracy and sce
def accuracy(real, pred):
    acc = tf.keras.metrics.sparse_categorical_accuracy(real, pred)

    mask = tf.cast(tf.math.logical_not(tf.math.equal(real, 0)), dtype=acc.dtype)
    acc *= mask

    return tf.reduce_sum(acc) / tf.reduce_sum(mask)

def sparse_categorical_crossentropy(y_true, y_pred, sample_weight=None):
    loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)

    mask = tf.cast(tf.math.logical_not(tf.math.equal(y_true, 0)), dtype=loss.dtype)
    loss *= mask

    return tf.reduce_sum(loss) / tf.reduce_sum(mask) 


In [5]:
UNITS = 500
BATCH_SIZE = 64
MAXLEN = 64

In [99]:
data = {}
data['mix'] = dataset.load_data([
    'hebrew_diacritized_private/poetry',
    'hebrew_diacritized_private/rabanit',
    'hebrew_diacritized_private/pre_modern'], validation_rate=0.1, maxlen=MAXLEN)

data['modern'] = dataset.load_data([
    'hebrew_diacritized/modern'], validation_rate=0.2, maxlen=MAXLEN)

In [100]:
%env WANDB_MODE run


model = build_model(units=UNITS, maxlen=MAXLEN)

model.compile(loss=sparse_categorical_crossentropy, optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              metrics={'N': accuracy, 'D': accuracy, 'S': accuracy})

model.save_weights('./checkpoints/uninit')

config = {
    'batch_size': BATCH_SIZE,
    'units': UNITS,
    'maxlen': MAXLEN,
    'model': model,
    'batch_size': BATCH_SIZE,
}

config['order'] = [
    ('mix', [(30e-4, 80e-4, 1e-4)], 'mix'),
    ('modern', [(50e-4, 50e-4, 1e-5)], 'modern'),
    ('modern', [(50e-4, 50e-4, 1e-5), (50e-4, 50e-4, 1e-5)], 'modern_over'),
]

run = wandb.init(project="dotter",
                 name='cleaned-rafe-new-aftercleanup',
                 tags=['CLR', 'ordered', 'new-style'],
                 config=config)

def get_xy(d):
    if d is None:
        return None
    x = d.normalized
    y = {'N': d.niqqud, 'D': d.dagesh, 'S': d.sin }
    return (x, y)

with run:
    for kind, clrs, save in config['order']:
        train, validation = data[kind]
        
        training_data = (x, y) = get_xy(train)
        validation_data = get_xy(validation)
        
        wandb_callback = WandbCallback(log_batch_frequency=10, training_data=training_data, validation_data=validation_data,
                                       log_weights=True)
        
        for clr in clrs:
            scheduler = schedulers.CircularLearningRate(*clr)
            scheduler.set_dataset(train, BATCH_SIZE)
            callbacks = [wandb_callback, scheduler]
            history = model.fit(x, y, validation_data=validation_data, batch_size=BATCH_SIZE, verbose=1, callbacks=callbacks)
        model.save(os.path.join(wandb.run.dir, save + ".h5"))
        model.save_weights('./checkpoints/' + save)


env: WANDB_MODE=run


Error generating diff: Command '['git', 'diff', '--submodule=diff', 'HEAD']' timed out after 5 seconds




In [None]:
model.load_weights('./checkpoints/modern_over')

model.compile()
model.save('modern.h5')
tfjs.converters.save_keras_model(model, '.')

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(nrows=2, ncols=3)

for n, v in enumerate(['accuracy', 'loss'], 0):
    for n1, t in enumerate(['N', 'D', 'S'], 0):
        p = ax[n][n1]
        p.plot(history.history[t + '_' + v][0:])
        p.plot(history.history['val_' + t + '_' +  v][0:])
        p.legend([t + '_Train', t + '_Test'], loc='center right')

plt.tight_layout()

In [101]:
model = build_model(units=UNITS, maxlen=MAXLEN)
model.load_weights('./checkpoints/modern_over')
test, _ = dataset.load_data(['test/modernTestCorpus/'], 0, MAXLEN)
x = test.normalized
y = {'N': test.niqqud, 'D': test.dagesh, 'S': test.sin }

model.compile(loss=sparse_categorical_crossentropy,
              metrics={'N': accuracy, 'D': accuracy, 'S': accuracy})

model.evaluate(x=x, y=y, batch_size=BATCH_SIZE)



[0.40030497312545776,
 0.2414025068283081,
 0.09960678964853287,
 0.059295691549777985,
 0.9323585629463196,
 0.9679851531982422,
 0.9871422052383423]

In [62]:
model.load_weights('./checkpoints/modern_over')

def real_evaluation(data, s=slice(0, None), print_comparison=True):
    batch = data.normalized[s]
    prediction = model.predict(batch)
    [actual_niqqud, actual_dagesh, actual_sin] = [dataset.from_categorical(prediction[0]), dataset.from_categorical(prediction[1]), dataset.from_categorical(prediction[2])]
    [expected_niqqud, expected_dagesh, expected_sin] = [data.niqqud[s], data.dagesh[s], data.sin[s]]
    actual = dataset.merge(data.text[s], batch, actual_niqqud, actual_dagesh, actual_sin)
    expected = dataset.merge(data.text[s], batch, expected_niqqud, expected_dagesh, expected_sin)
    total_letters = []
    total_words = []
    for i, (b, a, e) in enumerate(zip(batch, actual, expected)):
        letters = []
        letters.extend(expected_niqqud[i][expected_niqqud[i]>0] == actual_niqqud[i][expected_niqqud[i]>0])
        letters.extend(expected_dagesh[i][expected_dagesh[i]>0] == actual_dagesh[i][expected_dagesh[i]>0])
        letters.extend(expected_sin[i][expected_sin[i]>0] == actual_sin[i][expected_sin[i]>0])
        total_letters.extend(letters)
        words = []
        for aw, ew in zip(a.split(), e.split()):
            if len([x for x in 'אבגדהוזחטיכלמנסעפצקרשתךםןףץ' if x in aw]) > 1:
                words.append(aw == ew)
                if print_comparison and aw != ew:
                    print(aw, ew)
        total_words.extend(words)
        if print_comparison:
            print('מצוי: ', a)
            print('רצוי: ', e)
            print(f'{np.mean(letters):.2%} ({len(letters)-np.sum(letters)} out of {len(letters)})')
            print(f'{np.mean(words):.2%} ({len(words)-np.sum(words)} out of {len(words)})')
            print()
    print(f'letters: {np.mean(total_letters):.2%}, words: {np.mean(total_words):.2%}')

real_evaluation(test, s=slice(0, None), print_comparison=True)

וִיקְטְרִינָה וְיַקְטָרִינָה
חִילְּקוּ, חִילְקוֹ,
לִסְפּוֹרְטָאִית לַסְּפּוֹרְטָאִית
הָאוֹזְבָּקִית הָאוּזְבֵּקִית
מצוי:  וִיקְטְרִינָה חִילְּקוּ, שֶׁהָיְיתָה לִסְפּוֹרְטָאִית הָאוֹזְבָּקִית הָרִאשׁוֹנָה, וְהַיְּחִידָה עַד 
רצוי:  וְיַקְטָרִינָה חִילְקוֹ, שֶׁהָיְיתָה לַסְּפּוֹרְטָאִית הָאוּזְבֵּקִית הָרִאשׁוֹנָה, וְהַיְּחִידָה עַד 
88.66% (11 out of 97)
50.00% (4 out of 8)

לְשַׇׁלְחוֹ לְשׇׁלְחוֹ
לְהָלִיוֹפּוּלִיס לְהֶלְיוֹפּוֹלִיס
שֶׁיחֲקוֹר שֶׁיַּחְקוֹר
מצוי:  לְשַׇׁלְחוֹ לְהָלִיוֹפּוּלִיס בְּמִצְרַיִם כְּדֵי שֶׁיחֲקוֹר שָׁם אֶת תְּחוּמֵי הָאַסְטְרוֹנוֹמְיָה 
רצוי:  לְשׇׁלְחוֹ לְהֶלְיוֹפּוֹלִיס בְּמִצְרַיִם כְּדֵי שֶׁיַּחְקוֹר שָׁם אֶת תְּחוּמֵי הָאַסְטְרוֹנוֹמְיָה 
91.21% (8 out of 91)
66.67% (3 out of 9)

בְּלוֹאַר בְּלוֹאֵר
סַיָּיד סַיְיד
בָּנְיוּ בִּנְיוּ
אַׇשְׁרֵי אׇשְׁרִי
מצוי:  הַגָּדוֹל בְּלוֹאַר אִיסְט סַיָּיד שֶׁל מַנְהֵטְן בָּנְיוּ יוֹרְק. הָרַב אַׇשְׁרֵי נִפְטַר בָּב' 
רצוי:  הַגָּדוֹל בְּלוֹאֵר אִיסְט סַיְיד שֶׁל מַנְהֵטְן בִּנְיוּ יוֹרְק. הָרַב אׇשְׁרִי נִפְטַר בְּב' 
9

In [69]:
import hebrew
import dataset