In [1]:
import itertools
import os

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
def articles():
    with open('page_revisions_text', 'rb') as text_file:
        pending_article_data = b''
        while True:
            data = text_file.read(1024 ** 2)
            if len(data) == 0:
                break

            articles = data.split(b'\0')
            articles[0] = pending_article_data + articles[0]
            for index, article in enumerate(articles):
                if index + 1 == len(articles):
                    pending_article_data = article
                else:
                    yield article

        print(pending_article_data)
        if len(pending_article_data) != 0:
            yield pending_article_data

In [3]:
subword_text_encoder = tfds.features.text.SubwordTextEncoder.load_from_file('vocab_4096')

In [4]:
BATCH_SIZE = 192
BATCHED_ITEM_LENGTH = 256
BUFFER_SIZE = 1024
TYPE=np.int16

def articles_generator():
    for index, article in enumerate(itertools.islice(articles(), 0, 10000)):
        yield np.array(subword_text_encoder.encode(article + b'\0'), dtype=TYPE)

    # Pad the article count to the batch size
    # We do this to ensure that no data is dropped
    index += 1
    while index % BATCH_SIZE != 0:
        yield np.array([0], dtype=TYPE)
        index += 1

def subbatches():
    dataset = tf.data.Dataset.from_generator(articles_generator, output_types=TYPE)
    dataset = dataset.shuffle(BUFFER_SIZE)
    dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=([None]), drop_remainder=True)
    dataset = dataset.shuffle(2000)

    for batch in dataset.as_numpy_iterator():
        remaining = batch
        while remaining.shape[1] > 1:
            yield remaining[:, :BATCHED_ITEM_LENGTH + 1]
            remaining = remaining[:, BATCHED_ITEM_LENGTH:]

dataset = tf.data.Dataset.from_generator(subbatches, output_types=TYPE, output_shapes=(BATCH_SIZE, None))
dataset = dataset.map(lambda batch: (batch[:, :-1], batch[:, 1:]))

dataset

<MapDataset shapes: ((192, None), (192, None)), types: (tf.int16, tf.int16)>

In [5]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    return tf.keras.Sequential([
        tf.keras.layers.Masking(mask_value=0, batch_input_shape=[batch_size, None]),
        tf.keras.layers.Embedding(vocab_size, embedding_dim),
        tf.keras.layers.LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.Dense(vocab_size),
    ])

In [6]:
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

def average_batch_length(true_labels, predictions):
    return tf.shape(true_labels)[1]

model = build_model(vocab_size = subword_text_encoder.vocab_size, embedding_dim=512, rnn_units=1024, batch_size=BATCH_SIZE)
model.compile(optimizer='adam', loss=loss, metrics=[average_batch_length])

In [7]:
checkpoint_dir = './training_checkpoints-1' # Directory where the checkpoints will be saved
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") # Name of the checkpoint files

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True)

In [8]:
class ModelStateResetter(tf.keras.callbacks.Callback):
    def __init__(self):
        self.last_total_length = 0

    def on_batch_end(self, batch, logs={}):
        average_batch_length = logs.get('average_batch_length', 0)
        total_length = int(round(average_batch_length * (batch + 1)))
        current_batch_length = total_length - self.last_total_length
        self.last_total_length = total_length
        
        if current_batch_length < BATCHED_ITEM_LENGTH:
            self.model.reset_states()
        
model_state_resetter_callback = ModelStateResetter()

In [9]:
total_epochs = 10

for epoch in range(total_epochs):
    print('Epoch %d/%d' % (epoch + 1, total_epochs))
    model.fit(dataset, callbacks=[checkpoint_callback, model_state_resetter_callback])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [10]:
total_epochs = 10

for epoch in range(total_epochs):
    print('Epoch %d/%d' % (epoch + 1, total_epochs))
    model.fit(dataset, callbacks=[checkpoint_callback, model_state_resetter_callback])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [11]:
with open('page_revisions_text', 'rb') as text_file:
    data = text_file.read()

article = data.split(b'\0')[120]
del data

encoded_article = np.array(subword_text_encoder.encode(article + b'\0'), dtype=TYPE)

print('Raw:', len(article))
print('Encoded:', len(encoded_article))

Raw: 25541
Encoded: 8222


In [12]:
import ctypes

class Huffman:
    huffman = ctypes.CDLL('x64/Release/huffman')
    
    huffman.create_tree.restype = ctypes.c_void_p
    huffman.destroy_tree.restype = None
    huffman.load_weights.restype = None
    huffman.create_code_string.restype = ctypes.c_char_p
    
    def __init__(self, category_count):
        self.tree = ctypes.c_void_p(self.huffman.create_tree(category_count))

    def __del__(self):
        self.huffman.destroy_tree(self.tree)
        
    def load_weights(self, weights):
        self.huffman.load_weights(self.tree, weights.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
    
    def get_code_length(self, category):
        return self.huffman.get_code_length(self.tree, category)

    def get_code_zero_count(self, category):
        return self.huffman.get_code_zero_count(self.tree, category)

In [13]:
def huffman_archive_size(model, text):
    archived_size = 0
    zeros = 0
    input_eval = np.array([[0]], dtype=TYPE)
    huffman_tree = Huffman(subword_text_encoder.vocab_size)

    text_generated = []

    model.reset_states()

    for index, byte in enumerate(text):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0) # remove the batch dimension

        weights = tf.nn.softmax(predictions[0]).numpy()
        huffman_tree.load_weights(weights)
        zeros += huffman_tree.get_code_zero_count(byte.item())
        archived_size += huffman_tree.get_code_length(byte.item())

        input_eval = tf.expand_dims([byte], 0)
  
    return archived_size, zeros

In [14]:
tf.train.latest_checkpoint(checkpoint_dir)
model = build_model(vocab_size = subword_text_encoder.vocab_size, embedding_dim=512, rnn_units=1024, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))

In [15]:
total_raw = 0
total_compressed = 0

for index, article in enumerate(articles()):
    raw = (len(article) + 1) * 8
    encoded_article = np.array(subword_text_encoder.encode(article + b'\0'), dtype=TYPE)
    compressed, _ = huffman_archive_size(model, encoded_article)
    total_raw += raw
    total_compressed += compressed
    print('Article %d:\tCompression: %f\tAvg Compression: %f' % (index, compressed/raw, total_compressed/total_raw))

Article 0:	Compression: 0.430556	Avg Compression: 0.430556
Article 1:	Compression: 0.250000	Avg Compression: 0.304167
Article 2:	Compression: 0.229592	Avg Compression: 0.270642
Article 3:	Compression: 0.283333	Avg Compression: 0.273381
Article 4:	Compression: 0.314286	Avg Compression: 0.281609
Article 5:	Compression: 0.217949	Avg Compression: 0.269953
Article 6:	Compression: 0.176171	Avg Compression: 0.176511
Article 7:	Compression: 0.209459	Avg Compression: 0.176531
Article 8:	Compression: 0.246875	Avg Compression: 0.176579
Article 9:	Compression: 0.187500	Avg Compression: 0.176587
Article 10:	Compression: 0.236842	Avg Compression: 0.176626
Article 11:	Compression: 0.187500	Avg Compression: 0.176634
Article 12:	Compression: 0.184659	Avg Compression: 0.176640
Article 13:	Compression: 0.220395	Avg Compression: 0.176668
Article 14:	Compression: 0.172872	Avg Compression: 0.176665
Article 15:	Compression: 0.375000	Avg Compression: 0.176746
Article 16:	Compression: 0.314286	Avg Compression:

Article 139:	Compression: 0.173697	Avg Compression: 0.173913
Article 140:	Compression: 0.180190	Avg Compression: 0.174328
Article 141:	Compression: 0.189798	Avg Compression: 0.174445
Article 142:	Compression: 0.161267	Avg Compression: 0.174222
Article 143:	Compression: 0.174685	Avg Compression: 0.174226
Article 144:	Compression: 0.230769	Avg Compression: 0.174231
Article 145:	Compression: 0.175165	Avg Compression: 0.174264
Article 146:	Compression: 0.182842	Avg Compression: 0.174405
Article 147:	Compression: 0.188155	Avg Compression: 0.174774
Article 148:	Compression: 0.155593	Avg Compression: 0.173658
Article 149:	Compression: 0.165172	Avg Compression: 0.173123
Article 150:	Compression: 0.328571	Avg Compression: 0.173132
Article 151:	Compression: 0.155807	Avg Compression: 0.172928
Article 152:	Compression: 0.177003	Avg Compression: 0.173301
Article 153:	Compression: 0.189852	Avg Compression: 0.173681
Article 154:	Compression: 0.336957	Avg Compression: 0.173686
Article 155:	Compression

Article 274:	Compression: 0.156826	Avg Compression: 0.174778
Article 275:	Compression: 0.167002	Avg Compression: 0.174773
Article 276:	Compression: 0.135851	Avg Compression: 0.174727
Article 277:	Compression: 0.165159	Avg Compression: 0.174672
Article 278:	Compression: 0.198008	Avg Compression: 0.174806
Article 279:	Compression: 0.193740	Avg Compression: 0.174925
Article 280:	Compression: 0.317708	Avg Compression: 0.174927
Article 281:	Compression: 0.172388	Avg Compression: 0.174874
Article 282:	Compression: 0.184721	Avg Compression: 0.174896
Article 283:	Compression: 0.195087	Avg Compression: 0.175084
Article 284:	Compression: 0.190909	Avg Compression: 0.175290
Article 285:	Compression: 0.217391	Avg Compression: 0.175291
Article 286:	Compression: 0.235714	Avg Compression: 0.175292
Article 287:	Compression: 0.235294	Avg Compression: 0.175382
Article 288:	Compression: 0.175918	Avg Compression: 0.175383
Article 289:	Compression: 0.235714	Avg Compression: 0.175384
Article 290:	Compression

KeyboardInterrupt: 

Получаваме компресия `~ 0.176`. По-зле от преди. Но пък и обучавахме по-малък брой епохи и имахме по-висок loss.

За сметка на това всяка от по-дългите епохи доведве до по-нисък loss за сметка на по-дълго изчисление. Нищо от това не е изненадващо. Просто проверяваме, че сме на прав път.