In [1]:
from collections import Counter
from heapq import *
from collections import namedtuple
import numpy as np
import torch.nn.functional as F
import torch
import re

# load some text
with open('shakespeare.txt') as f:
    text = f.read()

# get rid of a few characters that break test coverage
text = re.sub('[$&3]', '', text)
test_split_idx = len(text)//3 * 2
train_text = text[:test_split_idx]
test_text = text[test_split_idx:]
assert set(train_text) == set(test_text)

# something short for testing
short_test_text = test_text[13:1000]


stoi = {char: i for i,char in enumerate(set(train_text))}
# remove this
stoi['>'] = len(stoi)
itos = {i: char for char, i in stoi.items()}


n_char_train = len(train_text)

abs_frequencies = Counter(train_text)
frequencies = {k: v/n_char_train for k,v in abs_frequencies.items()}

# Build Huffman code
This is done using a priority queue. The two elements with the lowest counts are merged and inserted back into the queue with added counts.
The most frequent characters will get the shortest codes

In [2]:
Node = namedtuple('Node', ['left', 'right', 'chars'])

def huffman_coding(frequencies):
    """
    Builds a Huffman tree:
    Args:
        -frequences: Dictionary of characters and their frequencies
    Returns:
        codes: dictonary from characters to their binary codes for encoding
        root: root node of the Huffman tree, for decoding
    """
    codes = {k: '' for k in frequencies}

    h = []
    for char, freq in frequencies.items():
        heappush(h, (freq, Node(None, None, char)))

    while True:
        freq1, node1 = heappop(h)
        if not h:
            break
        freq2, node2 = heappop(h)

        for char in node1.chars:
            codes[char] += ('0')
        for char in node2.chars:
            codes[char] += ('1')

        parent = Node(node1, node2, node1.chars+node2.chars)
        parent_freq = freq1+freq2
        heappush(h, (parent_freq, parent))

    codes = {k: v[::-1] for k,v in codes.items()}
    root = parent

    return codes, root

codes, root = huffman_coding(frequencies)
print(codes)

{'F': '011011011', 'i': '10111', 'r': '0000', 's': '0001', 't': '1000', ' ': '110', 'C': '01101100', 'z': '111101010110', 'e': '1110', 'n': '11111', ':': '1001111', '\n': '10100', 'B': '101100110', 'f': '011010', 'o': '0111', 'w': '100110', 'p': '1011000', 'c': '010111', 'd': '01100', 'a': '0100', 'y': '101011', 'u': '00110', 'h': '0010', ',': '101010', 'm': '101101', 'k': '0101001', '.': '0101011', 'A': '0101010', 'l': '10010', 'S': '01101111', 'Y': '011011010', 'v': '0101101', '?': '100111001', 'R': '11110100', 'M': '111101011', 'W': '01011001', "'": '11110110', 'L': '01010000', 'I': '1111001', 'N': '10110010', 'g': '001111', ';': '01010001', 'b': '1111000', '!': '011011101', 'O': '10011101', 'j': '11110101001', 'V': '11110101010', '-': '010110001', 'T': '0011100', 'H': '00111010', 'E': '11110111', 'U': '00111011', 'D': '100111000', 'P': '0110111000', 'q': '01101110011', 'x': '01101110010', 'J': '111101010111', 'G': '101100111', 'K': '010110000', 'Q': '111101010001', 'Z': '1111010100

## One-gram model
Here we see that the Huffman code gives us a bit of advantage over constant length coding.

In [3]:
expected_code_length = sum(len(codes[k]) * frequencies[k] for k in frequencies)
expected_code_length

frequencies_np = np.array(list(frequencies.values()))
entropy = -(np.log2(frequencies_np) * frequencies_np).sum()

print('Entropy: ', entropy)
print('Expected code length: ', expected_code_length)
print('Constant code length: ', np.log2(len(frequencies)))

Entropy:  4.7833957052643425
Expected code length:  4.82071185920971
Constant code length:  5.954196310386875


In [4]:
# encoding is easy here. Just join the codes for each character in the text.
def encode(text, codes):
    '''Encode text using the code we created
    '''
    return ''.join(codes[k] for k in text)

encoded_text = encode(short_test_text, codes)

# assert that encoded text is binary
assert len(set(encoded_text)) == 2
print(f'Encoded length: ', len(encoded_text))
print(encoded_text[:100], '...')

Encoded length:  4699
0101010000111010110000100000100011110000111001011101111001001110011000001011111111001111010101110100 ...


In [5]:
def decode(encoded_text, root):
    """Decoding encoded text given a huffman tree. 
    Goes down the tree till a leave is reached"""
    decoded_text = ''
    
    current_node = root
    for bit in encoded_text:
        if current_node.left is None:
            decoded_text += current_node.chars
            current_node = root
        if bit == '0':
            current_node = current_node.left
        else:
            current_node = current_node.right
    
    decoded_text += current_node.chars
    return decoded_text

# now we decode the text and check if we can reconstruct the input
decoded_text = decode(encoded_text, root)
assert decoded_text == short_test_text

# this should be in the range of the entropy
print('BPC: ', len(encoded_text) / len(short_test_text))

BPC:  4.760891590678825


## LSTM encoding/decoding

In [54]:
# define the charRNN
# TODO: test set, make it actually generalize

class CharRNN(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, batch_first=True):
        super(CharRNN, self).__init__()
        self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first)
        self.linear = torch.nn.Linear(hidden_size, input_size)
    
    def forward(self, inputs, hidden=None):
        x = F.one_hot(inputs, num_classes=len(stoi)).float()
        x, hidden = self.lstm(x, hidden)
        x = self.linear(x)
        return x, hidden

n_block = 64
net = CharRNN(len(stoi), 64, 2)
opt = torch.optim.Adam(net.parameters())

In [66]:
# train it
for i in range(5000):
    indices = np.random.randint(n_char_train - n_block, size=128)
    batch_clear = [train_text[i:i+n_block] for i in indices]
    batch_input = [sample[:-1] for sample in batch_clear]
    batch_target = [sample[1:] for sample in batch_clear]
    
    inputs = torch.tensor([[stoi[c] for c in sample] for sample in batch_input])
    targets = torch.tensor([[stoi[c] for c in sample] for sample in batch_target])
    
    opt.zero_grad()
    logits, hidden = net(inputs)
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.reshape(-1))
    loss.backward()
    opt.step()
    # convert loss to bits
    print(loss.item() / np.log(2), end='\r')

2.6810500990338373

KeyboardInterrupt: 

In [67]:
print(loss.item() / np.log(2), end='\r')

2.6341760555585276

In [68]:
encoded_text = ''
# first letter must be encoded without context
codes, root = huffman_coding(frequencies)
encoded_text += codes[short_test_text[0]]

text_num = torch.tensor([[stoi[c] for c in short_test_text[:-1]]])
logits, _ = net(text_num)
probs = F.softmax(logits, -1).squeeze()

for i in range(probs.shape[0]):
    frequencies_step = probs[i]
    frequencies_step = {itos[j]: frequencies_step[j] for j in range(frequencies_step.shape[0])}
    codes, root = huffman_coding(frequencies_step)
    encoded_text += codes[short_test_text[i+1]]
    
len(encoded_text)

2961

In [69]:
def decode_char(encoded_text, root):
    current_node = root
    for bit in encoded_text:
        if bit == '0':
            current_node = current_node.left
        elif bit == '1':
            current_node = current_node.right
        else:
            raise ValueError
        
        if current_node.left is None:
            return current_node.chars
    return '!'
# decode first letter
index = 0
decoded_text = ''
codes, root = huffman_coding(frequencies)
current_char = decode_char(encoded_text[index:], root)

decoded_text += current_char
index += len(codes[current_char])

hidden = None
while True:
    batch_num = torch.tensor(stoi[current_char]).view(1,1)
    logits, hidden = net(batch_num, hidden)
    probs = F.softmax(logits, -1).squeeze()
    frequencies_step = {itos[j]: probs[j].item() for j in range(probs.shape[0])}
    codes, root = huffman_coding(frequencies_step)
    current_char = decode_char(encoded_text[index:], root)
    
    decoded_text += current_char
    index += len(codes[current_char])
    
    if index >= len(encoded_text):
        break

print(decoded_text[:100])

print(len(encoded_text) / len(short_test_text))

As passes colouring.
Dear gentlewoman,
How fares our gracious lady?

EMILIA:
As well as one so great
3.0


In [73]:
779 * 8 / len(short_test_text)

6.314083080040527