In [1]:
import itertools
import os

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
def articles():
    with open('page_revisions_text', 'rb') as text_file:
        pending_article_data = b''
        while True:
            data = text_file.read(1024 ** 2)
            if len(data) == 0:
                break

            articles = data.split(b'\0')
            articles[0] = pending_article_data + articles[0]
            for index, article in enumerate(articles):
                if index + 1 == len(articles):
                    pending_article_data = article
                else:
                    yield article

        print(pending_article_data)
        if len(pending_article_data) != 0:
            yield pending_article_data

In [3]:
subword_text_encoder = tfds.features.text.SubwordTextEncoder.load_from_file('vocab_4096')

In [4]:
BATCH_SIZE = 256
BATCHED_ITEM_LENGTH = 128
BUFFER_SIZE = 1024
TYPE=np.int16

def articles_generator():
    for index, article in enumerate(itertools.islice(articles(), 0, 512)):
        yield np.array(subword_text_encoder.encode(article + b'\0'), dtype=TYPE)

    # Pad the article count to the batch size
    # We do this to ensure that no data is dropped
    index += 1
    while index % BATCH_SIZE != 0:
        yield np.array([0], dtype=TYPE)
        index += 1

def subbatches():
    dataset = tf.data.Dataset.from_generator(articles_generator, output_types=TYPE)
    dataset = dataset.shuffle(BUFFER_SIZE)
    dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=([None]), drop_remainder=True)
    dataset = dataset.shuffle(2000)

    for batch in dataset.as_numpy_iterator():
        remaining = batch
        while remaining.shape[1] > 1:
            yield remaining[:, :BATCHED_ITEM_LENGTH + 1]
            remaining = remaining[:, BATCHED_ITEM_LENGTH:]

dataset = tf.data.Dataset.from_generator(subbatches, output_types=TYPE, output_shapes=(BATCH_SIZE, None))
dataset = dataset.map(lambda batch: (batch[:, :-1], batch[:, 1:]))

dataset

<MapDataset shapes: ((256, None), (256, None)), types: (tf.int16, tf.int16)>

In [5]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    return tf.keras.Sequential([
        tf.keras.layers.Masking(mask_value=0, batch_input_shape=[batch_size, None]),
        tf.keras.layers.Embedding(vocab_size, embedding_dim),
        tf.keras.layers.LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.Dense(vocab_size),
    ])

In [6]:
from tensorflow.python.eager import context
from tensorflow.python.keras import backend_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as variables_module

epsilon = backend_config.epsilon

def get_graph():
    if context.executing_eagerly():
        global _GRAPH
        if _GRAPH is None:
            _GRAPH = func_graph.FuncGraph('keras_graph')
        return _GRAPH
    else:
        return ops.get_default_graph()

def flatten(x):
    return array_ops.reshape(x, [-1])

def cast(x, dtype):
    return math_ops.cast(x, dtype)
  
def _is_symbolic_tensor(x):
    return tensor_util.is_tensor(x) and not isinstance(x, ops.EagerTensor)

# This is based around the `sparse_categorical_crossentropy` implementation in Keras:
# https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/backend.py#L4507-L4582
def loss(target, output, from_logits=False, axis=-1):
    if not from_logits:
        if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or output.op.type != 'Softmax'):
            epsilon_ = constant_op.constant(epsilon(), dtype=output.dtype.base_dtype)
            output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
            output = math_ops.log(output)
        else:
            # When softmax activation function is used for output operation, we
            # use logits from the softmax function directly to compute loss in order
            # to prevent collapsing zero when training.
            # See b/117284466
            assert len(output.op.inputs) == 1
            output = output.op.inputs[0]
  
    if isinstance(output.shape, (tuple, list)):
        output_rank = len(output.shape)
    else:
        output_rank = output.shape.ndims

    if output_rank is not None:
        axis %= output_rank
        if axis != output_rank - 1:
            permutation = list(itertools.chain(range(axis), range(axis + 1, output_rank), [axis]))
            output = array_ops.transpose(output, perm=permutation)
    elif axis != -1:
        raise ValueError(
            'Cannot compute sparse categorical crossentropy with `axis={}` on an '
            'output tensor with unknown rank'.format(axis))
  
    target = cast(target, 'int64')
  
    # Try to adjust the shape so that rank of labels = rank of logits - 1.
    output_shape = array_ops.shape_v2(output)
    target_rank = target.shape.ndims
  
    update_shape = (target_rank is not None and output_rank is not None and target_rank != output_rank - 1)
    if update_shape:
        target = flatten(target)
        output = array_ops.reshape(output, [-1, output_shape[-1]])
  
    if __builtins__.any([_is_symbolic_tensor(v) for v in [target, output]]):
        with get_graph().as_default():
            res = huffman_code_lengths(labels=target, logits=output)
    else:
        res = huffman_code_lengths(labels=target, logits=output)
  
    if update_shape and output_rank >= 3:
        # If our output includes timesteps or spatial dimensions we need to reshape
        return array_ops.reshape(res, output_shape[:-1])
    else:
        return res

def log2(x):
    return tf.math.log(x) / tf.math.log(2.0)
    
def huffman_code_lengths(labels, logits):
    category_count = logits.shape[-1] or 0
    
    return tf.reduce_sum(-log2(tf.one_hot(labels, depth=category_count) * tf.nn.softmax(logits) + 0.0001), axis=-1)

In [7]:
def average_batch_length(true_labels, predictions):
    return tf.shape(true_labels)[1]

model = build_model(vocab_size = subword_text_encoder.vocab_size, embedding_dim=512, rnn_units=1024, batch_size=BATCH_SIZE)
model.compile(optimizer='adam', loss=loss, metrics=[average_batch_length])

In [8]:
checkpoint_dir = './training_checkpoints-1' # Directory where the checkpoints will be saved
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") # Name of the checkpoint files

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True)

In [9]:
class ModelStateResetter(tf.keras.callbacks.Callback):
    def __init__(self):
        self.last_total_length = 0

    def on_batch_end(self, batch, logs={}):
        average_batch_length = logs.get('average_batch_length', 0)
        total_length = int(round(average_batch_length * (batch + 1)))
        current_batch_length = total_length - self.last_total_length
        self.last_total_length = total_length
        
        if current_batch_length < BATCHED_ITEM_LENGTH:
            self.model.reset_states()
        
model_state_resetter_callback = ModelStateResetter()

In [10]:
total_epochs = 10

for epoch in range(total_epochs):
    print('Epoch %d/%d' % (epoch + 1, total_epochs))
    model.fit(dataset, callbacks=[checkpoint_callback, model_state_resetter_callback])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [11]:
import ctypes

class Huffman:
    huffman = ctypes.CDLL('x64/Release/huffman')
    
    huffman.create_tree.restype = ctypes.c_void_p
    huffman.destroy_tree.restype = None
    huffman.load_weights.restype = None
    huffman.create_code_string.restype = ctypes.c_char_p
    
    def __init__(self, category_count):
        self.tree = ctypes.c_void_p(self.huffman.create_tree(category_count))

    def __del__(self):
        self.huffman.destroy_tree(self.tree)
        
    def load_weights(self, weights):
        self.huffman.load_weights(self.tree, weights.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
    
    def get_code_length(self, category):
        return self.huffman.get_code_length(self.tree, category)

    def get_code_zero_count(self, category):
        return self.huffman.get_code_zero_count(self.tree, category)

In [12]:
def huffman_archive_size(model, text):
    archived_size = 0
    zeros = 0
    input_eval = np.array([[0]], dtype=TYPE)
    huffman_tree = Huffman(subword_text_encoder.vocab_size)

    text_generated = []

    model.reset_states()

    for index, byte in enumerate(text):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0) # remove the batch dimension

        weights = tf.nn.softmax(predictions[0]).numpy()
        huffman_tree.load_weights(weights)
        zeros += huffman_tree.get_code_zero_count(byte.item())
        archived_size += huffman_tree.get_code_length(byte.item())

        input_eval = tf.expand_dims([byte], 0)
  
    return archived_size, zeros

In [13]:
tf.train.latest_checkpoint(checkpoint_dir)
model = build_model(vocab_size = subword_text_encoder.vocab_size, embedding_dim=512, rnn_units=1024, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))

In [14]:
total_raw = 0
total_compressed = 0

for index, article in enumerate(itertools.islice(articles(), 0, 512)):
    raw = (len(article) + 1) * 8
    encoded_article = np.array(subword_text_encoder.encode(article + b'\0'), dtype=TYPE)
    compressed, _ = huffman_archive_size(model, encoded_article)
    total_raw += raw
    total_compressed += compressed
    print('Article %d:\tCompression: %f\tAvg Compression: %f' % (index, compressed/raw, total_compressed/total_raw))

Article 0:	Compression: 0.611111	Avg Compression: 0.611111
Article 1:	Compression: 0.630952	Avg Compression: 0.625000
Article 2:	Compression: 0.520408	Avg Compression: 0.577982
Article 3:	Compression: 0.658333	Avg Compression: 0.595324
Article 4:	Compression: 0.671429	Avg Compression: 0.610632
Article 5:	Compression: 0.455128	Avg Compression: 0.582160
Article 6:	Compression: 0.580207	Avg Compression: 0.580214
Article 7:	Compression: 0.577703	Avg Compression: 0.580212
Article 8:	Compression: 0.596875	Avg Compression: 0.580223
Article 9:	Compression: 0.517857	Avg Compression: 0.580179
Article 10:	Compression: 0.773026	Avg Compression: 0.580303
Article 11:	Compression: 0.505682	Avg Compression: 0.580248
Article 12:	Compression: 0.531250	Avg Compression: 0.580211
Article 13:	Compression: 0.671053	Avg Compression: 0.580269
Article 14:	Compression: 0.563830	Avg Compression: 0.580256
Article 15:	Compression: 0.671875	Avg Compression: 0.580294
Article 16:	Compression: 0.578571	Avg Compression:

Article 139:	Compression: 0.650873	Avg Compression: 0.596527
Article 140:	Compression: 0.594263	Avg Compression: 0.596378
Article 141:	Compression: 0.600547	Avg Compression: 0.596409
Article 142:	Compression: 0.662625	Avg Compression: 0.597531
Article 143:	Compression: 0.583555	Avg Compression: 0.597400
Article 144:	Compression: 0.714744	Avg Compression: 0.597410
Article 145:	Compression: 0.631118	Avg Compression: 0.598612
Article 146:	Compression: 0.584941	Avg Compression: 0.598387
Article 147:	Compression: 0.612198	Avg Compression: 0.598758
Article 148:	Compression: 0.541552	Avg Compression: 0.595430
Article 149:	Compression: 0.542976	Avg Compression: 0.592121
Article 150:	Compression: 0.667857	Avg Compression: 0.592125
Article 151:	Compression: 0.550688	Avg Compression: 0.591638
Article 152:	Compression: 0.575468	Avg Compression: 0.590159
Article 153:	Compression: 0.558955	Avg Compression: 0.589443
Article 154:	Compression: 0.586957	Avg Compression: 0.589443
Article 155:	Compression

Article 274:	Compression: 0.687978	Avg Compression: 0.606035
Article 275:	Compression: 0.606064	Avg Compression: 0.606035
Article 276:	Compression: 0.528584	Avg Compression: 0.605945
Article 277:	Compression: 0.557722	Avg Compression: 0.605666
Article 278:	Compression: 0.697410	Avg Compression: 0.606194
Article 279:	Compression: 0.589542	Avg Compression: 0.606089
Article 280:	Compression: 0.692708	Avg Compression: 0.606091
Article 281:	Compression: 0.607606	Avg Compression: 0.606122
Article 282:	Compression: 0.660010	Avg Compression: 0.606242
Article 283:	Compression: 0.579454	Avg Compression: 0.605993
Article 284:	Compression: 0.686028	Avg Compression: 0.607032
Article 285:	Compression: 0.467391	Avg Compression: 0.607029
Article 286:	Compression: 0.650000	Avg Compression: 0.607030
Article 287:	Compression: 0.853573	Avg Compression: 0.607402
Article 288:	Compression: 0.590222	Avg Compression: 0.607384
Article 289:	Compression: 0.650000	Avg Compression: 0.607385
Article 290:	Compression

Article 409:	Compression: 0.599699	Avg Compression: 0.604293
Article 410:	Compression: 0.570945	Avg Compression: 0.604219
Article 411:	Compression: 0.513889	Avg Compression: 0.604218
Article 412:	Compression: 0.616620	Avg Compression: 0.604276
Article 413:	Compression: 0.545833	Avg Compression: 0.604275
Article 414:	Compression: 0.637831	Avg Compression: 0.604641
Article 415:	Compression: 0.713388	Avg Compression: 0.604668
Article 416:	Compression: 0.592111	Avg Compression: 0.604594
Article 417:	Compression: 0.618173	Avg Compression: 0.604613
Article 418:	Compression: 0.634697	Avg Compression: 0.604695
Article 419:	Compression: 0.664232	Avg Compression: 0.606073
Article 420:	Compression: 0.574258	Avg Compression: 0.605782
Article 421:	Compression: 0.563915	Avg Compression: 0.605633
Article 422:	Compression: 0.538930	Avg Compression: 0.605596
Article 423:	Compression: 0.633987	Avg Compression: 0.605602
Article 424:	Compression: 0.538725	Avg Compression: 0.605535
Article 425:	Compression