In [1]:
import itertools
import math
import os

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

physical_devices = tf.config.list_physical_devices('GPU')
for physical_device in physical_devices:
    tf.config.experimental.set_memory_growth(physical_device, enable=True)

In [2]:
TYPE=np.int16

subword_text_encoder = tfds.features.text.SubwordTextEncoder.load_from_file('vocab_4096')

class Articles:
    EMPTY_ARTICLE = np.array([], dtype=TYPE) # used for padding
    
    def __init__(self, path):
        with open(path, 'rb') as text_file:
            data = text_file.read()

        self.articles = sorted(set(data.split(b'\0')[:2000]), key=len)
        self._encoded_articles = None

    @property
    def encoded_articles(self):
        if self._encoded_articles == None:
            self._encoded_articles = [np.array(subword_text_encoder.encode(article), dtype=TYPE) for article in self.articles]
        
        return self._encoded_articles

    def articles_generator(self, batch_size = 1, start = 0, end = None):
        end = end or len(self.articles)

        for _ in range(batch_size - ((end - start - 1) % batch_size + 1)):
            yield self.EMPTY_ARTICLE

        for article in itertools.islice(self.encoded_articles, start, end):
            yield article

    def subbatch_generator(self, batch_size, batch_length, start = 0, end = None):
        end = end or len(self.articles)

        dataset = tf.data.Dataset.from_generator(self.articles_generator, args=(batch_size, start, end), output_types=TYPE)
        dataset = dataset.padded_batch(batch_size, padded_shapes=([None]), drop_remainder=True)
        dataset = dataset.shuffle(100)

        for batch in dataset.as_numpy_iterator():
            remaining = batch
            while remaining.shape[1] > batch_length + 1:
                yield remaining[:, :batch_length + 1]
                remaining = remaining[:, batch_length:]

            if remaining.shape[1] == batch_length + 1:
                yield remaining
                yield np.zeros((batch_size, batch_length + 1), dtype=TYPE)
            else:
                yield np.hstack([remaining, np.zeros([batch_size, batch_length - remaining.shape[1] + 1])])

    def steps(self, batch_size, batch_length):
        articles = self.articles_generator(batch_size, batch_length)
        return sum(math.ceil(len(article) / batch_length) for i, article in enumerate(articles) if (i + 1) % batch_size == 0)

    def dataset(self, batch_size, batch_length, start = 0, end = None):
        end = end or len(self.articles)

        dataset = tf.data.Dataset.from_generator(self.subbatch_generator, args=(batch_size, batch_length, start, end), output_types=TYPE, output_shapes=(batch_size, batch_length + 1))
        return dataset.map(lambda batch: (batch[:, :-1], batch[:, 1:]))

In [3]:
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

def average_final_batch_ratio(true_labels, predictions):
    return 0 ** tf.math.abs(true_labels[-1, -1])

class ModelStateResetter(tf.keras.callbacks.Callback):
    def __init__(self):
        pass

    def on_epoch_begin(self, epoch, logs=None):
        self.last_final_batch_count = 0
        
    def on_batch_end(self, batch, logs={}):
        average_final_batch_ratio = logs.get('average_final_batch_ratio', 0)
        final_batch_count = int(round(average_final_batch_ratio * (batch + 1)))
        is_final = final_batch_count - self.last_final_batch_count
        self.last_final_batch_count = final_batch_count
        
        if is_final:
            self.model.reset_states()

class Model:
    def __init__(self, articles, checkpoint_dir, vocab_size, embedding_dim, rnn_units):
        self._articles = articles
        self._batch_size = None
        self._batched_item_length = None
        self._training_model = None
        self._predicting_model = None
        self._vocab_size = vocab_size
        self._embedding_dim = embedding_dim
        self._rnn_units = rnn_units

        self._checkpoint_dir = checkpoint_dir
        self._checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") # Name of the checkpoint files

    def training_model(self, batch_size, batched_item_length):
        if self._training_model == None or batch_size != self._batch_size or batched_item_length != self._batched_item_length:
            self._batch_size = batch_size
            self._batched_item_length = batched_item_length
            self._training_model = tf.keras.Sequential([
                tf.keras.layers.Masking(mask_value=0, batch_input_shape=[batch_size, batched_item_length]),
                tf.keras.layers.Embedding(self._vocab_size, self._embedding_dim),
                tf.keras.layers.GRU(self._rnn_units // 4 * 4, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
                tf.keras.layers.GRU(self._rnn_units // 4 * 2, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
                tf.keras.layers.GRU(self._rnn_units // 4 * 1, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
                tf.keras.layers.Dense(self._vocab_size),
            ])

            if os.path.isdir(self._checkpoint_dir):
                self._training_model.load_weights(tf.train.latest_checkpoint(self._checkpoint_dir))

            self._training_model.compile(optimizer='adam', loss=loss, metrics=[average_final_batch_ratio])
            self._predicting_model = None
        
        return self._training_model

    @property
    def callbacks(self):
        checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=self._checkpoint_prefix, save_weights_only=True)
        model_state_resetter_callback = ModelStateResetter()
        
        return [checkpoint_callback, model_state_resetter_callback]
    
    def train(self, batch_size, batched_item_length, epochs=1):
        dataset = self._articles.dataset(batch_size, batched_item_length)

        model = self.training_model(batch_size, batched_item_length)

        model.fit(dataset, epochs=epochs, callbacks=self.callbacks)
    
    @property
    def predicting_model(self):
        if self._predicting_model == None:
            self._predicting_model = tf.keras.Sequential([
                tf.keras.layers.Masking(mask_value=0, batch_input_shape=[1, 1]),
                tf.keras.layers.Embedding(self._vocab_size, self._embedding_dim),
                tf.keras.layers.GRU(self._rnn_units // 4 * 4, stateful=True, return_sequences=True),
                tf.keras.layers.GRU(self._rnn_units // 4 * 2, stateful=True, return_sequences=True),
                tf.keras.layers.GRU(self._rnn_units // 4 * 1, stateful=True, return_sequences=True),
                tf.keras.layers.Dense(self._vocab_size),
            ])
            
            self._predicting_model.load_weights(tf.train.latest_checkpoint(self._checkpoint_dir))
            self._training_model = None
        
        return self._predicting_model
    
    def predict(self, input_eval):
        return self.predicting_model(input_eval)

In [4]:
articles = Articles('page_revisions_text')

In [5]:
steps = 13

for i in range(steps):
    batch_size = 3 * 2**i
    batch_item_length = 3 * 2**(steps - i - 1)
    count = articles.steps(batch_size, batch_item_length)
    print("batch size: %6d\t batch item length: %4d\tsteps per epoch: %6d" % (batch_size, batch_item_length, count))

batch size:      3	 batch item length: 12288	steps per epoch:      0
batch size:      6	 batch item length: 6144	steps per epoch:      0
batch size:     12	 batch item length: 3072	steps per epoch:      0
batch size:     24	 batch item length: 1536	steps per epoch:    108
batch size:     48	 batch item length:  768	steps per epoch:    166
batch size:     96	 batch item length:  384	steps per epoch:    193
batch size:    192	 batch item length:  192	steps per epoch:    249
batch size:    384	 batch item length:   96	steps per epoch:    374
batch size:    768	 batch item length:   48	steps per epoch:    645
batch size:   1536	 batch item length:   24	steps per epoch:   1215
batch size:   3072	 batch item length:   12	steps per epoch:   2428
batch size:   6144	 batch item length:    6	steps per epoch:   4855
batch size:  12288	 batch item length:    3	steps per epoch:   9709


In [6]:
model = Model(articles, './training_checkpoints-9',
              vocab_size = subword_text_encoder.vocab_size,
              embedding_dim=32,
              rnn_units=1536)

In [14]:
model.train(192, 192, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [15]:
model.train(192, 192, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [8]:
import ctypes

class Huffman:
    huffman = ctypes.CDLL('x64/Release/huffman')
    
    huffman.create_tree.restype = ctypes.c_void_p
    huffman.destroy_tree.restype = None
    huffman.load_weights.restype = None
    huffman.create_code_string.restype = ctypes.c_char_p
    
    def __init__(self, category_count):
        self.tree = ctypes.c_void_p(self.huffman.create_tree(category_count))

    def __del__(self):
        self.huffman.destroy_tree(self.tree)
        
    def load_weights(self, weights):
        self.huffman.load_weights(self.tree, weights.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
    
    def get_code_length(self, category):
        return self.huffman.get_code_length(self.tree, category)

    def get_code_zero_count(self, category):
        return self.huffman.get_code_zero_count(self.tree, category)

In [9]:
def huffman_archive_size(model, text):
    archived_size = 0
    zeros = 0
    input_eval = np.array([[0]], dtype=TYPE)
    huffman_tree = Huffman(subword_text_encoder.vocab_size)

    text_generated = []

    model.predicting_model.reset_states()

    for index, byte in enumerate(text):
        predictions = model.predict(input_eval)
        predictions = tf.squeeze(predictions, 0) # remove the batch dimension

        weights = tf.nn.softmax(predictions[0]).numpy()
        huffman_tree.load_weights(weights)
        zeros += huffman_tree.get_code_zero_count(byte.item())
        archived_size += huffman_tree.get_code_length(byte.item())

        input_eval = tf.expand_dims([byte], 0)
  
    return archived_size, zeros

In [18]:
total_raw = 0
total_compressed = 0

for index, encoded_article in enumerate(articles.articles_generator(1)):
    if index % 10 == 0:
        article = subword_text_encoder.decode(encoded_article)
        raw = len(article) * 8
        if raw == 0:
            continue
        compressed, _ = huffman_archive_size(model, encoded_article)
        total_raw += raw
        total_compressed += compressed

        print('Article %d:\tLength: %d\tCompression: %f\tAvg Compression: %f' % (index, raw, compressed/raw, total_compressed/total_raw))

Article 0:	Length: 128	Compression: 0.507812	Avg Compression: 0.507812
Article 10:	Length: 144	Compression: 0.604167	Avg Compression: 0.558824
Article 20:	Length: 152	Compression: 0.368421	Avg Compression: 0.490566
Article 30:	Length: 160	Compression: 0.593750	Avg Compression: 0.518836
Article 40:	Length: 160	Compression: 0.462500	Avg Compression: 0.506720
Article 50:	Length: 168	Compression: 0.494048	Avg Compression: 0.504386
Article 60:	Length: 168	Compression: 0.511905	Avg Compression: 0.505556
Article 70:	Length: 176	Compression: 0.471591	Avg Compression: 0.500796
Article 80:	Length: 176	Compression: 0.494318	Avg Compression: 0.500000
Article 90:	Length: 184	Compression: 0.505435	Avg Compression: 0.500619
Article 100:	Length: 184	Compression: 0.456522	Avg Compression: 0.496111
Article 110:	Length: 192	Compression: 0.401042	Avg Compression: 0.486948
Article 120:	Length: 192	Compression: 0.458333	Avg Compression: 0.484432
Article 130:	Length: 200	Compression: 0.695000	Avg Compression

Article 1110:	Length: 35264	Compression: 0.524473	Avg Compression: 0.524366
Article 1120:	Length: 35736	Compression: 0.527871	Avg Compression: 0.524486
Article 1130:	Length: 36448	Compression: 0.573200	Avg Compression: 0.526141
Article 1140:	Length: 37448	Compression: 0.503685	Avg Compression: 0.525384
Article 1150:	Length: 38424	Compression: 0.536748	Avg Compression: 0.525764
Article 1160:	Length: 38680	Compression: 0.600155	Avg Compression: 0.528187
Article 1170:	Length: 40216	Compression: 0.505868	Avg Compression: 0.527456
Article 1180:	Length: 41480	Compression: 0.606075	Avg Compression: 0.530025
Article 1190:	Length: 42616	Compression: 0.502534	Avg Compression: 0.529132
Article 1200:	Length: 43488	Compression: 0.568824	Avg Compression: 0.530405
Article 1210:	Length: 45080	Compression: 0.486979	Avg Compression: 0.529007
Article 1220:	Length: 46864	Compression: 0.505868	Avg Compression: 0.528258
Article 1230:	Length: 48608	Compression: 0.493334	Avg Compression: 0.527123
Article 1240

In [None]:
model.train(192, 192, epochs=20)

Epoch 1/20
Epoch 2/20
Epoch 3/20


In [7]:
model.train(192, 192, epochs=17)

Epoch 1/17
Epoch 2/17
Epoch 3/17
Epoch 4/17
Epoch 5/17
Epoch 6/17
Epoch 7/17
Epoch 8/17
  1/271 [..............................] - ETA: 50s

InternalError: Failed copying input tensor from /job:localhost/replica:0/task:0/device:GPU:0 to /job:localhost/replica:0/task:0/device:CPU:0 in order to run Identity: GPU sync failed [Op:Identity]

In [7]:
model.train(192, 192, epochs=9)

Epoch 1/9
Epoch 2/9
Epoch 3/9
Epoch 4/9
Epoch 5/9
Epoch 6/9
Epoch 7/9
Epoch 8/9
Epoch 9/9


In [11]:
total_raw = 0
total_compressed = 0

for index, encoded_article in enumerate(articles.articles_generator(1)):
    if index % 10 == 0:
        article = subword_text_encoder.decode(encoded_article)
        raw = len(article) * 8
        if raw == 0:
            continue
        compressed, _ = huffman_archive_size(model, encoded_article)
        total_raw += raw
        total_compressed += compressed

        print('Article %d:\tLength: %d\tCompression: %f\tAvg Compression: %f' % (index, raw, compressed/raw, total_compressed/total_raw))

Article 0:	Length: 128	Compression: 0.406250	Avg Compression: 0.406250
Article 10:	Length: 144	Compression: 0.368056	Avg Compression: 0.386029
Article 20:	Length: 152	Compression: 0.394737	Avg Compression: 0.389151
Article 30:	Length: 160	Compression: 0.331250	Avg Compression: 0.373288
Article 40:	Length: 152	Compression: 0.605263	Avg Compression: 0.421196
Article 50:	Length: 168	Compression: 0.386905	Avg Compression: 0.414823
Article 60:	Length: 168	Compression: 0.386905	Avg Compression: 0.410448
Article 70:	Length: 176	Compression: 0.363636	Avg Compression: 0.403846
Article 80:	Length: 176	Compression: 0.380682	Avg Compression: 0.400983
Article 90:	Length: 184	Compression: 0.364130	Avg Compression: 0.396766
Article 100:	Length: 184	Compression: 0.494565	Avg Compression: 0.406808
Article 110:	Length: 192	Compression: 0.343750	Avg Compression: 0.400706
Article 120:	Length: 192	Compression: 0.411458	Avg Compression: 0.401654
Article 130:	Length: 200	Compression: 0.410000	Avg Compression

Article 1110:	Length: 35264	Compression: 0.427093	Avg Compression: 0.427696
Article 1120:	Length: 35736	Compression: 0.422347	Avg Compression: 0.427512
Article 1130:	Length: 36448	Compression: 0.473359	Avg Compression: 0.429069
Article 1140:	Length: 37696	Compression: 0.384683	Avg Compression: 0.427563
Article 1150:	Length: 37904	Compression: 0.403994	Avg Compression: 0.426785
Article 1160:	Length: 38680	Compression: 0.478206	Avg Compression: 0.428460
Article 1170:	Length: 40216	Compression: 0.410235	Avg Compression: 0.427863
Article 1180:	Length: 41480	Compression: 0.504171	Avg Compression: 0.430357
Article 1190:	Length: 42616	Compression: 0.404261	Avg Compression: 0.429509
Article 1200:	Length: 43488	Compression: 0.476959	Avg Compression: 0.431031
Article 1210:	Length: 45080	Compression: 0.389663	Avg Compression: 0.429700
Article 1220:	Length: 46864	Compression: 0.401844	Avg Compression: 0.428798
Article 1230:	Length: 48608	Compression: 0.386397	Avg Compression: 0.427420
Article 1240

In [7]:
model.train(192, 192, epochs=20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [10]:
total_raw = 0
total_compressed = 0

for index, encoded_article in enumerate(articles.articles_generator(1)):
    if index % 10 == 0:
        article = subword_text_encoder.decode(encoded_article)
        raw = len(article) * 8
        if raw == 0:
            continue
        compressed, _ = huffman_archive_size(model, encoded_article)
        total_raw += raw
        total_compressed += compressed

        print('Article %d:\tLength: %d\tCompression: %f\tAvg Compression: %f' % (index, raw, compressed/raw, total_compressed/total_raw))

Article 0:	Length: 128	Compression: 0.445312	Avg Compression: 0.445312
Article 10:	Length: 144	Compression: 0.409722	Avg Compression: 0.426471
Article 20:	Length: 152	Compression: 0.473684	Avg Compression: 0.443396
Article 30:	Length: 160	Compression: 0.468750	Avg Compression: 0.450342
Article 40:	Length: 160	Compression: 0.581250	Avg Compression: 0.478495
Article 50:	Length: 168	Compression: 0.458333	Avg Compression: 0.474781
Article 60:	Length: 168	Compression: 0.398810	Avg Compression: 0.462963
Article 70:	Length: 176	Compression: 0.397727	Avg Compression: 0.453822
Article 80:	Length: 176	Compression: 0.500000	Avg Compression: 0.459497
Article 90:	Length: 184	Compression: 0.554348	Avg Compression: 0.470297
Article 100:	Length: 184	Compression: 0.489130	Avg Compression: 0.472222
Article 110:	Length: 192	Compression: 0.411458	Avg Compression: 0.466365
Article 120:	Length: 192	Compression: 0.354167	Avg Compression: 0.456502
Article 130:	Length: 200	Compression: 0.525000	Avg Compression

Article 1110:	Length: 35264	Compression: 0.410163	Avg Compression: 0.411659
Article 1120:	Length: 35736	Compression: 0.392070	Avg Compression: 0.410984
Article 1130:	Length: 36448	Compression: 0.446828	Avg Compression: 0.412201
Article 1140:	Length: 37448	Compression: 0.383332	Avg Compression: 0.411228
Article 1150:	Length: 38424	Compression: 0.410603	Avg Compression: 0.411207
Article 1160:	Length: 38680	Compression: 0.533195	Avg Compression: 0.415180
Article 1170:	Length: 40216	Compression: 0.398051	Avg Compression: 0.414619
Article 1180:	Length: 41480	Compression: 0.487560	Avg Compression: 0.417002
Article 1190:	Length: 42616	Compression: 0.392341	Avg Compression: 0.416201
Article 1200:	Length: 43488	Compression: 0.469279	Avg Compression: 0.417904
Article 1210:	Length: 45080	Compression: 0.381655	Avg Compression: 0.416737
Article 1220:	Length: 46864	Compression: 0.393436	Avg Compression: 0.415983
Article 1230:	Length: 48608	Compression: 0.372799	Avg Compression: 0.414580
Article 1240