In [None]:
# Compress MNIST to the smallest self extracting archive
# gzip gets 9.63MB, or 166 bytes per sample
# This is MDL, you add the bytes for the model, the bytes for the signal, and the bytes for the noise
# And this is rather poorly done, ideally you want to build the model online

"""
compressing model.json      with len   11986 complen    1140
compressing model.h5        with len  664504 complen  559453
compressing signal.dat      with len 3840000 complen 2579712
compressing noise.dat       with len 5882854 complen 5884667
"""

# Total size of x_train compressed is 9027925, about 6% better than gzip. Can you beat this?

# See http://prize.hutter1.net for the inspiration.
# I want to get at the question of what percent of data is noise.
# Lossless compression is AI and is a true optimization metric you can compute.
# Lossy compression is not good in the extremes and has fake loss functions.

# As the world gets weird, it's important to get your optimization target exactly right.
# Get it close, and we'll live in a hyperoptimized dystopia.

In [None]:
%load_ext Cython
%pylab inline
import os
import time
import random
import gzip
from tqdm import trange
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
y_train, y_test = [tf.keras.utils.to_categorical(x) for x in [y_train, y_test]]

In [None]:
# MNIST is 9.63MB gzipped, can we beat this with a net?
# that's 166 bytes per sample
x_train_gz = gzip.compress(x_train.data)
len(x_train_gz)

In [None]:
BITS = 64
K.set_floatx('float32')

def block(x_in, chan=32, strides=(2,2), bn=True, long=True):
  x = Conv2D(chan, (3,3), padding='same', strides=strides)(x_in)
  if bn:
    x = BatchNormalization()(x)
  if long:
    x = ELU()(x)
    x = Conv2D(chan, (3,3), padding='same')(x)
    if bn:
      x = BatchNormalization()(x)
    x = ELU()(x)
    x_cut = Conv2D(chan, (1,1), padding='same', strides=strides)(x_in)
    x = Add()([x, x_cut])
  return ELU()(x)

# encoder
x = in1 = Input((28,28))
x = Reshape((28,28,1))(x)
x = Lambda(lambda x: (x-127.5)/127.5)(x)

x = block(x, 64)
x = block(x, 128)
x = block(x, 192)

x = Flatten()(x)
x = Dense(BITS)(x)

@tf.custom_gradient
def binarize(x):
  def grad(dy):
    return dy
  return tf.clip_by_value(tf.math.sign(x), 0, 1), grad

#x = BatchNormalization()(x)
#x = GaussianNoise(0.5)(x)
#x = Dropout(rate=0.2)(x)
#x = Lambda(lambda x: binarize(x), name="round")(x)
#x = Activation('tanh')(x)

enc = Model(in1, x)
enc.compile('sgd', 'mse')

# decoder
x = in1 = Input((BITS,))

x = Dense(4*4*64)(x)
x = ELU()(x)
x = Reshape((4,4,64))(x)

x = UpSampling2D((2,2))(x)
x = block(x, 64, strides=(1,1))
x = UpSampling2D((2,2))(x)
x = Cropping2D((1,1))(x)
x = block(x, 64, strides=(1,1))

x = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same')(x)
x = ELU()(x)
x = Conv2D(128, (1,1), padding='same')(x)
x = ELU()(x)
x = Conv2D(256, (1,1), padding='same')(x)
x = Activation('softmax')(x)

# this is just a rescaling of sparse_categorical_crossentropy
def loss(y_true, y_pred):
  ll = tf.keras.backend.sparse_categorical_crossentropy(y_true, y_pred, False)
  # (nats per input channel) * 784 / (log(2)*8) = 141.384
  return (ll*784)/(np.log(2)*8)

dec = Model(in1, x)
dec.compile('sgd', loss)
dec.summary()

# autoencoder
x = dec(enc.output)
ae = Model(enc.inputs, x)
ae.compile(Adam(3e-4), loss)

In [None]:
ae.fit(x_train, x_train, validation_data=(x_test, x_test), batch_size=256, shuffle=True, epochs=20)

In [None]:
#ae.save_weights("/tmp/ae_model.h5", overwrite=True)
ae.load_weights("/tmp/ae_model.h5")

In [None]:
# extract the data using the encoder
z_swag = enc.predict(x_train, batch_size=256, verbose=1)
z_swag = (z_swag*2).astype(np.int8)
z_swag.shape

In [None]:
# this converts float32 into float16

import h5py

t1 = "/tmp/weights32.h5"
t2 = "/tmp/weights16.h5"

dec.save_weights(t1, overwrite=True)
weights = h5py.File(t1)
base_attrs = list(weights.attrs.items())
weights16 = {}
def print_attrs(name, obj):
  if type(obj) == h5py._hl.dataset.Dataset:
    weights16[name] = obj.value, list(obj.attrs.items())
  else:
    weights16[name] = None, list(obj.attrs.items())
weights.visititems(print_attrs)
weights.close()

out = h5py.File(t2, "w")
for k,v in base_attrs:
  out.attrs.create(k,v)
for name in weights16:
  val, attrs = weights16[name]
  if val is not None:
    out[name] = val.astype(np.float16)
  else:
    out.create_group(name)
  for k,v in attrs:
    out[name].attrs.create(k,v)
out.close()

dec.load_weights("/tmp/weights16.h5")
print(os.path.getsize(t1), os.path.getsize(t2))

In [None]:
# compute the number of bytes required
num_samples = 60000
model_bytes = len(gzip.compress(dec.to_json().encode('utf-8')))
model_bytes += len(gzip.compress(open("/tmp/weights16.h5", "rb").read()))
signal_bytes = len(gzip.compress(z_swag.data))
noise_bytes = dec.evaluate(z_swag/2, x_train, batch_size=1024) * z_swag.shape[0]
all_bytes = model_bytes + signal_bytes + noise_bytes

print("%.2f model + %.2f sig + %.2f noise = %.2f" %
      tuple([x/num_samples for x in [model_bytes, signal_bytes, noise_bytes, all_bytes]]))
print("total size is %f bytes" % all_bytes)

In [None]:
# show ground truth, argmax, entropy
figsize(16,16)
gt = np.concatenate(x_train[0:16], axis=1)
coding_probs = dec.predict(z_swag[0:16]/2).reshape((-1,28,28,256))
coding_probs = np.concatenate(coding_probs, axis=1)
mu = np.argmax(coding_probs, axis=2)
H = -np.sum(coding_probs * np.log2(coding_probs+1e-10), axis=-1)
imshow(np.concatenate([gt, mu, H*16], axis=0))

In [None]:
%%cython
from cpython cimport array
import array
import numpy as np
cimport numpy as np
from libc.stdint cimport uint64_t, int64_t, uint8_t

cdef class Coder(object):
  cdef uint64_t szmask,l,h,sz,val,precmult
  cdef unsigned int prec,ptr,count,nbytes,lx,decode
  cdef array.array ob
  cdef uint64_t splits[257]
  
  def __init__(self, ob=[]):
    self.ob = array.array('B', ob)
    
    self.nbytes = 6
    self.sz = (self.nbytes-1)*8
    self.szmask = (1L << self.sz) - 1
    
    # init these
    self.ptr = 0
    self.l = 0
    self.h = (1L << (self.nbytes*8))-1
    
    # read in first bytes
    if len(self.ob) > 0:
      self.decode = 1
      self.val = 0
      for i in range(self.nbytes):
        self.val = ((self.val & self.szmask) << 8) | self.ob[self.ptr]
        self.ptr += 1
    else:
      self.decode = 0
      self.val = -1
    
    # compute the splits for the probability distribution
    # the precision doesn't have to be perfect, just the same
    # note: 16 is too high here, the network isn't totally deterministic
    self.prec = 14
    self.precmult = ((1L<<self.prec) - 257)
    
    # number done so far
    self.count = 0

  def data(self):
    return self.ob

  cdef update_splits(self, np.ndarray[np.uint64_t, ndim=1] p):
    cdef uint64_t crange = self.h - self.l

    cdef int i
    cdef uint64_t cp = 0
    for i in range(0,257):
      self.splits[i] = crange*cp
      self.splits[i] >>= self.prec
      self.splits[i] += self.l
      
      if i != 256:
        # minimum is 1
        cp += p[i] + 1
      
    # if any splits are the same, we messed up
    # this bug isn't fixed, but increasing the bytes to 6 hides it
    #if np.any((splits[1:] - splits[:-1]) == 0):
    #  print(self.count, hex(self.l), hex(self.h), hex(self.h-self.l))
    #  assert False
      
  cdef int search_splits(self):
    cdef int i
    for i in range(0,257):
      if self.splits[i] > self.val:
        #assert i != 0
        return i-1
    #assert False
    return 256
  
  cdef push_bytes(self, decode=False):
    while self.l>>self.sz == self.h>>self.sz:
      #if self.l > self.h:
      #  print("ISSUE", hex(self.l), hex(self.h))
      if self.decode:
        #assert self.val >> self.sz == self.l >> self.sz
        self.val = ((self.val & self.szmask) << 8) | self.ob[self.ptr]
        self.ptr += 1
      else:
        #assert (self.l >> self.sz) < 256
        array.resize_smart(self.ob, self.ptr+1)
        self.ob[self.ptr] = self.l >> self.sz
        self.ptr += 1
        
      self.l = ((self.l & self.szmask) << 8)
      self.h = ((self.h & self.szmask) << 8) | 0xFF
      
  cdef void update(self, int x):
    self.l, self.h = self.splits[x]+1, self.splits[x+1]
    self.push_bytes()
    #assert self.l < self.h
      
  def flush(self):
    for i in range(self.nbytes):
      self.ob.append(self.l >> self.sz)
      self.l = ((self.l & self.szmask) << 8)
      
  def code(self, np.ndarray[np.float32_t, ndim=2] pf, np.ndarray[np.uint8_t, ndim=1] x):
    cdef int n = pf.shape[0]
    cdef int i
    cdef uint64_t[256] p
    
    # discretize all
    cdef np.ndarray[np.uint64_t, ndim=2] pn = (self.precmult*pf).astype(np.uint64)
    
    for i in range(0, n):
      self.update_splits(pn[i])
    
      if self.decode:
        x[i] = self.search_splits()

      self.update(x[i])
      self.count += 1

In [None]:
c = Coder()
DD = z_swag.shape[0]
BS = 1024
for i in range(0, DD, BS):
  st = time.clock()
  coding_probs = dec.predict(z_swag[i:i+BS]/2).reshape((-1,256))
  c.code(coding_probs, x_train[i:i+BS].flatten())
  #for p,x in zip(coding_probs, x_train[i:i+BS].flatten()):
  #  c.code(p, x)
  tlen = len(c.data())
  et = time.clock() - st
  print("%5d: total bytes: %7d  bytes per sample: %.3f  in %.2f ms" % (i, tlen, tlen/(i+BS), et*1000.))
c.flush()

In [None]:
import tarfile, io

def add_file(tt, nm, data):
  gzdata = gzip.compress(data)
  print("compressing %-15s with len %7d complen %7d" % (nm, len(data), len(gzdata)))
  tarinfo = tarfile.TarInfo(nm)
  tarinfo.size = len(data)
  tt.addfile(tarinfo, io.BytesIO(data))

tt = tarfile.open("/tmp/mnist_compressed.gz", "w:gz")
add_file(tt, "model.json", dec.to_json().encode('utf-8'))
add_file(tt, "model.h5", open("/tmp/weights16.h5", "rb").read())
add_file(tt, "signal.dat", z_swag.flatten().data)
add_file(tt, "noise.dat", c.data())
tt.close()

os.path.getsize("/tmp/mnist_compressed.gz")

In [None]:
# ********** decompression starts here **********

In [None]:
%load_ext Cython
%pylab inline

# for verification
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

import os
import numpy as np
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import tarfile
import tempfile
from tensorflow.keras.models import model_from_json

tt = tarfile.open("/tmp/mnist_compressed_2.gz")
model_json = tt.extractfile("model.json").read()

model = model_from_json(model_json)

with tempfile.NamedTemporaryFile() as tmpf:
  tmpf.write(tt.extractfile("model.h5").read())
  model.load_weights(tmpf.name)
  
z_swag = np.frombuffer(tt.extractfile("signal.dat").read(), dtype=np.uint8).reshape(-1, 64)
noise = tt.extractfile("noise.dat").read()

tt.close()

In [None]:
DD = zswag.shape[0]
BS = 2000

dc = Coder(noise)
for i in range(0, DD, BS):
  st = time.clock()
  coding_probs = model.predict(z_swag[i:i+BS]/2).reshape((-1,256))
  outs = np.zeros((coding_probs.shape[0],), np.uint8)
  dc.code(coding_probs, outs)
  if np.all(x_train[i:i+BS].flatten() != np.array(outs)):
    assert False
  et = time.clock() - st
  print("%5d: decoded successfully in %.2f ms" % (i, et*1000.))

figsize(8,8)
imshow(np.array(outs).reshape(-1, 28, 28)[0])