In [1]:
%load_ext autoreload

In [26]:
import os
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
assert tf.config.list_physical_devices('GPU')

import collections

import dataset

%autoreload
import transformer

In [27]:
BATCH_SIZE = 32

MAXLEN = 50

LETTERS_SIZE = len(dataset.letters_table)
NIQQUD_SIZE = len(dataset.niqqud_table)
DAGESH_SIZE = len(dataset.dagesh_table)
SIN_SIZE = len(dataset.sin_table)

num_layers = 1
num_heads = 4
d_model = 220 // num_heads * num_heads
dff = 512

model = transformer.Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=LETTERS_SIZE,
    output_sizes=[NIQQUD_SIZE, DAGESH_SIZE, SIN_SIZE],
    maximum_position_encoding_input=MAXLEN,
    maximum_position_encoding_target=MAXLEN,
    rate=0.0
)

learning_rate = transformer.CustomSchedule(d_model, warmup_steps=3000)
lr = 12e-4
model.compile(
    optimizer=tf.keras.optimizers.Adam(lr),
    loss='sparse_categorical_crossentropy',  # transformer.MaskedCategoricalCrossentropy(),  # tf.keras.losses.sparse_categorical_crossentropy,
    metrics=['accuracy']  # unmasked, so incorrect
)
# pseudo "build" step, to allow printing a summary:
# model.run_eagerly = True
h = model.pseudo_build(MAXLEN, MAXLEN)
model.summary()
model.save_weights('./checkpoints/uninit')

Model: "transformer_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_8 (Encoder)          multiple                  431052    
_________________________________________________________________
decoder_8 (Decoder)          multiple                  621352    
_________________________________________________________________
softmax_13 (Softmax)         multiple                  0         
_________________________________________________________________
dense_157 (Dense)            multiple                  3536      
_________________________________________________________________
softmax_14 (Softmax)         multiple                  0         
_________________________________________________________________
dense_158 (Dense)            multiple                  663       
_________________________________________________________________
softmax_15 (Softmax)         multiple                

In [42]:
def load_data(source, validation=0.1):
    filenames = [os.path.join('texts', f) for f in source]
    train, valid = dataset.load_data(filenames, validation, maxlen=MAXLEN)
    return train, valid


def batch_zip(x, *ys):
    for i in range(len(x) // BATCH_SIZE):
        s = slice(i*BATCH_SIZE, (i+1)*BATCH_SIZE)
        yield (i, x[s], [y[s] for y in ys])
    

def fit(data, epochs=1):
    train, valid = data
    total_train = len(train)//BATCH_SIZE
    total_valid = len(valid)//BATCH_SIZE
    history = collections.defaultdict(list)
    for epoch in range(epochs):
        model.reset_metrics()
        for batch, x, ys in batch_zip(train.normalized, train.niqqud, train.dagesh, train.sin):
            res = model.train_step(x, *ys)
            
            print(f"\r{batch:4d}/{total_train:4d} - {' - '.join(f'{k}: {v:.4f}' for k, v in res.items())}", end='')
            
        model.reset_metrics()
        
        for batch, x, ys in batch_zip(valid.normalized, valid.niqqud, valid.dagesh, valid.sin):
            res = model.test_step(x, *ys)

        print(''.join(f" - {k}: {v:.4f}" for k, v in res.items()))
        
        for k, v in res.items():
            history[k].append(res[k].numpy())

    return history

In [36]:
t, valid = data_modern
s = slice(0, 1)
model.train_step(t.normalized[s], t.niqqud[s], t.dagesh[s], t.sin[s])

{'loss': <tf.Tensor: shape=(), dtype=float32, numpy=1.534505>,
 'output_1_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.26333332>,
 'output_2_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.70666665>,
 'output_3_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.64>}

In [None]:
data_other = load_data(['biblical', 'garbage'])

In [None]:
data_mix = load_data(['poetry', 'rabanit', 'pre_modern'])

In [28]:
data_modern = load_data(validation=0.1, source=['modern'])

In [None]:
model.load_weights('./checkpoints/uninit')
history = fit(data_other, epochs=1)
model.save_weights('./checkpoints/other')

In [None]:
model.load_weights('./checkpoints/other')
history = fit(data_mix, epochs=1)
model.save_weights('./checkpoints/mix')

In [None]:
model.load_weights('./checkpoints/uninit')
history = fit(data_modern, epochs=1) # lr=0.0012, d_model=220, num_layers=1, num_heads=4, dff=512 - accuracy: 0.8314
print(f'{lr=}, {d_model=}, {num_layers=}, {num_heads=}, {dff=}')

# print(true_accuracy(data_modern)
# model.save_weights('./checkpoints/modern')

 677/ 678 - loss: 0.3486 - output_1_accuracy: 0.7302 - output_2_accuracy: 0.9445 - output_3_accuracy: 0.9817

In [None]:
model.load_weights('./checkpoints/modern')

def print_predictions(data, s):
    batch = data.normalized[s]
    prediction = model.predict(batch)
    [actual_niqqud, actual_dagesh, actual_sin] = [dataset.from_categorical(prediction[0]), dataset.from_categorical(prediction[1]), dataset.from_categorical(prediction[2])]
    [expected_niqqud, expected_dagesh, expected_sin] = [data.niqqud[s], data.dagesh[s], data.sin[s]]
    actual = dataset.merge(data.text[s], ts=batch, ns=actual_niqqud, ds=actual_dagesh, ss=actual_sin)
    expected = dataset.merge(data.text[s], ts=batch, ns=expected_niqqud, ds=expected_dagesh, ss=expected_sin)
    total = []
    for i, (a, e) in enumerate(zip(actual, expected)):
        print('מצוי: ', a)
        print('רצוי: ', e)
        last = expected_niqqud[i].tolist().index(0)
        res = expected_niqqud[i][:last] == actual_niqqud[i][:last]
        total.extend(res)
        print(f'{np.mean(res):.2f} ({last - sum(res)} out of {last})')
        print()
    print(round(np.mean(total), 3))

print_predictions(data_modern[1], slice(0, BATCH_SIZE))

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(nrows=1, ncols=len(history))

for i, v in enumerate(history.values()):
    ax[i].plot(v)

plt.tight_layout()

In [None]:
model.load_weights('./checkpoints/mix')

def predict(x):
    batch_len, timesteps = x.shape

    # we "know" that the first item is the start item
    y_probs = np.array([[[0] * NIQQUD_SIZE]] * batch_len)
    y_probs[:, 0, 1] = 1
    y_probs = tf.constant(y_probs, tf.float32)

    padding_mask = transformer.create_padding_mask(x)
    dec_target_padding_mask = transformer.create_padding_mask(x)
    
    invisible_future = tf.zeros([batch_len, timesteps], dtype=tf.int32)
    print(invisible_future.shape)
    
    for i in range(timesteps):
        y_pred = tf.cast(tf.argmax(y_probs, axis=-1), tf.int32)
        y_pred = tf.concat([y_pred, invisible_future[:, i+1:]], axis=-1)

        predictions, _ = model(x, y_pred, False, dec_target_padding_mask, padding_mask)
        predictions = predictions[: ,i:i+1, :]
        y_probs = tf.concat([y_probs, predictions], axis=1)
        
    y_probs = y_probs[:, 1:, :]
    return tf.cast(tf.argmax(y_probs, axis=-1), tf.int32).numpy()

d = data_modern[1]
n = slice(0, 1*BATCH_SIZE)
output = predict(d.normalized[n])
print(list(d.normalized[n][0]))
print(list(output[0]))
print(list(d.niqqud[0]))
# print(probs[0])
(d.niqqud[n] == output).mean()

In [None]:

def merge(normalized, prediction):
    sentence = []
    for c, n in zip(normalized, prediction):
        if c == dataset.letters_table.PAD_TOKEN:
            break
        sentence.append(dataset.letters_table.indices_char[c])
        sentence.append(dataset.niqqud_table.indices_char[n])
    return ''.join(sentence)

d = data_modern
text = d[1].normalized[0*BATCH_SIZE:1*BATCH_SIZE]
actual = d[1].niqqud[0*BATCH_SIZE:1*BATCH_SIZE]
padded_actual = np.hstack([np.ones((BATCH_SIZE, 1)), actual])[:, :-1]
print(padded_actual.shape)
print(text.shape)
dec_target_padding_mask = transformer.create_padding_mask(actual)
padding_mask = transformer.create_padding_mask(text)
prediction = model(text, padded_actual, False, dec_target_padding_mask, padding_mask)[0]  # np.argmax(history['predictions'], axis=-1)[0]
prediction = np.argmax(prediction, axis=-1)
n = 3
print(text[n])
print(prediction[n])
print(actual[n])
print(prediction[n] == actual[n])
print(np.mean(prediction == actual))
print(merge(text[n], prediction[n]))
print(merge(text[n], actual[n]))