In [None]:
import os
import math
from math import sqrt

import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

In [None]:
EPS = 1.e-7

**DISCLAIMER**

The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how neural compression works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result.

### Dataset

In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\{0, 1, \ldots, 16\}$.

The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc.

In [None]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode='train', transforms=None):
        digits = load_digits()
        if mode == 'train':
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == 'val':
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)
        
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

### Auxiliary code

Auxiliary code for running some parts, e.g., Causal Convolution 1D for the autoregressive model (ARM).

**Causal Convolution for ARM**

In [None]:
class CausalConv1d(nn.Module):
    """
    A causal 1D convolution.
    """

    def __init__(self, in_channels, out_channels, kernel_size, dilation, A=False, **kwargs):
        super(CausalConv1d, self).__init__()

        # attributes:
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.A = A
        
        self.padding = (kernel_size - 1) * dilation + A * 1

        # module:
        self.conv1d = torch.nn.Conv1d(in_channels, out_channels,
                                      kernel_size, stride=1,
                                      padding=0,
                                      dilation=dilation,
                                      **kwargs)

    def forward(self, x):
        x = torch.nn.functional.pad(x, (self.padding, 0))
        conv1d_out = self.conv1d(x)
        if self.A:
            return conv1d_out[:, :, : -1]
        else:
            return conv1d_out

### Neural Compression code

Please see the blogpost for details.

**Quantizer**

The quantizer is the crucial component of a neural compressor. It consists of a codebook, a vector of floats. It takes a real-valued input and replaces them with the closests values in the codebook.
Please note that we use a real-valued codebook, however, in practice, we can implement it using integers. As a result, we use $K$ bits.

In [None]:
class Quantizer(nn.Module):
    def __init__(self, input_dim, codebook_dim, temp=1.e7):
        super(Quantizer, self).__init__()
        
        #temperature for softmax
        self.temp = temp
        
        # dimensionality of the inputs and the codebook
        self.input_dim = input_dim
        self.codebook_dim = codebook_dim
        
        # codebook layer (a codebook)
        # - we initialize it uniformly
        # - we make it Parameter, namely, it is learnable
        self.codebook = nn.Parameter(torch.FloatTensor(1, self.codebook_dim,).uniform_(-1/self.codebook_dim, 1/self.codebook_dim))
    
    # A function for codebook indices (a one-hot representation) to values in the codebook.
    def indices2codebook(self, indices_onehot):
        return torch.matmul(indices_onehot, self.codebook.t()).squeeze()
    
    # A function to change integers to a one-hot representation.
    def indices_to_onehot(self, inputs_shape, indices):
        indices_hard = torch.zeros(inputs_shape[0], inputs_shape[1], self.codebook_dim)
        indices_hard.scatter_(2, indices, 1)
    
    # The forward function:
    # - First, distances are calculated between input values and codebook values.
    # - Second, indices (soft - differentiable, hard - non-differentiable) between the encoded values and the codebook values are calculated.
    # - Third, the quantizer returns indices and quantized code (the output of the encoder).
    # - Fourth, the decoder maps the quantized code to the obeservable space (i.e., it decodes the code back).
    def forward(self, inputs):
        # inputs - a matrix of floats, B x M
        inputs_shape = inputs.shape
        # repeat inputs
        inputs_repeat = inputs.unsqueeze(2).repeat(1, 1, self.codebook_dim)
        # calculate distances between input values and the codebook values
        distances = torch.exp(-torch.sqrt(torch.pow(inputs_repeat - self.codebook.unsqueeze(1), 2)))
        
        # indices (hard, i.e., nondiff)
        indices = torch.argmax(distances, dim=2).unsqueeze(2)
        indices_hard = self.indices_to_onehot(inputs_shape=inputs_shape, indices=indices)
        
        # indices (soft, i.e., diff)
        indices_soft = torch.softmax(self.temp * distances, -1)
        
        # quantized values: we use soft indices here because it allows backpropagation
        quantized = self.indices2codebook(indices_onehot=indices_soft)
        
        return (indices_soft, indices_hard, quantized)

**Encoder**

In [None]:
# The encoder is simply a neural network that takes an image and outputs a corresponding code.
class Encoder(nn.Module):
    def __init__(self, encoder_net):
        super(Encoder, self).__init__()

        self.encoder = encoder_net

    def encode(self, x):
        h_e = self.encoder(x)
        return h_e
    
    def forward(self, x):
        return self.encode(x)

**Decoder**

In [None]:
# The decoder is simply a neural network that takes a quantized code and returns an image.
class Decoder(nn.Module):
    def __init__(self, decoder_net):
        super(Decoder, self).__init__()

        self.decoder = decoder_net

    def decode(self, z):
        h_d = self.decoder(z)
        return h_d

    def forward(self, z, x=None):
        x_rec = self.decode(z)
        return x_rec

**Entropy Coding**

Entropy coding is the crucial step in the compression scheme. At this point, we have a quantized code that we want to transmit. In order to send the quantized code, which is typically represented by discerete (non-binary) symbols, it must be translated into a bitstream (a stream o bits).

An entropy coder assigns a unique prefix-free code (e.g., unique binary codes like Huffman codes) to each unique symbol that occurs in the input. Two of the most common entropy encoding techniques are Huffman coding and arithmetic coding that require knowing the (estimates of) probabilities of the symbols.

In the following code, we present a non-learnable, uniform distribution over symbols, a learnable, independent distributions over symbols (i.e., the product of categorical distributions) and an auto-regressive model for entropy coding.

In [None]:
class UniformEntropyCoding(nn.Module):
    def __init__(self, code_dim, codebook_dim):
        super(UniformEntropyCoding, self).__init__()
        self.code_dim = code_dim
        self.codebook_dim = codebook_dim
        
        self.probs = torch.softmax(torch.ones(1, self.code_dim, self.codebook_dim), -1)
    
    def sample(self, quantizer=None, B=10):
        code = torch.zeros(B, self.code_dim, self.codebook_dim)
        for b in range(B):
            indx = torch.multinomial(torch.softmax(self.probs, -1).squeeze(0), 1).squeeze()
            for i in range(self.code_dim):
                code[b,i,indx[i]] = 1
        
        code = quantizer.indices2codebook(code)
        return code
    
    def forward(self, z, x=None):
        p = torch.clamp(self.probs, EPS, 1. - EPS)
        return -torch.sum(z * torch.log(p), 2)

In [None]:
class IndependentEntropyCoding(nn.Module):
    def __init__(self, code_dim, codebook_dim):
        super(IndependentEntropyCoding, self).__init__()
        self.code_dim = code_dim
        self.codebook_dim = codebook_dim
        
        self.probs = nn.Parameter(torch.ones(1, self.code_dim, self.codebook_dim))
    
    def sample(self, quantizer=None, B=10):
        code = torch.zeros(B, self.code_dim, self.codebook_dim)
        for b in range(B):
            indx = torch.multinomial(torch.softmax(self.probs, -1).squeeze(0), 1).squeeze()
            for i in range(self.code_dim):
                code[b,i,indx[i]] = 1
        
        code = quantizer.indices2codebook(code)
        return code
    
    def forward(self, z, x=None):
        p = torch.clamp(torch.softmax(self.probs, -1), EPS, 1. - EPS)
        return -torch.sum(z * torch.log(p), 2)

In [None]:
class ARMEntropyCoding(nn.Module):
    def __init__(self, code_dim, codebook_dim, arm_net):
        super(ARMEntropyCoding, self).__init__()
        self.code_dim = code_dim
        self.codebook_dim = codebook_dim
        self.arm_net = arm_net # it takes B x 1 x code_dim and outputs B x codebook_dim x code_dim
    
    def f(self, x):
        h = self.arm_net(x.unsqueeze(1))
        h = h.permute(0, 2, 1)
        p = torch.softmax(h, 2)
        
        return p
    
    def sample(self, quantizer=None, B=10):
        x_new = torch.zeros((B, self.code_dim))
        
        for d in range(self.code_dim):
            p = self.f(x_new)
            indx_d = torch.multinomial(p[:, d, :], num_samples=1)
            codebook_value = quantizer.codebook[0, indx_d].squeeze()
            x_new[:, d] = codebook_value
        
        return x_new

    def forward(self, z, x):
        p = self.f(x)
        return -torch.sum(z * torch.log(p), 2)

**Full Neural Compressor**

In [None]:
class NeuralCompressor(nn.Module):
    def __init__(self, encoder, decoder, entropy_coding, quantizer, beta=1., detaching=False):
        super(NeuralCompressor, self).__init__()

        print('VAE by JT.')
            
        # we 
        self.encoder = encoder
        self.decoder = decoder
        self.entropy_coding = entropy_coding
        self.quantizer = quantizer
        
        # beta determines how strongly we focus on compression against reconstruction quality
        self.beta = beta
        
        # We can detach inputs to the rate, then we learn rate and distortion separately
        self.detaching = detaching

    def forward(self, x, reduction='avg'):
        # encoding
        #-non-quantized values
        z = self.encoder(x)
        #-quantizing
        quantizer_out = self.quantizer(z)
        
        # decoding
        x_rec = self.decoder(quantizer_out[2])
        
        # Distortion (e.g., MSE)
        Distortion = torch.mean(torch.pow(x - x_rec, 2), 1)
        
        # Rate: we use the entropy coding here
        Rate = torch.mean(self.entropy_coding(quantizer_out[0], quantizer_out[2]), 1)
        
        # Objective
        objective = Distortion + self.beta * Rate
        
        if reduction == 'sum':
            return objective.sum(), Distortion.sum(), Rate.sum()
        else:
            return objective.mean(), Distortion.mean(), Rate.mean()

### Auxiliary functions: training, evaluation, plotting

It's rather self-explanatory, isn't it?

In [None]:
def evaluation(test_loader, name=None, model_best=None, epoch=None):
    # EVALUATION
    if model_best is None:
        # load best performing model
        model_best = torch.load(name + '.model')

    model_best.eval()
    loss = 0.
    distortion = 0.
    rate = 0.
    N = 0.
    for indx_batch, test_batch in enumerate(test_loader):
        loss_t, distortion_t, rate_t = model_best.forward(test_batch, reduction='sum')
        loss = loss + loss_t.item()
        distortion = distortion + distortion_t.item()
        rate = rate + rate_t.item()
        N = N + test_batch.shape[0]
    loss = loss / N
    distortion = distortion/N
    rate = rate / N

    if epoch is None:
        print(f'FINAL LOSS: objective={loss} (distortion={distortion}, rate={rate})')
    else:
        print(f'Epoch: {epoch}, objective val={loss} (distortion={distortion}, rate={rate})')

    return loss, distortion, rate

def plot_curve(name, nll_val, metric_name='loss'):
    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')
    plt.xlabel('epochs')
    plt.ylabel(metric_name)
    plt.savefig(name + metric_name + '_val_curve.pdf', bbox_inches='tight')
    plt.close()

In [None]:
def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):
    objective_loss_val = []
    objective_distortion_val = []
    objective_rate_val = []
    loss_best = 1000.
    patience = 0

    # Main loop
    for e in range(num_epochs):
        # TRAINING
        model.train()
        for indx_batch, batch in enumerate(training_loader):
            if hasattr(model, 'dequantization'):
                if model.dequantization:
                    batch = batch + torch.rand(batch.shape)
            loss, _, _ = model.forward(batch)

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

        # Validation
        loss_val, distortion_val, rate_val = evaluation(val_loader, model_best=model, epoch=e)
        objective_loss_val.append(loss_val)  # save for plotting
        objective_distortion_val.append(distortion_val)  # save for plotting
        objective_rate_val.append(rate_val)  # save for plotting

        if e == 0:
            print('saved!')
            torch.save(model, name + '.model')
            loss_best = loss_val
        else:
            if loss_val < loss_best:
                print('saved!')
                torch.save(model, name + '.model')
                loss_best = loss_val
                patience = 0
            else:
                patience = patience + 1

        if patience > max_patience:
            break

    objective_loss_val = np.asarray(objective_loss_val)
    objective_distortion_val = np.asarray(objective_distortion_val)
    objective_rate_val = np.asarray(objective_rate_val)

    return objective_loss_val, objective_distortion_val, objective_rate_val

### Initialize dataloaders

In [None]:
train_data = Digits(mode='train')
val_data = Digits(mode='val')
test_data = Digits(mode='test')

training_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

### Hyperparams

In [None]:
entropy_coding_type = 'arm' # arm or indp or uniform
D = 64   # input dimension
C = 16  # code length
E = 8 # codebook size (i.e., the number of quantized values)
M = 256  # the number of neurons
M_kernels = 32 # the number of kernels in causal conv1d layers

# beta: how much we weight rate
if entropy_coding_type == 'uniform':
    beta = 0. 
else:
    beta = 1.

lr = 1e-3 # learning rate
num_epochs = 1000 # max. number of epochs
max_patience = 50 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped

In [None]:
result_dir = 'results/'
if not(os.path.exists(result_dir)):
    os.mkdir(result_dir)
name = 'neural_compressor_' + entropy_coding_type + '_C_' + str(C) + '_E_' + str(E)

### Initialize Neural Compressor

In [None]:
# ENCODER
encoder_net = nn.Sequential(nn.Linear(D, M*2), nn.BatchNorm1d(M*2), nn.ReLU(),
                            nn.Linear(M*2, M), nn.BatchNorm1d(M), nn.ReLU(),
                            nn.Linear(M, M//2), nn.BatchNorm1d(M//2), nn.ReLU(),
                            nn.Linear(M//2, C))

encoder = Encoder(encoder_net=encoder_net)

# DECODER
decoder_net = nn.Sequential(nn.Linear(C, M//2), nn.BatchNorm1d(M//2), nn.ReLU(),
                            nn.Linear(M//2, M), nn.BatchNorm1d(M), nn.ReLU(),
                            nn.Linear(M, M*2), nn.BatchNorm1d(M*2), nn.ReLU(),
                            nn.Linear(M*2, D))

decoder = Decoder(decoder_net=decoder_net)

# QUANTIZER
quantizer = Quantizer(input_dim=C, codebook_dim=E)

# ENTROPY CODING
if entropy_coding_type == 'uniform':
    entropy_coding = UniformEntropyCoding(code_dim=C, codebook_dim=E)
    
elif entropy_coding_type == 'indp':
    entropy_coding = IndependentEntropyCoding(code_dim=C, codebook_dim=E)

elif entropy_coding_type == 'arm':
    kernel = 4
    arm_net = nn.Sequential(
        CausalConv1d(in_channels=1, out_channels=M_kernels, dilation=1, kernel_size=kernel, A=True, bias=True),
        nn.LeakyReLU(),
        CausalConv1d(in_channels=M_kernels, out_channels=M_kernels, dilation=1, kernel_size=kernel, A=False, bias=True),
        nn.LeakyReLU(),
        CausalConv1d(in_channels=M_kernels, out_channels=E, dilation=1, kernel_size=kernel, A=False, bias=True))

    entropy_coding = ARMEntropyCoding(code_dim=C, codebook_dim=E, arm_net=arm_net)

# MODEL
model = NeuralCompressor(encoder=encoder, decoder=decoder, entropy_coding=entropy_coding, quantizer=quantizer, beta=beta)

### Let's play! Training

In [None]:
# OPTIMIZER
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)

In [None]:
# Training procedure
objective_loss_val, objective_distortion_val, objective_rate_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, 
                   optimizer=optimizer,
                   training_loader=training_loader, val_loader=val_loader)

In [None]:
test_loss, test_distortion, test_rate = evaluation(name=result_dir + name, test_loader=test_loader)
f = open(result_dir + name + '_test_loss.txt', "w")
f.write(str(test_loss) + ', ' + str(test_distortion) + ', ' + str(test_rate))
f.close()

plot_curve(result_dir + name + '_objective_', objective_loss_val, metric_name='objective')
plot_curve(result_dir + name + '_distortion_', objective_distortion_val, metric_name='distortion')
plot_curve(result_dir + name + '_rate_', objective_rate_val, metric_name='rate')

### Qualitative inspection

Here, we visualize samples and reconstructions. 

In [None]:
# We specifies ids of images from the test set.
IMG_IDs = [110, 120, 130, 140]

# samples
z_sampled = model.entropy_coding.sample(quantizer=model.quantizer, B=9)
x_sampled = model.decoder(z_sampled)

# reconstructions
x_real = torch.from_numpy(test_data.__getitem__(IMG_IDs))
x_rec = model.decoder(model.quantizer(model.encoder(x_real))[-1])

# plotting
fig, axs = plt.subplots(4, 3, figsize=(6, 8))
i = 0
for i in range(len(IMG_IDs)):
    axs[i,0].imshow(x_real[i].reshape(8,8).detach().numpy())
    axs[i,0].set_title('original')
    axs[i,0].axis('off')
    
    axs[i,1].imshow(x_rec[i].reshape(8,8).detach().numpy())
    axs[i,1].set_title('reconstruction')
    axs[i,1].axis('off')
    
    axs[i,2].imshow(x_sampled[i].squeeze().reshape(8,8).detach().numpy())
    axs[i,2].set_title('sample')
    axs[i,2].axis('off')

plt.savefig(result_dir + name + 'recon_sample.pdf', bbox_inches='tight')