In [1]:
import os
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

import dataset
import schedulers

import tensorflow as tf
assert tf.config.list_physical_devices('GPU')

from tensorflow_addons.layers.crf import CRF
from tensorflow_addons.text.crf import crf_log_likelihood

In [44]:

def unpack_data(data):
    if len(data) == 2:
        return data[0], data[1], None
    elif len(data) == 3:
        return data
    else:
        raise TypeError("Expected data to be a tuple of size 2 or 3.")


accuracy = keras.metrics.SparseCategoricalAccuracy(name='accuracy')

class ModelWithCRFLoss(tf.keras.Model):
    """Wrapper around the base model for custom training logic."""

    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def call(self, inputs):
        return self.base_model(inputs)

    def compute_loss(self, x, y, sample_weights, training=False):
        y_pred = self(x, training=training)
        _, potentials, sequence_length, chain_kernel = y_pred
        
        # potentials: Tensor("model_with_crf_loss_4/model_1/N/add_1:0", shape=(None, 82, 16), dtype=float32)
        # expected: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer.

        # sequence_length: Tensor("model_with_crf_loss_4/model_1/N/Cast_8:0", shape=(None,), dtype=int64)
        # expected: A [batch_size] vector of true sequence lengths.

        # y: {'N': <tf.Tensor 'IteratorGetNext:2' shape=(None, 82) dtype=int32>,
        #     'D': <tf.Tensor 'IteratorGetNext:1' shape=(None, 82) dtype=int32>,
        #     'S': <tf.Tensor 'IteratorGetNext:3' shape=(None, 82) dtype=int32>}
        # expected: A [batch_size, max_seq_len] matrix of tag indices for which we compute the log-likelihood.

        # chain_kernel: <tf.Variable 'chain_kernel:0' shape=(16, 16) dtype=float32>
        # expected: A [num_tags, num_tags] transition matrix, if available.

        crf_loss = -crf_log_likelihood(potentials, y, sequence_length, chain_kernel)[0]

        if sample_weights is not None:
            crf_loss = crf_loss * sample_weights

        return tf.reduce_mean(crf_loss), sum(self.losses)

    def train_step(self, data):
        x, y, sample_weight = unpack_data(data)

        with tf.GradientTape() as tape:
            crf_loss, internal_losses = self.compute_loss(
                x, y, sample_weight, training=True
            )
            total_loss = crf_loss + internal_losses

        gradients = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
        return {"crf_loss": crf_loss, "internal_losses": internal_losses, **{m.name: m.result() for m in self.metrics}}

    def test_step(self, data):
        x, y, sample_weight = unpack_data(data)
        crf_loss, internal_losses = self.compute_loss(x, y, sample_weight)
        return {"crf_loss_val": crf_loss, "internal_losses_val": internal_losses, ** {m.name: m.result() for m in self.metrics}}




In [46]:
BATCH_SIZE = 32

LETTERS_SIZE = len(dataset.letters_table)
NIQQUD_SIZE = len(dataset.niqqud_table)
DAGESH_SIZE = len(dataset.dagesh_table)
SIN_SIZE = len(dataset.sin_table)

def build_model(EMBED_DIM=28, UNITS=128):
    inp = keras.Input(batch_shape=(None, None), batch_size=BATCH_SIZE)
    layer = layers.Embedding(LETTERS_SIZE, UNITS, mask_zero=True)(inp)
    layer = layers.Bidirectional(layers.LSTM(UNITS, return_sequences=True), merge_mode='sum')(layer)
    layer = CRF(NIQQUD_SIZE, name='N')(layer)
    model = ModelWithCRFLoss(keras.Model(inputs=inp, outputs=layer))
    model.build((None, None))
    return model

model = build_model()

model.summary()
model.save_weights('./checkpoints/crf_uninit')

Model: "model_with_crf_loss_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
model_8 (Model)              [(None, None), (None, Non 271152    
Total params: 271,152
Trainable params: 271,152
Non-trainable params: 0
_________________________________________________________________


In [27]:
def fit(train_validation):
    train, valid = train_validation
    model.compile()
    callbacks = []
        
    x  = train.normalized
    vx = valid.normalized
    
    y  = train.niqqud # , 'D': train.dagesh, 'S': train.sin }
    vy = valid.niqqud # , 'D': valid.dagesh, 'S': valid.sin }
    
    return model.fit(x, y, validation_data=(vx, vy), batch_size=BATCH_SIZE, epochs=1, metrics=["accuracy"])


def load_data(source, maxlen=82, validation=0.1):
    filenames = [os.path.join('texts', f) for f in source]
    train, valid = dataset.load_data(filenames, validation, maxlen=maxlen)
    return train, valid

In [8]:
data_ynet = load_data(validation=0.1, source=['modern/ynet'])

In [47]:
 history = fit(data_ynet)
 print(history.history)

{'crf_loss': [106.46131134033203], 'internal_losses': [0], 'val_crf_loss_val': [111.09480285644531], 'val_internal_losses_val': [0]}
