In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
import data

In [None]:
def create_tcm(tokenizer, vocabulary, corpus, skip_grams_window):
    """
    Term co-occuarnce matrix.
    It would've been dtm.dtmT if there wasn't a window size
    :param tokenizer:
    :param vocabulary:
    :param corpus:
    :param skip_grams_window:
    :return:
    """
    vocab_size = len(vocabulary)
    tcm = np.zeros((vocab_size, vocab_size))
    for words in data.corpus_to_token_ids(corpus, tokenizer, vocabulary):
        L = len(words)

        for i, word in enumerate(words):
            left_start, left_end = max(0, i - skip_grams_window), i
            right_start, right_end = i + 1, min(i + 1 + skip_grams_window, L)

            context_left = words[left_start:left_end]
            context_right = words[right_start:right_end]

            for j, cword in enumerate(context_left[::-1]):
                tcm[word][cword] += 1. / (1 + j)
            for j, cword in enumerate(context_right):
                tcm[word][cword] += 1. / (1 + j)

    np.fill_diagonal(tcm, 0)

    return tcm


# Implementation of these two functions is borrowed from
# https://github.com/erwtokritos/keras-glove/blob/43ce3a262a517e2c7aed04f1726bc7ea049fd031/app/models.py
# with modifications and debugging

def create_loss(a=0.75, x_max=100):
    @tf.function
    def custom_loss(y_true, y_pred):
        """
        This is GloVe's loss function
        :param y_true: The actual values, in our case the 'observed' X_ij co-occurrence values
        :param y_pred: The predicted (log-)co-occurrences from the model
        :return: The loss associated with this batch
        """
        return K.sum(K.pow(K.clip(y_true / x_max, 0.0, 1.0), a) * K.square(y_pred - K.log(1 + y_true)), axis=-1)

    return custom_loss


def glove_model(vocab_size, vector_dim):
    """
    A Keras implementation of the GloVe architecture
    :param vocab_size: The number of distinct words
    :param vector_dim: The vector dimension of each word
    :return:
    """
    input_target = tf.keras.layers.Input((1,), name='central_word_id')
    input_context = tf.keras.layers.Input((1,), name='context_word_id')

    central_embedding = tf.keras.layers.Embedding(vocab_size, vector_dim, input_length=1, name='central_embeddings')
    central_bias = tf.keras.layers.Embedding(vocab_size, 1, input_length=1, name='central_biases')

    context_embedding = tf.keras.layers.Embedding(vocab_size, vector_dim, input_length=1, name='context_embeddings')
    context_bias = tf.keras.layers.Embedding(vocab_size, 1, input_length=1, name='context_biases')

    vector_target = central_embedding(input_target)
    vector_context = context_embedding(input_context)

    bias_target = central_bias(input_target)
    bias_context = context_bias(input_context)

    dot_product = tf.keras.layers.Dot(axes=-1)([vector_target, vector_context])
    dot_product = tf.keras.layers.Reshape((1,))(dot_product)
    bias_target = tf.keras.layers.Reshape((1,))(bias_target)
    bias_context = tf.keras.layers.Reshape((1,))(bias_context)

    prediction = tf.keras.layers.Add()([dot_product, bias_target, bias_context])

    model = tf.keras.models.Model(inputs=[input_target, input_context], outputs=prediction)

    return model


def data_set_generator_from_tcm(tcm):
    def _glove_gen():
        with np.nditer(tcm, flags=['multi_index']) as it:
            for item in it:
                first_id, second_id = it.multi_index
                if first_id != second_id:
                    yield {'central_word_id': [first_id], 'context_word_id': [second_id]}, item

    return _glove_gen


def get_vectors_from_model(model):
    # wv_context = glove$components
    # word_vectors = wv_main + t(wv_context)

    wv_context = wv_main = None
    for layer in model.layers:
        if layer.name == 'context_embeddings':
            wv_context = layer.weights[0]
        elif layer.name == 'central_embeddings':
            wv_main = layer.weights[0]

    assert wv_main is not None and wv_main is not None

    return wv_context + wv_main


def create_training_dataset(batch_size, tcm):
    dataset = tf.data.Dataset.from_generator(
        data_set_generator_from_tcm(tcm),
        output_types=({
                          'central_word_id': tf.int32, 'context_word_id': tf.int32
                      }, tf.float64),
        output_shapes=({
                           'central_word_id': tf.TensorShape([1, ]), 'context_word_id': tf.TensorShape([1, ])
                       }, tf.TensorShape([]))
    )
    dataset = dataset.batch(batch_size)
    return dataset

In [None]:
def train_embeddings(*,
                     tcm: np.array,
                     vector_dims: int,
                     batch_size: int = 2048,
                     plot_model: bool = False,
                     model_summary: bool = False,
                     x_max: int = 50,
                     **kwargs):
    optimizer = kwargs.pop('optimizer', tf.keras.optimizers.Adam())
    save_model_name = kwargs.pop('save_model', None)
    save_embedding_name = kwargs.pop('save_embedding', None)

    vocab_size = tcm.shape[0]

    dataset = create_training_dataset(batch_size, tcm)

    model = glove_model(
        vocab_size=vocab_size,
        vector_dim=vector_dims
    )
    model.compile(loss=create_loss(x_max=x_max), optimizer=optimizer)
    if model_summary:
        model.summary()
    if plot_model:
        tf.keras.utils.plot_model(model, show_shapes=True)

    model.fit(dataset, **kwargs)
    if save_model_name is not None:
        model.save(save_model_name)

    embeddings = get_vectors_from_model(model)

    if save_embedding_name is not None:
        np.save(save_embedding_name, embeddings)

    return model, embeddings


def load_pretrained_embeddings(**kwargs):
    model_address = kwargs.pop('model_address', None)
    embedding_address = kwargs.pop('embedding_address', None)

    # Split xor into two expressions
    assert not (model_address is None and embedding_address is None)
    assert model_address is None or embedding_address is None

    if embedding_address is not None:
        return np.load(embedding_address)

    if model_address is not None:
        return tf.keras.models.load_model(model_address)