In [1]:
import itertools
import math
import os

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

physical_devices = tf.config.list_physical_devices('GPU')
for physical_device in physical_devices:
    tf.config.experimental.set_memory_growth(physical_device, enable=True)

In [2]:
TYPE=np.int16

subword_text_encoder = tfds.features.text.SubwordTextEncoder.load_from_file('vocab_4096')

class Articles:
    EMPTY_ARTICLE = np.array([], dtype=TYPE) # used for padding
    
    def __init__(self, path):
        with open(path, 'rb') as text_file:
            data = text_file.read()

        self.articles = sorted(set(data.split(b'\0')), key=len)
        self._encoded_articles = None

    @property
    def encoded_articles(self):
        if self._encoded_articles == None:
            self._encoded_articles = [np.array(subword_text_encoder.encode(article), dtype=TYPE) for article in self.articles]
        
        return self._encoded_articles

    def articles_generator(self, batch_size = 1, start = 0, end = None):
        end = end or len(self.articles)

        for _ in range(batch_size - ((end - start - 1) % batch_size + 1)):
            yield self.EMPTY_ARTICLE

        for article in itertools.islice(self.encoded_articles, start, end):
            yield article

    def subbatch_generator(self, batch_size, batch_length, start = 0, end = None):
        end = end or len(self.articles)

        dataset = tf.data.Dataset.from_generator(self.articles_generator, args=(batch_size, start, end), output_types=TYPE)
        dataset = dataset.padded_batch(batch_size, padded_shapes=([None]), drop_remainder=True)
        dataset = dataset.shuffle(100)

        for batch in dataset.as_numpy_iterator():
            remaining = batch
            while remaining.shape[1] > batch_length + 1:
                yield remaining[:, :batch_length + 1]
                remaining = remaining[:, batch_length:]

            if remaining.shape[1] == batch_length + 1:
                yield remaining
                yield np.zeros((batch_size, batch_length + 1), dtype=TYPE)
            else:
                yield np.hstack([remaining, np.zeros([batch_size, batch_length - remaining.shape[1] + 1])])

    def steps(self, batch_size, batch_length):
        articles = self.articles_generator(batch_size, batch_length)
        return sum(math.ceil(len(article) / batch_length) for i, article in enumerate(articles) if (i + 1) % batch_size == 0)

    def dataset(self, batch_size, batch_length, start = 0, end = None):
        end = end or len(self.articles)

        dataset = tf.data.Dataset.from_generator(self.subbatch_generator, args=(batch_size, batch_length, start, end), output_types=TYPE, output_shapes=(batch_size, batch_length + 1))
        return dataset.map(lambda batch: (batch[:, :-1], batch[:, 1:]))

In [3]:
articles = Articles('page_revisions_text')

In [9]:
steps = 13

for i in range(steps):
    batch_size = 3 * 2**i
    batch_item_length = 3 * 2**(steps - i - 1)
    count = articles.steps(batch_size, batch_item_length)
    print("batch size: %6d\t batch item length: %4d\tsteps per epoch: %6d" % (batch_size, batch_item_length, count))

batch size:      3	 batch item length: 12288	steps per epoch:  70297
batch size:      6	 batch item length: 6144	steps per epoch:  37955
batch size:     12	 batch item length: 3072	steps per epoch:  21974
batch size:     24	 batch item length: 1536	steps per epoch:  14546
batch size:     48	 batch item length:  768	steps per epoch:  11792
batch size:     96	 batch item length:  384	steps per epoch:  10574
batch size:    192	 batch item length:  192	steps per epoch:  10573
batch size:    384	 batch item length:   96	steps per epoch:  11676
batch size:    768	 batch item length:   48	steps per epoch:  14588
batch size:   1536	 batch item length:   24	steps per epoch:  20735
batch size:   3072	 batch item length:   12	steps per epoch:  33286
batch size:   6144	 batch item length:    6	steps per epoch:  58607
batch size:  12288	 batch item length:    3	steps per epoch: 109444


In [10]:
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

def average_final_batch_ratio(true_labels, predictions):
    return 0 ** tf.math.abs(true_labels[-1, -1])

class ModelStateResetter(tf.keras.callbacks.Callback):
    def __init__(self):
        pass

    def on_epoch_begin(self, epoch, logs=None):
        self.last_final_batch_count = 0
        
    def on_batch_end(self, batch, logs={}):
        average_final_batch_ratio = logs.get('average_final_batch_ratio', 0)
        final_batch_count = int(round(average_final_batch_ratio * (batch + 1)))
        is_final = final_batch_count - self.last_final_batch_count
        self.last_final_batch_count = final_batch_count
        
        if is_final:
            self.model.reset_states()

class Model:
    def __init__(self, articles, checkpoint_dir, vocab_size, embedding_dim, rnn_units):
        self._articles = articles
        self._batch_size = None
        self._batched_item_length = None
        self._training_model = None
        self._predicting_model = None
        self._vocab_size = vocab_size
        self._embedding_dim = embedding_dim
        self._rnn_units = rnn_units

        self._checkpoint_dir = checkpoint_dir
        self._checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") # Name of the checkpoint files

    def training_model(self, batch_size, batched_item_length):
        if self._training_model == None or batch_size != self._batch_size or batched_item_length != self._batched_item_length:
            self._batch_size = batch_size
            self._batched_item_length = batched_item_length
            self._training_model = tf.keras.Sequential([
                tf.keras.layers.Masking(mask_value=0, batch_input_shape=[batch_size, batched_item_length]),
                tf.keras.layers.Embedding(self._vocab_size, self._embedding_dim),
                tf.keras.layers.GRU(self._rnn_units // 4 * 4, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
                tf.keras.layers.GRU(self._rnn_units // 4 * 2, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
                tf.keras.layers.GRU(self._rnn_units // 4 * 1, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
                tf.keras.layers.Dense(self._vocab_size),
            ])

            if os.path.isdir(self._checkpoint_dir):
                self._training_model.load_weights(tf.train.latest_checkpoint(self._checkpoint_dir))

            self._training_model.compile(optimizer='adam', loss=loss, metrics=[average_final_batch_ratio])
            self._predicting_model = None
        
        return self._training_model

    @property
    def callbacks(self):
        checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=self._checkpoint_prefix, save_weights_only=True)
        model_state_resetter_callback = ModelStateResetter()
        
        return [checkpoint_callback, model_state_resetter_callback]
    
    def train(self, batch_size, batched_item_length, epochs=1):
        dataset = self._articles.dataset(batch_size, batched_item_length)

        model = self.training_model(batch_size, batched_item_length)

        model.fit(dataset, epochs=epochs, callbacks=self.callbacks)
    
    @property
    def predicting_model(self):
        if self._predicting_model == None:
            self._predicting_model = tf.keras.Sequential([
                tf.keras.layers.Masking(mask_value=0, batch_input_shape=[1, 1]),
                tf.keras.layers.Embedding(self._vocab_size, self._embedding_dim),
                tf.keras.layers.GRU(self._rnn_units // 4 * 4, stateful=True, return_sequences=True),
                tf.keras.layers.GRU(self._rnn_units // 4 * 2, stateful=True, return_sequences=True),
                tf.keras.layers.GRU(self._rnn_units // 4 * 1, stateful=True, return_sequences=True),
                tf.keras.layers.Dense(self._vocab_size),
            ])
            
            self._predicting_model.load_weights(tf.train.latest_checkpoint(self._checkpoint_dir))
            self._training_model = None
        
        return self._predicting_model
    
    def predict(self, input_eval):
        return self.predicting_model(input_eval)

In [12]:
model = Model(articles, './training_checkpoints-8',
              vocab_size = subword_text_encoder.vocab_size,
              embedding_dim=32,
              rnn_units=1536)

In [14]:
model.training_model(192, 192).summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
masking (Masking)            (192, 192)                0         
_________________________________________________________________
embedding (Embedding)        (192, 192, 32)            129536    
_________________________________________________________________
gru (GRU)                    (192, 192, 1536)          7234560   
_________________________________________________________________
gru_1 (GRU)                  (192, 192, 768)           5313024   
_________________________________________________________________
gru_2 (GRU)                  (192, 192, 384)           1329408   
_________________________________________________________________
dense (Dense)                (192, 192, 4048)          1558480   
Total params: 15,565,008
Trainable params: 15,565,008
Non-trainable params: 0
____________________________________________

In [15]:
model.train(192, 192)



In [16]:
model.train(192, 192, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [17]:
model.train(192, 192, epochs=2)

Epoch 1/2
Epoch 2/2


In [18]:
import ctypes

class Huffman:
    huffman = ctypes.CDLL('x64/Release/huffman')
    
    huffman.create_tree.restype = ctypes.c_void_p
    huffman.destroy_tree.restype = None
    huffman.load_weights.restype = None
    huffman.create_code_string.restype = ctypes.c_char_p
    
    def __init__(self, category_count):
        self.tree = ctypes.c_void_p(self.huffman.create_tree(category_count))

    def __del__(self):
        self.huffman.destroy_tree(self.tree)
        
    def load_weights(self, weights):
        self.huffman.load_weights(self.tree, weights.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
    
    def get_code_length(self, category):
        return self.huffman.get_code_length(self.tree, category)

    def get_code_zero_count(self, category):
        return self.huffman.get_code_zero_count(self.tree, category)

In [19]:
def huffman_archive_size(model, text):
    archived_size = 0
    zeros = 0
    input_eval = np.array([[0]], dtype=TYPE)
    huffman_tree = Huffman(subword_text_encoder.vocab_size)

    text_generated = []

    model.predicting_model.reset_states()

    for index, byte in enumerate(text):
        predictions = model.predict(input_eval)
        predictions = tf.squeeze(predictions, 0) # remove the batch dimension

        weights = tf.nn.softmax(predictions[0]).numpy()
        huffman_tree.load_weights(weights)
        zeros += huffman_tree.get_code_zero_count(byte.item())
        archived_size += huffman_tree.get_code_length(byte.item())

        input_eval = tf.expand_dims([byte], 0)
  
    return archived_size, zeros

In [20]:
total_raw = 0
total_compressed = 0

for index, encoded_article in enumerate(articles.articles_generator(1)):
    if index % 1000 == 0:
        article = subword_text_encoder.decode(encoded_article)
        raw = len(article) * 8
        if raw == 0:
            continue
        compressed, _ = huffman_archive_size(model, encoded_article)
        total_raw += raw
        total_compressed += compressed

        print('Article %d:\tLength: %d\tCompression: %f\tAvg Compression: %f' % (index, raw, compressed/raw, total_compressed/total_raw))

Article 1000:	Length: 144	Compression: 0.562500	Avg Compression: 0.562500
Article 2000:	Length: 152	Compression: 0.467105	Avg Compression: 0.513514
Article 3000:	Length: 160	Compression: 0.506250	Avg Compression: 0.510965
Article 4000:	Length: 168	Compression: 0.404762	Avg Compression: 0.482372
Article 5000:	Length: 168	Compression: 0.357143	Avg Compression: 0.455808
Article 6000:	Length: 168	Compression: 0.505952	Avg Compression: 0.464583
Article 7000:	Length: 176	Compression: 0.357955	Avg Compression: 0.448063
Article 8000:	Length: 176	Compression: 0.539773	Avg Compression: 0.460366
Article 9000:	Length: 184	Compression: 0.527174	Avg Compression: 0.468583
Article 10000:	Length: 184	Compression: 0.500000	Avg Compression: 0.472024
Article 11000:	Length: 192	Compression: 0.302083	Avg Compression: 0.454594
Article 12000:	Length: 192	Compression: 0.552083	Avg Compression: 0.463663
Article 13000:	Length: 200	Compression: 0.405000	Avg Compression: 0.458481
Article 14000:	Length: 200	Compres

Article 110000:	Length: 18048	Compression: 0.291888	Avg Compression: 0.358502
Article 111000:	Length: 18520	Compression: 0.374298	Avg Compression: 0.359104
Article 112000:	Length: 18920	Compression: 0.495190	Avg Compression: 0.364204
Article 113000:	Length: 19328	Compression: 0.328177	Avg Compression: 0.362875
Article 114000:	Length: 19944	Compression: 0.342609	Avg Compression: 0.362133
Article 115000:	Length: 19848	Compression: 0.093863	Avg Compression: 0.352692
Article 116000:	Length: 20024	Compression: 0.090192	Avg Compression: 0.343692
Article 117000:	Length: 20096	Compression: 0.089719	Avg Compression: 0.335244
Article 118000:	Length: 20176	Compression: 0.090206	Avg Compression: 0.327325
Article 119000:	Length: 20272	Compression: 0.098461	Avg Compression: 0.320127
Article 120000:	Length: 20000	Compression: 0.373050	Avg Compression: 0.321720
Article 121000:	Length: 20544	Compression: 0.100613	Avg Compression: 0.315090
Article 122000:	Length: 20664	Compression: 0.101965	Avg Compress

Article 215000:	Length: 193192	Compression: 0.334020	Avg Compression: 0.302684
Article 216000:	Length: 213000	Compression: 0.293113	Avg Compression: 0.302328
Article 217000:	Length: 238400	Compression: 0.332085	Avg Compression: 0.303517
Article 218000:	Length: 274360	Compression: 0.326881	Avg Compression: 0.304544
Article 219000:	Length: 329056	Compression: 0.300970	Avg Compression: 0.304365
Article 220000:	Length: 458704	Compression: 0.502344	Avg Compression: 0.317284


In [21]:
model.train(384, 96, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [22]:
total_raw = 0
total_compressed = 0

for index, encoded_article in enumerate(articles.articles_generator(1)):
    if index % 1000 == 0:
        article = subword_text_encoder.decode(encoded_article)
        raw = len(article) * 8
        if raw == 0:
            continue
        compressed, _ = huffman_archive_size(model, encoded_article)
        total_raw += raw
        total_compressed += compressed

        print('Article %d:\tLength: %d\tCompression: %f\tAvg Compression: %f' % (index, raw, compressed/raw, total_compressed/total_raw))

Article 1000:	Length: 144	Compression: 7.104167	Avg Compression: 7.104167
Article 2000:	Length: 152	Compression: 6.697368	Avg Compression: 6.895270
Article 3000:	Length: 160	Compression: 6.356250	Avg Compression: 6.706140
Article 4000:	Length: 168	Compression: 6.041667	Avg Compression: 6.527244
Article 5000:	Length: 168	Compression: 5.946429	Avg Compression: 6.404040
Article 6000:	Length: 168	Compression: 6.125000	Avg Compression: 6.355208
Article 7000:	Length: 176	Compression: 5.687500	Avg Compression: 6.251761
Article 8000:	Length: 176	Compression: 5.920455	Avg Compression: 6.207317
Article 9000:	Length: 184	Compression: 5.646739	Avg Compression: 6.138369
Article 10000:	Length: 184	Compression: 5.586957	Avg Compression: 6.077976
Article 11000:	Length: 192	Compression: 5.213542	Avg Compression: 5.989316
Article 12000:	Length: 192	Compression: 5.458333	Avg Compression: 5.939922
Article 13000:	Length: 200	Compression: 5.070000	Avg Compression: 5.863074
Article 14000:	Length: 200	Compres

Article 110000:	Length: 18048	Compression: 0.358932	Avg Compression: 0.526571
Article 111000:	Length: 18520	Compression: 0.356749	Avg Compression: 0.520100
Article 112000:	Length: 18920	Compression: 0.538901	Avg Compression: 0.520804
Article 113000:	Length: 19328	Compression: 0.469319	Avg Compression: 0.518906
Article 114000:	Length: 19944	Compression: 0.330225	Avg Compression: 0.511991
Article 115000:	Length: 19848	Compression: 0.094821	Avg Compression: 0.497310
Article 116000:	Length: 20024	Compression: 0.089592	Avg Compression: 0.483332
Article 117000:	Length: 20096	Compression: 0.091361	Avg Compression: 0.470293
Article 118000:	Length: 20176	Compression: 0.092337	Avg Compression: 0.458079
Article 119000:	Length: 20272	Compression: 0.098017	Avg Compression: 0.446755
Article 120000:	Length: 20000	Compression: 0.357550	Avg Compression: 0.444070
Article 121000:	Length: 20544	Compression: 0.098764	Avg Compression: 0.433716
Article 122000:	Length: 20664	Compression: 0.103223	Avg Compress

Article 215000:	Length: 193192	Compression: 0.343762	Avg Compression: 0.319865
Article 216000:	Length: 213000	Compression: 0.286333	Avg Compression: 0.318618
Article 217000:	Length: 238400	Compression: 0.327815	Avg Compression: 0.318985
Article 218000:	Length: 274360	Compression: 0.330081	Avg Compression: 0.319473
Article 219000:	Length: 329056	Compression: 0.292503	Avg Compression: 0.318122
Article 220000:	Length: 458704	Compression: 0.509937	Avg Compression: 0.330639


In [23]:
model.train(192, 192, epochs=5)

Epoch 1/5
Epoch 2/5

KeyboardInterrupt: 