In [1]:
import numpy
from itertools import chain
from collections import Counter
from PIL import Image, ImageChops

In [2]:
class InputBitStream: 
	def __init__(self, file_name): 
		self.file_name = file_name
		self.file = open(self.file_name, 'rb') 
		self.bytes_read = 0
		self.buffer = []

	def read_bit(self):
		return self.read_bits(1)[0]

	def read_bits(self, count):
		while len(self.buffer) < count:
			self._load_byte()
		result = self.buffer[:count]
		self.buffer[:] = self.buffer[count:]
		return result

	def flush(self):
		assert(not any(self.buffer))
		self.buffer[:] = []

	def _load_byte(self):
		value = ord(self.file.read(1))
		self.buffer += pad_bits(to_binary_list(value), 8)
		self.bytes_read += 1

	def close(self): 
		self.file.close()

In [3]:
class OutputBitStream: 
	def __init__(self, file_name): 
		self.file_name = file_name
		self.file = open(self.file_name, 'wb') 
		self.bytes_written = 0
		self.buffer = []

	def write_bit(self, value):
		self.write_bits([value])

	def write_bits(self, values):
		self.buffer += values
		while len(self.buffer) >= 8:
			self._save_byte()        

	def flush(self):
		if len(self.buffer) > 0: # Add trailing zeros to complete a byte and write it
			self.buffer += [0] * (8 - len(self.buffer))
			self._save_byte()
		assert(len(self.buffer) == 0)

	def _save_byte(self):
		bits = self.buffer[:8]
		self.buffer[:] = self.buffer[8:]

		byte_value = from_binary_list(bits)
		self.file.write(bytes([byte_value]))
		self.bytes_written += 1

	def close(self): 
		self.flush()
		self.file.close()

In [4]:
def raw_size(width, height):
	header_size = 2 * 16 # height and width as 16 bit values
	pixels_size = 3 * 8 * width * height # 3 channels, 8 bits per channel
	return (header_size + pixels_size) / 8

In [5]:
def images_equal(file_name_a, file_name_b):
	image_a = Image.open(file_name_a)
	image_b = Image.open(file_name_b)

	diff = ImageChops.difference(image_a, image_b)

	return diff.getbbox() is None

In [6]:
def PSNR(X1, X2, max_pixel=255):
    mse = numpy.mean((X1 - X2)**2)
    if mse > 0:
        return 20 * numpy.log10(max_pixel / mse ** .5)
    else:
        return numpy.inf

In [7]:
def count_symbols(image):
	pixels = image.getdata()
	values = chain.from_iterable(pixels)
	counts = Counter(values).items()
	return sorted(counts, key=lambda x: x[::-1])

In [8]:
def build_tree(counts) :
	nodes = [entry[::-1] for entry in counts] # Reverse each (symbol,count) tuple
	while len(nodes) > 1 :
		leastTwo = tuple(nodes[0:2]) # get the 2 to combine
		theRest = nodes[2:] # all the others
		combFreq = leastTwo[0][0] + leastTwo[1][0]  # the branch points freq
		nodes = theRest + [(combFreq, leastTwo)] # add branch point to the end
		nodes.sort(key=lambda t: t[0]) # sort it into place
	return nodes[0]  # Return the single tree inside the list

In [9]:
def trim_tree(tree) :
	p = tree[1] # Ignore freq count in [0]
	if type(p) is tuple: # Node, trim left then right and recombine
		return (trim_tree(p[0]), trim_tree(p[1]))
	return p  # Leaf, just return it

In [10]:
def assign_codes_impl(codes, node, pat):
	if type(node) == tuple:
		assign_codes_impl(codes, node[0], pat + [0]) # Branch point. Do the left branch
		assign_codes_impl(codes, node[1], pat + [1]) # then do the right branch.
	else:
		codes[node] = pat # A leaf. set its code

In [11]:
def assign_codes(tree):
	codes = {}
	assign_codes_impl(codes, tree, [])
	return codes

In [12]:
def to_binary_list(n):
	"""Convert integer into a list of bits"""
	return [n] if (n <= 1) else to_binary_list(n >> 1) + [n & 1]

In [13]:
def from_binary_list(bits):
	"""Convert list of bits into an integer"""
	result = 0
	for bit in bits:
		result = (result << 1) | bit
	return result

In [14]:
def pad_bits(bits, n):
	"""Prefix list of bits with enough zeros to reach n digits"""
	assert(n >= len(bits))
	return ([0] * (n - len(bits)) + bits)

In [15]:
def compressed_size(counts, codes):
	header_size = 2 * 16 # height and width as 16 bit values

	tree_size = len(counts) * (1 + 8) # Leafs: 1 bit flag, 8 bit symbol each
	tree_size += len(counts) - 1 # Nodes: 1 bit flag each
	if tree_size % 8 > 0: # Padding to next full byte
		tree_size += 8 - (tree_size % 8)

	# Sum for each symbol of count * code length
	pixels_size = sum([count * len(codes[symbol]) for symbol, count in counts])
	if pixels_size % 8 > 0: # Padding to next full byte
		pixels_size += 8 - (pixels_size % 8)

	return (header_size + tree_size + pixels_size) / 8

In [16]:
def encode_header(image, bitstream):
	height_bits = pad_bits(to_binary_list(image.height), 16)
	bitstream.write_bits(height_bits)    
	width_bits = pad_bits(to_binary_list(image.width), 16)
	bitstream.write_bits(width_bits)

In [17]:
def encode_tree(tree, bitstream):
	if type(tree) == tuple: # Note - write 0 and encode children
		bitstream.write_bit(0)
		encode_tree(tree[0], bitstream)
		encode_tree(tree[1], bitstream)
	else: # Leaf - write 1, followed by 8 bit symbol
		bitstream.write_bit(1)
		symbol_bits = pad_bits(to_binary_list(tree), 8)
		bitstream.write_bits(symbol_bits)

In [18]:
def encode_pixels(image, codes, bitstream):
	for pixel in image.getdata():
		for value in pixel:
			bitstream.write_bits(codes[value])

In [19]:
def compress_image(in_file_name, out_file_name):
	print('Compressing "%s" -> "%s"' % (in_file_name, out_file_name))
	image = Image.open(in_file_name)
	print('Image shape: (height=%d, width=%d)' % (image.height, image.width))
	size_raw = raw_size(image.height, image.width)
	print('RAW image size: %d bytes' % size_raw)
	counts = count_symbols(image)
	print('Counts: %s' % counts)
	tree = build_tree(counts)
	print('Tree: %s' % str(tree))
	trimmed_tree = trim_tree(tree)
	print('Trimmed tree: %s' % str(trimmed_tree))
	codes = assign_codes(trimmed_tree)
	print('Codes: %s' % codes)

	size_estimate = compressed_size(counts, codes)
	print('Estimated size: %d bytes' % size_estimate)

	print('Writing...')
	stream = OutputBitStream(out_file_name)
	print('* Header offset: %d' % stream.bytes_written)
	encode_header(image, stream)
	stream.flush() # Ensure next chunk is byte-aligned
	print('* Tree offset: %d' % stream.bytes_written)
	encode_tree(trimmed_tree, stream)
	stream.flush() # Ensure next chunk is byte-aligned
	print('* Pixel offset: %d' % stream.bytes_written)
	encode_pixels(image, codes, stream)
	stream.close()

	size_real = stream.bytes_written
	print('Wrote %d bytes.' % size_real)

	print('Estimate is %scorrect.' % ('' if size_estimate == size_real else 'in'))
	print('Compression ratio: %0.2f' % (float(size_raw) / size_real))

In [20]:
def decode_header(bitstream):
	height = from_binary_list(bitstream.read_bits(16))
	width = from_binary_list(bitstream.read_bits(16))
	return (height, width)

In [21]:
def decode_tree(bitstream):
	flag = bitstream.read_bits(1)[0]
	if flag == 1: # Leaf, read and return symbol
		return from_binary_list(bitstream.read_bits(8))
	left = decode_tree(bitstream)
	right = decode_tree(bitstream)
	return (left, right)

In [22]:
def decode_value(tree, bitstream):
	bit = bitstream.read_bits(1)[0]
	node = tree[bit]
	if type(node) == tuple:
		return decode_value(node, bitstream)
	return node

In [23]:
def decode_pixels(height, width, tree, bitstream):
	pixels = bytearray()
	for i in range(height * width * 3):
		pixels.append(decode_value(tree, bitstream))
	return Image.frombytes('RGB', (width, height), bytes(pixels))

In [24]:
def decompress_image(in_file_name, out_file_name):
	print('Decompressing "%s" -> "%s"' % (in_file_name, out_file_name))

	print('Reading...')
	stream = InputBitStream(in_file_name)
	print('* Header offset: %d' % stream.bytes_read)
	height, width = decode_header(stream)
	stream.flush() # Ensure next chunk is byte-aligned
	print('* Tree offset: %d' % stream.bytes_read)    
	trimmed_tree = decode_tree(stream)
	stream.flush() # Ensure next chunk is byte-aligned
	print('* Pixel offset: %d' % stream.bytes_read)
	image = decode_pixels(height, width, trimmed_tree, stream)
	stream.close()
	print('Read %d bytes.' % stream.bytes_read)

	print('Image size: (height=%d, width=%d)' % (height, width))
	print('Trimmed tree: %s' % str(trimmed_tree))
	image.save(out_file_name)

In [25]:
compress_image("tiger.bmp", "tiger.bin")

Compressing "tiger.bmp" -> "tiger.bin"
Image shape: (height=354, width=630)
RAW image size: 669064 bytes
Counts: [(246, 243), (242, 254), (243, 256), (240, 257), (241, 272), (245, 285), (244, 290), (237, 294), (235, 295), (238, 300), (239, 301), (236, 305), (247, 306), (233, 307), (234, 313), (248, 316), (230, 329), (232, 340), (225, 351), (249, 355), (231, 360), (227, 363), (229, 363), (226, 364), (255, 367), (228, 373), (250, 381), (223, 391), (215, 405), (219, 406), (224, 406), (222, 407), (218, 415), (220, 427), (221, 429), (214, 434), (251, 436), (213, 452), (217, 460), (208, 464), (216, 467), (253, 468), (206, 481), (211, 481), (210, 484), (254, 490), (209, 491), (212, 494), (252, 501), (207, 503), (204, 512), (205, 518), (200, 530), (198, 534), (201, 554), (194, 559), (202, 561), (196, 563), (203, 572), (197, 574), (191, 580), (192, 584), (195, 587), (193, 596), (190, 597), (185, 609), (199, 622), (189, 628), (188, 631), (187, 633), (184, 650), (186, 655), (181, 684), (183, 697)

In [26]:
decompress_image("tiger.bin", "tiger2.bmp")

Decompressing "tiger.bin" -> "tiger2.bmp"
Reading...
* Header offset: 0
* Tree offset: 4
* Pixel offset: 324
Read 617757 bytes.
Image size: (height=354, width=630)
Trimmed tree: (((((((20, 24), 9), (((108, ((217, 208), 166)), 22), (21, 18))), (((23, (107, (161, (216, 253)))), (25, 81)), ((82, 15), (((162, (206, 211)), 105), 16)))), ((((26, (((210, 254), 163), ((209, 212), 160))), (27, 80)), (((106, (((246, 242), 252), 156)), 28), (79, (104, (153, 159))))), (((14, (((207, 204), 155), 103)), (77, ((((243, 240), 205), 158), 102))), ((30, 29), (((154, 151), (157, (200, 198))), 78))))), (((((31, (101, (149, 152))), 8), ((13, 32), (((147, 145), ((201, (241, 245)), 148)), 76))), ((((100, 99), 34), (33, (((194, 202), (196, 203)), (150, 146)))), ((75, 74), ((97, 98), (((197, 191), 142), ((192, (244, 237)), (195, (235, 238)))))))), ((((35, ((139, 141), 96)), (12, (((193, 190), 143), (144, ((239, 236), 185))))), ((36, 73), (37, 71))), ((0, ((95, (((247, 233), 199), (189, (234, 248)))), 72)), (1, 