# tensorflow-compress

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/byronknoll/tensorflow-compress/blob/master/tensorflow-compress.ipynb)

Made by Byron Knoll. GitHub repository: https://github.com/byronknoll/tensorflow-compress

### Description

tensorflow-compress performs lossless data compression using neural networks in TensorFlow. It can run on GPUs with a large batch size to get a substantial speed improvement. It is made using Colab, which should make it easy to run through a web browser. You can choose a file, perform compression (or decompression), and download the result.

tensorflow-compress is open source and the code should hopefully be easy to understand and modify. Feel free to experiment with the code and create pull requests with improvements.

The neural network is trained from scratch during compression and decompression, so the model weights do not need to be stored. Arithmetic coding is used to encode the model predictions to a file.

Feel free to contact me at byron@byronknoll.com if you have any questions.

### Instructions:

Basic usage: configure all the fields in the "Parameters" section and select Runtime->Run All.

Advanced usage: save a copy of this notebook and modify the code.

### Related Projects
*   [NNCP](https://bellard.org/nncp/) - this uses a similar architecture to tensorflow-compress, but with transformers rather than LSTM. NNCP currently outperforms tensorflow-compress: running faster, on a worse GPU, while getting better compression rate.
*   [lstm-compress](https://github.com/byronknoll/lstm-compress) - uses LSTM for compression, but limited to running on a CPU with a batch size of one.
*   [cmix](http://www.byronknoll.com/cmix.html) - shares the same LSTM code as lstm-compress, but contains a bunch of other components to get better compression rate.
*   [DeepZip](https://github.com/mohit1997/DeepZip) - this also performs compression using TensorFlow. However, it has some substantial architecture differences to tensorflow-compress: it uses pretraining (using multiple passes over the training data) and stores the model weights in the compressed file.

### Benchmarks
These benchmarks were performed using tensorflow-compress v4 with the default parameter settings. Some parameters differ between enwik8 and enwik9 as noted below. Compute Engine VM was used with A100 GPU. Compression time and decompression time are approximately the same.
*   enwik8: compressed to 15,905,037 bytes in 32,048.55 seconds. NNCP preprocessing time: 206.38 seconds. Dictionary size: 65,987 bytes.
*   enwik9: compressed to 113,542,413 bytes in 289,632.17 seconds. NNCP preprocessing time: 1,762.28 seconds. Dictionary size: 79,876 bytes. The preprocessed enwik9 file was split into four parts using [this notebook](https://colab.sandbox.google.com/github/byronknoll/tensorflow-compress/blob/master/nncp-splitter.ipynb). The "checkpoint" option was used to save/load model weights between processing each part. For the first part, start_learning_rate=0.0005 and end_learning_rate=0.0002 was used. For the remaining three parts, a constant 0.0002 learning rate was used.

See the [Large Text Compression Benchmark](http://mattmahoney.net/dc/text.html) for more information about the test files and a comparison with other programs.

### Versions
* v4 - released August 10, 2022. Changes from v3:
  * Added embedding layer
  * Tuned parameters to run on A100 GPU
* v3 - released November 28, 2020. Changes from v2:
  * Parameter tuning
  * [New notebook](https://colab.sandbox.google.com/github/byronknoll/tensorflow-compress/blob/master/nncp-splitter.ipynb) for file splitting
  * Support for learning rate decay
* v2 - released September 6, 2020. Changes from v1:
  * 16 bit floats for improved speed
  * Weight updates occur at every timestep (instead of at spaced intervals)
  * Support for saving/loading model weights
* v1 - released July 20, 2020.

## Parameters

In [None]:
#@markdown ---
batch_size = 256 #@param {type:"integer"}
#@markdown _This will split the file into N batches, and process them in parallel. Increasing this will improve speed but can make compression rate worse. Make this a multiple of 8 to improve speed on certain GPUs._

#@markdown ---
seq_length =  15#@param {type:"integer"}
#@markdown _This determines the horizon for back propagation through time. Reducing this will improve speed, but can make compression rate worse._

#@markdown ---
rnn_units =  4000 #@param {type:"integer"}
#@markdown _This is the number of units to use within each LSTM layer. Reducing this will improve speed and reduce memory usage, but can make compression rate worse. Make this a multiple of 8 to improve speed on certain GPUs._

#@markdown ---
num_layers = 8 #@param {type:"integer"}
#@markdown _This is the number of LSTM layers to use. Reducing this will improve speed, but can make compression rate worse._

#@markdown ---
embedding_size=1024 #@param {type:"integer"}
#@markdown _Size of the embedding layer._

#@markdown ---
start_learning_rate = 0.0005 #@param {type:"number"}
#@markdown _Learning rate for Adam optimizer._

#@markdown ---
end_learning_rate = 0.0002 #@param {type:"number"}
#@markdown _Typically this should be set to the same value as the "start_learning_rate" parameter above. If this is set to a different value, the learning rate will start at "start_learning_rate" and linearly change to "end_learning_rate" by the end of the file. For large files this could be useful for learning rate decay._

#@markdown ---
mode = 'compress' #@param ["compress", "decompress", "both", "preprocess_only"]
#@markdown _Whether to run compression only, decompression only, or both. "preprocess_only" will only run preprocessing and skip compression._

#@markdown ---
preprocess = 'nncp' #@param ["cmix", "nncp", "nncp-done", "none"]
#@markdown _The choice of preprocessor. NNCP works better on enwik8/enwik9. NNCP preprocessing is slower since it constructs a custom dictionary, while cmix uses a pretrained dictionary. "nncp_done" is used for files which have already been preprocessed by NNCP (the dictionary must also be included)._

#@markdown ---
n_words = 8192 #@param {type:"integer"}
#@markdown _Only used for NNCP preprocessor: this is the approximative maximum number of words of the dictionary. Recommended value for enwik8/enwik9: 8192._

#@markdown ---
min_freq = 512 #@param {type:"integer"}
#@markdown _Only used for NNCP preprocessor: this is the minimum frequency of the selected words. Recommended value for enwik8: 64, enwik9: 512._

#@markdown ---
path_to_file = "enwik8" #@param ["enwik4", "enwik6", "enwik8", "enwik9", "custom"]
#@markdown _Name of the file to compress or decompress. If "custom" is selected, use the next parameter to set a custom path._

#@markdown ---
custom_path = '' #@param {type:"string"}
#@markdown _Use this if the previous parameter was set to "custom". Set this to the name of the file you want to compress/decompress. You can transfer files using the "http_path" or "local_upload" options below._

#@markdown ---
http_path = '' #@param {type:"string"}
#@markdown _The file from this URL will be downloaded. It is recommended to use Google Drive URLs to get fast transfer speed. Use this format for Google Drive files: https://drive.google.com/uc?id= and paste the file ID at the end of the URL. You can find the file ID from the "Get Link" URL in Google Drive. You can enter multiple URLs here, space separated._

#@markdown ---
local_upload = False #@param {type:"boolean"}
#@markdown _If enabled, you will be prompted in the "Setup Files" section to select files to upload from your local computer. You can upload multiple files. Note: the upload speed can be quite slow (use "http_path" for better transfer speeds)._

#@markdown ---
download_option = "no_download" #@param ["no_download", "local", "google_drive"]
#@markdown _If this is set to "local", the output files will be downloaded to your computer after compression/decompression. If set to "google_drive", they will be copied to your Google Drive account (which is significantly faster than downloading locally)._

#@markdown ---
checkpoint = False #@param {type:"boolean"}
#@markdown _If this is enabled, a checkpoint of the model weights will be downloaded (using the "download_option" parameter). This can be useful for getting around session time limits for Colab, by splitting files into multiple segments and saving/loading the model weights between each segment. Checkpoints (if present) will automatically be loaded when starting compression._


## Setup

In [None]:
#@title Imports

import tensorflow as tf
import numpy as np
import random
from google.colab import files
import time
import math
import sys
import subprocess
import contextlib
import os
from tensorflow.keras import mixed_precision
from google.colab import drive
os.environ['TF_DETERMINISTIC_OPS'] = '1'

In [None]:
#@title System Info

def system_info():
  """Prints out system information."""
  gpu_info = !nvidia-smi
  gpu_info = '\n'.join(gpu_info)
  if gpu_info.find('failed') >= 0:
    print('Select the Runtime → "Change runtime type" menu to enable a GPU accelerator, ')
    print('and then re-execute this cell.')
  else:
    print(gpu_info)
  print("TensorFlow version: ", tf.__version__)
  !lscpu |grep 'Model name'
  !cat /proc/meminfo | head -n 3

system_info()

In [None]:
#@title Mount Google Drive
if download_option == "google_drive":
  drive.mount('/content/gdrive')

In [None]:
#@title Setup Files

!mkdir -p "data"

if local_upload:
  %cd data
  files.upload()
  %cd ..

if path_to_file == 'enwik8' or path_to_file == 'enwik6' or path_to_file == 'enwik4':
  %cd data
  !gdown --id 1BUbuEUhPOBaVZDdOh0KG8hxvIDgsyiZp
  !unzip enwik8.zip
  !head -c 1000000 enwik8 > enwik6
  !head -c 10000 enwik8 > enwik4
  path_to_file = 'data/' + path_to_file
  %cd ..

if path_to_file == 'enwik9':
  %cd data
  !gdown --id 1D2gCmf9AlXIBP62ARhy0XcIuIolOTRAE
  !unzip enwik9.zip
  path_to_file = 'data/' + path_to_file
  %cd ..

if path_to_file == 'custom':
  path_to_file = 'data/' + custom_path

if http_path:
  %cd data
  paths = http_path.split()
  for path in paths:
    !gdown $path
  %cd ..

if preprocess == 'cmix':
  !gdown --id 1qa7K28tlUDs9GGYbaL_iE9M4m0L1bYm9
  !unzip cmix-v18.zip
  %cd cmix
  !make
  %cd ..

if preprocess == 'nncp' or preprocess == 'nncp-done':
  !gdown --id 1EzVPbRkBIIbgOzvEMeM0YpibDi2R4SHD
  !tar -xf nncp-2019-11-16.tar.gz
  %cd nncp-2019-11-16/
  !make preprocess
  %cd ..

In [None]:
#@title Model Architecture

def build_model(vocab_size):
  """Builds the model architecture.

    Args:
      vocab_size: Int, size of the vocabulary.
  """
  policy = mixed_precision.Policy('mixed_float16')
  mixed_precision.set_global_policy(policy)
  inputs = [
    tf.keras.Input(batch_input_shape=[batch_size, seq_length])]
  # In addition to the primary input, there are also two "state" inputs for each
  # layer of the network.
  for i in range(num_layers):
    inputs.append(tf.keras.Input(shape=(None,)))
    inputs.append(tf.keras.Input(shape=(None,)))
  embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)(inputs[0])
  # Skip connections will be used to connect each LSTM layer output to the final
  # output layer. Each LSTM layer will get as input both the original input and
  # the output of the previous layer.
  skip_connections = []
  # In addition to the softmax output, there are also two "state" outputs for
  # each layer of the network.
  outputs = []
  predictions, state_h, state_c = tf.keras.layers.LSTM(rnn_units,
                          return_sequences=True,
                          return_state=True,
                          recurrent_initializer='glorot_uniform',
                          )(embedding, initial_state=[
                          tf.cast(inputs[1], tf.float16),
                          tf.cast(inputs[2], tf.float16)])
  skip_connections.append(predictions)
  outputs.append(state_h)
  outputs.append(state_c)
  for i in range(num_layers - 1):
    layer_input = tf.keras.layers.concatenate(
        [embedding, skip_connections[-1]])
    predictions, state_h, state_c = tf.keras.layers.LSTM(rnn_units,
        return_sequences=True,
        return_state=True,
        recurrent_initializer='glorot_uniform')(
          layer_input, initial_state=[tf.cast(inputs[i*2+3], tf.float16),
                                      tf.cast(inputs[i*2+4], tf.float16)])
    skip_connections.append(predictions)
    outputs.append(state_h)
    outputs.append(state_c)
  # The dense output layer only needs to be computed for the last timestep, so
  # we can discard the earlier outputs.
  last_timestep = []
  for i in range(num_layers):
    last_timestep.append(tf.slice(skip_connections[i], [0, seq_length - 1, 0],
                                [batch_size, 1, rnn_units]))
  if num_layers == 1:
    layer_input = last_timestep[0]
  else:
    layer_input = tf.keras.layers.concatenate(last_timestep)
  dense = tf.keras.layers.Dense(vocab_size, name='dense_logits')(layer_input)
  output = tf.keras.layers.Activation('softmax', dtype='float32',
                                      name='predictions')(dense)
  outputs.insert(0, output)
  model = tf.keras.Model(inputs=inputs, outputs=outputs)
  return model

In [None]:
#@title Compression Library

def get_symbol(index, length, freq, coder, compress, data):
  """Runs arithmetic coding and returns the next symbol.

  Args:
    index: Int, position of the symbol in the file.
    length: Int, size limit of the file.
    freq: ndarray, predicted symbol probabilities.
    coder: this is the arithmetic coder.
    compress: Boolean, True if compressing, False if decompressing.
    data: List containing each symbol in the file.
  
  Returns:
    The next symbol, or 0 if "index" is over the file size limit.
  """
  symbol = 0
  if index < length:
    if compress:
      symbol = data[index]
      coder.write(freq, symbol)
    else:
      symbol = coder.read(freq)
      data[index] = symbol
  return symbol

def train(pos, seq_input, length, vocab_size, coder, model, optimizer, compress,
          data, states):
  """Runs one training step.

  Args:
    pos: Int, position in the file for the current symbol for the *first* batch.
    seq_input: Tensor, containing the last seq_length inputs for the model.
    length: Int, size limit of the file.
    vocab_size: Int, size of the vocabulary.
    coder: this is the arithmetic coder.
    model: the model to generate predictions.
    optimizer: optimizer used to train the model.
    compress: Boolean, True if compressing, False if decompressing.
    data: List containing each symbol in the file.
    states: List containing state information for the layers of the model.
  
  Returns:
    seq_input: Tensor, containing the last seq_length inputs for the model.
    cross_entropy: cross entropy numerator.
    denom: cross entropy denominator.
  """
  loss = cross_entropy = denom = 0
  split = math.ceil(length / batch_size)
  # Keep track of operations while running the forward pass for automatic
  # differentiation.
  with tf.GradientTape() as tape:
    # The model inputs contain both seq_input and the states for each layer.
    inputs = states.pop(0)
    inputs.insert(0, seq_input)
    # Run the model (for all batches in parallel) to get predictions for the
    # next characters.
    outputs = model(inputs)
    predictions = outputs.pop(0)
    states.append(outputs)
    p = predictions.numpy()
    symbols = []
    # When the last batch reaches the end of the file, we start giving it "0"
    # as input. We use a mask to prevent this from influencing the gradients.
    mask = []
    # Go over each batch to run the arithmetic coding and prepare the next
    # input.
    for i in range(batch_size):
      # The "10000000" is used to convert floats into large integers (since
      # the arithmetic coder works on integers).
      freq = np.cumsum(p[i][0] * 10000000 + 1)
      index = pos + 1 + i * split
      symbol = get_symbol(index, length, freq, coder, compress, data)
      symbols.append(symbol)
      if index < length:
        prob = p[i][0][symbol]
        if prob <= 0:
          # Set a small value to avoid error with log2.
          prob = 0.000001
        cross_entropy += math.log2(prob)
        denom += 1
        mask.append(1.0)
      else:
        mask.append(0.0)
    # "input_one_hot" will be used both for the loss function and for the next
    # input.
    input_one_hot = tf.expand_dims(tf.one_hot(symbols, vocab_size), 1)
    loss = tf.keras.losses.categorical_crossentropy(
        input_one_hot, predictions, from_logits=False) * tf.expand_dims(
            tf.convert_to_tensor(mask), 1)
    scaled_loss = optimizer.get_scaled_loss(loss)
    # Remove the oldest input and append the new one.
    seq_input = tf.slice(seq_input, [0, 1],
                          [batch_size, seq_length - 1])
    seq_input = tf.concat([seq_input, tf.expand_dims(symbols, 1)], 1)
  # Run the backwards pass to update model weights.
  scaled_gradients = tape.gradient(scaled_loss, model.trainable_variables)
  grads = optimizer.get_unscaled_gradients(scaled_gradients)
  # Gradient clipping to make training more robust.
  capped_grads = [tf.clip_by_norm(grad, 4) for grad in grads]
  optimizer.apply_gradients(zip(capped_grads, model.trainable_variables))
  return (seq_input, cross_entropy, denom)

def reset_seed():
  """Initializes various random seeds to help with determinism."""
  SEED = 1234
  os.environ['PYTHONHASHSEED']=str(SEED)
  random.seed(SEED)
  np.random.seed(SEED)
  tf.random.set_seed(SEED)

def download(path):
  """Downloads the file at the specified path."""
  if download_option == 'local':
    files.download(path)
  elif download_option == 'google_drive':
    !cp -f $path /content/gdrive/My\ Drive

def process(compress, length, vocab_size, coder, data):
  """This runs compression/decompression.

  Args:
    compress: Boolean, True if compressing, False if decompressing.
    length: Int, size limit of the file.
    vocab_size: Int, size of the vocabulary.
    coder: this is the arithmetic coder.
    data: List containing each symbol in the file.
  """
  start = time.time()
  reset_seed()
  model = build_model(vocab_size = vocab_size)
  checkpoint_path = tf.train.latest_checkpoint('./data')
  if checkpoint_path:
    model.load_weights(checkpoint_path)
  model.summary()

  # Try to split the file into equal size pieces for the different batches. The
  # last batch may have fewer characters if the file can't be split equally.
  split = math.ceil(length / batch_size)

  learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
      start_learning_rate,
      split,
      end_learning_rate,
      power=1.0)
  optimizer = tf.keras.optimizers.Adam(
      learning_rate=learning_rate_fn, beta_1=0, beta_2=0.9999, epsilon=1e-5)
  optimizer = mixed_precision.LossScaleOptimizer(optimizer)

  hidden = model.reset_states()
  # Use a uniform distribution for predicting the first batch of symbols. The
  # "10000000" is used to convert floats into large integers (since the
  # arithmetic coder works on integers).
  freq = np.cumsum(np.full(vocab_size, (1.0 / vocab_size)) * 10000000 + 1)
  # Construct the first set of input characters for training.
  symbols = []
  for i in range(batch_size):
    symbols.append(get_symbol(i*split, length, freq, coder, compress, data))
  # Replicate the input tensor seq_length times, to match the input format.
  seq_input = tf.tile(tf.expand_dims(symbols, 1), [1, seq_length])
  pos = cross_entropy = denom = last_output = 0
  template = '{:0.2f}%\tcross entropy: {:0.2f}\ttime: {:0.2f}'
  # This will keep track of layer states. Initialize them to zeros.
  states = []
  for i in range(seq_length):
    states.append([tf.zeros([batch_size, rnn_units])] * (num_layers * 2))
  # Keep repeating the training step until we get to the end of the file.
  while pos < split:
    seq_input, ce, d = train(pos, seq_input, length, vocab_size, coder, model,
                             optimizer, compress, data, states)
    cross_entropy += ce
    denom += d
    pos += 1
    time_diff = time.time() - start
    # If it has been over 20 seconds since the last status message, display a
    # new one.
    if time_diff - last_output > 20:
      last_output = time_diff
      percentage = 100 * pos / split
      if percentage >= 100: continue
      print(template.format(percentage, -cross_entropy / denom, time_diff))
  if compress:
    coder.finish()
  print(template.format(100, -cross_entropy / length, time.time() - start))
  system_info()
  if mode != "both" or not compress:
    model.save_weights('./data/model')


In [None]:
#@title Arithmetic Coding Library

# 
# Reference arithmetic coding
# Copyright (c) Project Nayuki
# 
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
# 

import sys
python3 = sys.version_info.major >= 3


# ---- Arithmetic coding core classes ----

# Provides the state and behaviors that arithmetic coding encoders and decoders share.
class ArithmeticCoderBase(object):
	
	# Constructs an arithmetic coder, which initializes the code range.
	def __init__(self, numbits):
		if numbits < 1:
			raise ValueError("State size out of range")
		
		# -- Configuration fields --
		# Number of bits for the 'low' and 'high' state variables. Must be at least 1.
		# - Larger values are generally better - they allow a larger maximum frequency total (maximum_total),
		#   and they reduce the approximation error inherent in adapting fractions to integers;
		#   both effects reduce the data encoding loss and asymptotically approach the efficiency
		#   of arithmetic coding using exact fractions.
		# - But larger state sizes increase the computation time for integer arithmetic,
		#   and compression gains beyond ~30 bits essentially zero in real-world applications.
		# - Python has native bigint arithmetic, so there is no upper limit to the state size.
		#   For Java and C++ where using native machine-sized integers makes the most sense,
		#   they have a recommended value of num_state_bits=32 as the most versatile setting.
		self.num_state_bits = numbits
		# Maximum range (high+1-low) during coding (trivial), which is 2^num_state_bits = 1000...000.
		self.full_range = 1 << self.num_state_bits
		# The top bit at width num_state_bits, which is 0100...000.
		self.half_range = self.full_range >> 1  # Non-zero
		# The second highest bit at width num_state_bits, which is 0010...000. This is zero when num_state_bits=1.
		self.quarter_range = self.half_range >> 1  # Can be zero
		# Minimum range (high+1-low) during coding (non-trivial), which is 0010...010.
		self.minimum_range = self.quarter_range + 2  # At least 2
		# Maximum allowed total from a frequency table at all times during coding. This differs from Java
		# and C++ because Python's native bigint avoids constraining the size of intermediate computations.
		self.maximum_total = self.minimum_range
		# Bit mask of num_state_bits ones, which is 0111...111.
		self.state_mask = self.full_range - 1
		
		# -- State fields --
		# Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s.
		self.low = 0
		# High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s.
		self.high = self.state_mask
	
	
	# Updates the code range (low and high) of this arithmetic coder as a result
	# of processing the given symbol with the given frequency table.
	# Invariants that are true before and after encoding/decoding each symbol
	# (letting full_range = 2^num_state_bits):
	# - 0 <= low <= code <= high < full_range. ('code' exists only in the decoder.)
	#   Therefore these variables are unsigned integers of num_state_bits bits.
	# - low < 1/2 * full_range <= high.
	#   In other words, they are in different halves of the full range.
	# - (low < 1/4 * full_range) || (high >= 3/4 * full_range).
	#   In other words, they are not both in the middle two quarters.
	# - Let range = high - low + 1, then full_range/4 < minimum_range
	#   <= range <= full_range. These invariants for 'range' essentially
	#   dictate the maximum total that the incoming frequency table can have.
	def update(self, freqs, symbol):
		# State check
		low = self.low
		high = self.high
		# if low >= high or (low & self.state_mask) != low or (high & self.state_mask) != high:
		# 	raise AssertionError("Low or high out of range")
		range = high - low + 1
		# if not (self.minimum_range <= range <= self.full_range):
		# 	raise AssertionError("Range out of range")
		
		# Frequency table values check
		total = int(freqs[-1])
		symlow = int(freqs[symbol-1]) if symbol > 0 else 0
		symhigh = int(freqs[symbol])
		#total = freqs.get_total()
		#symlow = freqs.get_low(symbol)
		#symhigh = freqs.get_high(symbol)
		# if symlow == symhigh:
		# 	raise ValueError("Symbol has zero frequency")
		# if total > self.maximum_total:
		# 	raise ValueError("Cannot code symbol because total is too large")
		
		# Update range
		newlow  = low + symlow  * range // total
		newhigh = low + symhigh * range // total - 1
		self.low = newlow
		self.high = newhigh
		
		# While low and high have the same top bit value, shift them out
		while ((self.low ^ self.high) & self.half_range) == 0:
			self.shift()
			self.low  = ((self.low  << 1) & self.state_mask)
			self.high = ((self.high << 1) & self.state_mask) | 1
		# Now low's top bit must be 0 and high's top bit must be 1
		
		# While low's top two bits are 01 and high's are 10, delete the second highest bit of both
		while (self.low & ~self.high & self.quarter_range) != 0:
			self.underflow()
			self.low = (self.low << 1) ^ self.half_range
			self.high = ((self.high ^ self.half_range) << 1) | self.half_range | 1
	
	
	# Called to handle the situation when the top bit of 'low' and 'high' are equal.
	def shift(self):
		raise NotImplementedError()
	
	
	# Called to handle the situation when low=01(...) and high=10(...).
	def underflow(self):
		raise NotImplementedError()


# Encodes symbols and writes to an arithmetic-coded bit stream.
class ArithmeticEncoder(ArithmeticCoderBase):
	
	# Constructs an arithmetic coding encoder based on the given bit output stream.
	def __init__(self, numbits, bitout):
		super(ArithmeticEncoder, self).__init__(numbits)
		# The underlying bit output stream.
		self.output = bitout
		# Number of saved underflow bits. This value can grow without bound.
		self.num_underflow = 0
	
	
	# Encodes the given symbol based on the given frequency table.
	# This updates this arithmetic coder's state and may write out some bits.
	def write(self, freqs, symbol):
		self.update(freqs, symbol)
	
	
	# Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly.
	# It is important that this method must be called at the end of the each encoding process.
	# Note that this method merely writes data to the underlying output stream but does not close it.
	def finish(self):
		self.output.write(1)
	
	
	def shift(self):
		bit = self.low >> (self.num_state_bits - 1)
		self.output.write(bit)
		
		# Write out the saved underflow bits
		for _ in range(self.num_underflow):
			self.output.write(bit ^ 1)
		self.num_underflow = 0
	
	
	def underflow(self):
		self.num_underflow += 1


# Reads from an arithmetic-coded bit stream and decodes symbols.
class ArithmeticDecoder(ArithmeticCoderBase):
	
	# Constructs an arithmetic coding decoder based on the
	# given bit input stream, and fills the code bits.
	def __init__(self, numbits, bitin):
		super(ArithmeticDecoder, self).__init__(numbits)
		# The underlying bit input stream.
		self.input = bitin
		# The current raw code bits being buffered, which is always in the range [low, high].
		self.code = 0
		for _ in range(self.num_state_bits):
			self.code = self.code << 1 | self.read_code_bit()
	
	
	# Decodes the next symbol based on the given frequency table and returns it.
	# Also updates this arithmetic coder's state and may read in some bits.
	def read(self, freqs):
		#if not isinstance(freqs, CheckedFrequencyTable):
		#	freqs = CheckedFrequencyTable(freqs)
		
		# Translate from coding range scale to frequency table scale
		total = int(freqs[-1])
		#total = freqs.get_total()
		#if total > self.maximum_total:
		#	raise ValueError("Cannot decode symbol because total is too large")
		range = self.high - self.low + 1
		offset = self.code - self.low
		value = ((offset + 1) * total - 1) // range
		#assert value * range // total <= offset
		#assert 0 <= value < total
		
		# A kind of binary search. Find highest symbol such that freqs.get_low(symbol) <= value.
		start = 0
		end = len(freqs)
		#end = freqs.get_symbol_limit()
		while end - start > 1:
			middle = (start + end) >> 1
			low = int(freqs[middle-1]) if middle > 0 else 0
			#if freqs.get_low(middle) > value:
			if low > value:
				end = middle
			else:
				start = middle
		#assert start + 1 == end
		
		symbol = start
		#assert freqs.get_low(symbol) * range // total <= offset < freqs.get_high(symbol) * range // total
		self.update(freqs, symbol)
		#if not (self.low <= self.code <= self.high):
		#	raise AssertionError("Code out of range")
		return symbol
	
	
	def shift(self):
		self.code = ((self.code << 1) & self.state_mask) | self.read_code_bit()
	
	
	def underflow(self):
		self.code = (self.code & self.half_range) | ((self.code << 1) & (self.state_mask >> 1)) | self.read_code_bit()
	
	
	# Returns the next bit (0 or 1) from the input stream. The end
	# of stream is treated as an infinite number of trailing zeros.
	def read_code_bit(self):
		temp = self.input.read()
		if temp == -1:
			temp = 0
		return temp


# ---- Bit-oriented I/O streams ----

# A stream of bits that can be read. Because they come from an underlying byte stream,
# the total number of bits is always a multiple of 8. The bits are read in big endian.
class BitInputStream(object):
	
	# Constructs a bit input stream based on the given byte input stream.
	def __init__(self, inp):
		# The underlying byte stream to read from
		self.input = inp
		# Either in the range [0x00, 0xFF] if bits are available, or -1 if end of stream is reached
		self.currentbyte = 0
		# Number of remaining bits in the current byte, always between 0 and 7 (inclusive)
		self.numbitsremaining = 0
	
	
	# Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if
	# the end of stream is reached. The end of stream always occurs on a byte boundary.
	def read(self):
		if self.currentbyte == -1:
			return -1
		if self.numbitsremaining == 0:
			temp = self.input.read(1)
			if len(temp) == 0:
				self.currentbyte = -1
				return -1
			self.currentbyte = temp[0] if python3 else ord(temp)
			self.numbitsremaining = 8
		assert self.numbitsremaining > 0
		self.numbitsremaining -= 1
		return (self.currentbyte >> self.numbitsremaining) & 1
	
	
	# Reads a bit from this stream. Returns 0 or 1 if a bit is available, or raises an EOFError
	# if the end of stream is reached. The end of stream always occurs on a byte boundary.
	def read_no_eof(self):
		result = self.read()
		if result != -1:
			return result
		else:
			raise EOFError()
	
	
	# Closes this stream and the underlying input stream.
	def close(self):
		self.input.close()
		self.currentbyte = -1
		self.numbitsremaining = 0


# A stream where bits can be written to. Because they are written to an underlying
# byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits.
# The bits are written in big endian.
class BitOutputStream(object):
	
	# Constructs a bit output stream based on the given byte output stream.
	def __init__(self, out):
		self.output = out  # The underlying byte stream to write to
		self.currentbyte = 0  # The accumulated bits for the current byte, always in the range [0x00, 0xFF]
		self.numbitsfilled = 0  # Number of accumulated bits in the current byte, always between 0 and 7 (inclusive)
	
	
	# Writes a bit to the stream. The given bit must be 0 or 1.
	def write(self, b):
		if b not in (0, 1):
			raise ValueError("Argument must be 0 or 1")
		self.currentbyte = (self.currentbyte << 1) | b
		self.numbitsfilled += 1
		if self.numbitsfilled == 8:
			towrite = bytes((self.currentbyte,)) if python3 else chr(self.currentbyte)
			self.output.write(towrite)
			self.currentbyte = 0
			self.numbitsfilled = 0
	
	
	# Closes this stream and the underlying output stream. If called when this
	# bit stream is not at a byte boundary, then the minimum number of "0" bits
	# (between 0 and 7 of them) are written as padding to reach the next byte boundary.
	def close(self):
		while self.numbitsfilled != 0:
			self.write(0)
		self.output.close()

## Compress

In [None]:
#@title Preprocess

if mode != 'decompress':
  input_path = path_to_file

  if preprocess == 'cmix':
    !./cmix/cmix -s ./cmix/dictionary/english.dic $path_to_file ./data/preprocessed.dat
    input_path = "./data/preprocessed.dat"

  # int_list will contain the characters of the file.
  int_list = []
  if preprocess == 'nncp' or preprocess == 'nncp-done':
    if preprocess == 'nncp':
      !time ./nncp-2019-11-16/preprocess c data/dictionary.words $path_to_file data/preprocessed.dat $n_words $min_freq
    else:
      !cp $path_to_file data/preprocessed.dat
    input_path = "./data/preprocessed.dat"
    orig = open(input_path, 'rb').read()
    for i in range(0, len(orig), 2):
      int_list.append(orig[i] * 256 + orig[i+1])
    vocab_size = int(subprocess.check_output(
        ['wc', '-l', 'data/dictionary.words']).split()[0])
  else:
    text = open(input_path, 'rb').read()
    vocab = sorted(set(text))
    vocab_size = len(vocab)
    # Creating a mapping from unique characters to indexes.
    char2idx = {u:i for i, u in enumerate(vocab)}
    for idx, c in enumerate(text):
      int_list.append(char2idx[c])

  # Round up to a multiple of 8 to improve performance.
  vocab_size = math.ceil(vocab_size/8) * 8
  file_len = len(int_list)
  print ('Length of file: {} symbols'.format(file_len))
  print ('Vocabulary size: {}'.format(vocab_size))

In [None]:
#@title Compression

if mode == 'compress' or mode == 'both':
  original_file = path_to_file
  path_to_file = "data/compressed.dat"
  with open(path_to_file, "wb") as out, contextlib.closing(BitOutputStream(out)) as bitout:
    length = len(int_list)
    # Write the original file length to the compressed file header.
    out.write(length.to_bytes(5, byteorder='big', signed=False))
    if preprocess != 'nncp' and preprocess != 'nncp-done':
      # If NNCP was not used for preprocessing, write 256 bits to the compressed
      # file header to keep track of the vocabulary.
      for i in range(256):
        if i in char2idx:
          bitout.write(1)
        else:
          bitout.write(0)
    enc = ArithmeticEncoder(32, bitout)
    process(True, length, vocab_size, enc, int_list)
  print("Compressed size:", os.path.getsize(path_to_file))

In [None]:
#@title Download Result

if mode == 'preprocess_only':
  if preprocess == 'nncp':
    download('data/dictionary.words')
  download(input_path)
elif mode != 'decompress':
  download('data/compressed.dat')
  if preprocess == 'nncp':
    download('data/dictionary.words')
  if checkpoint and mode != "both":
    download('data/model.index')
    download('data/model.data-00000-of-00001')
    download('data/checkpoint')

## Decompress

In [None]:
#@title Decompression

if mode == 'decompress' or mode == 'both':
  output_path = "data/decompressed.dat"
  with open(path_to_file, "rb") as inp, open(output_path, "wb") as out:
    # Read the original file size from the header.
    length = int.from_bytes(inp.read()[:5], byteorder='big')
    inp.seek(5)
    # Create a list to store the file characters.
    output = [0] * length
    bitin = BitInputStream(inp)
    if preprocess == 'nncp' or preprocess == 'nncp-done':
      # If the preprocessor is NNCP, we can get the vocab_size from the
      # dictionary.
      vocab_size = int(subprocess.check_output(
          ['wc', '-l', 'data/dictionary.words']).split()[0])
    else:
      # If the preprocessor is not NNCP, we can get the vocabulary from the file
      # header.
      vocab = []
      for i in range(256):
        if bitin.read():
          vocab.append(i)
      vocab_size = len(vocab)
    # Round up to a multiple of 8 to improve performance.
    vocab_size = math.ceil(vocab_size/8) * 8
    dec = ArithmeticDecoder(32, bitin)
    process(False, length, vocab_size, dec, output)
    # The decompressed data is stored in the "output" list. We can now write the
    # data to file (based on the type of preprocessing used).
    if preprocess == 'nncp' or preprocess == 'nncp-done':
      for i in range(length):
        out.write(bytes(((output[i] // 256),)))
        out.write(bytes(((output[i] % 256),)))
    else:
      # Convert indexes back to the original characters.
      idx2char = np.array(vocab)
      for i in range(length):
        out.write(bytes((idx2char[output[i]],)))

  if preprocess == 'cmix':
    !./cmix/cmix -d ./cmix/dictionary/english.dic $output_path ./data/final.dat
    output_path = "data/final.dat"
  if preprocess == 'nncp' or preprocess == 'nncp-done':
    !./nncp-2019-11-16/preprocess d data/dictionary.words $output_path ./data/final.dat
    output_path = "data/final.dat"

In [None]:
#@title Download Result

if mode == 'decompress':
  if preprocess == 'nncp-done':
    download('data/decompressed.dat')
  else:
    download(output_path)
  if checkpoint:
    download('data/model.index')
    download('data/model.data-00000-of-00001')
    download('data/checkpoint')

In [None]:
#@title Validation

if mode == 'decompress' or mode == 'both':
  if preprocess == 'nncp-done':
    !md5sum data/decompressed.dat
  !md5sum $output_path
if mode == 'both':
  !md5sum $original_file