# JPEG (baseline) — Implementation from scratch

This notebook follows the workflow your professor wrote: preprocessing (resize, RGB→YCbCr, divide into 8×8 blocks), 8×8 DCT, quantization, zig-zag + RLE, Huffman entropy coding, then decoding & reconstruction. It's educational and not the full ISO JPEG bitstream.

Only `Pillow` and `matplotlib` are used for I/O and display.

## 1) Upload an input JPG (Colab / local)
Run the cell and choose a `.jpg` file.

In [None]:
from google.colab import files
uploaded = files.upload()
INPUT_JPG = list(uploaded.keys())[0]
print('Uploaded', INPUT_JPG)

: 

## 2) Read image, convert to RGB and then to YCbCr; optionally resize
We will convert image to 8-bit RGB then to YCbCr. For chroma subsampling we'll use 4:2:0.

In [None]:
from PIL import Image
import math
img = Image.open(INPUT_JPG).convert('RGB')
# Optional: resize to multiple of 16 for 4:2:0 and 8x8 blocks
W, H = img.size
newW = (W + 15) // 16 * 16
newH = (H + 15) // 16 * 16
if (newW, newH) != (W, H):
    img = img.resize((newW, newH))
    print(f'Resized to {newW}x{newH} to fit block boundaries.')
W, H = img.size
print('Image size:', W, H)

# Convert to YCbCr arrays (0..255)
img_ycbcr = img.convert('YCbCr')
pixels = list(img_ycbcr.getdata())
# Split channels
Y = [[pixels[y*W + x][0] for x in range(W)] for y in range(H)]
Cb = [[pixels[y*W + x][1] for x in range(W)] for y in range(H)]
Cr = [[pixels[y*W + x][2] for x in range(W)] for y in range(H)]
print('Prepared Y, Cb, Cr channels.')

## 3) Chroma subsampling 4:2:0 (average 2×2 blocks)
Reduce resolution of Cb/Cr by factor 2 horizontally and vertically.

In [None]:
def subsample_420(channel):
    h = len(channel); w = len(channel[0])
    h2 = h//2; w2 = w//2
    out = [[0]*w2 for _ in range(h2)]
    for y in range(h2):
        for x in range(w2):
            s = channel[2*y][2*x] + channel[2*y][2*x+1] + channel[2*y+1][2*x] + channel[2*y+1][2*x+1]
            out[y][x] = s//4
    return out

Cb_s = subsample_420(Cb)
Cr_s = subsample_420(Cr)
print('Cb/Cr subsampled to', len(Cb_s[0]), 'x', len(Cb_s))

## 4) Block-splitting into 8×8 blocks and level shift (subtract 128)
We'll create functions to iterate over blocks and apply level shift (JPEG uses values centered around 0).

In [None]:
def split_blocks(channel, block_size=8):
    h = len(channel); w = len(channel[0])
    blocks = []
    for by in range(0, h, block_size):
        for bx in range(0, w, block_size):
            block = [[channel[by + dy][bx + dx] for dx in range(block_size)] for dy in range(block_size)]
            blocks.append(block)
    return blocks

def merge_blocks(blocks, w, h, block_size=8):
    out = [[0]*w for _ in range(h)]
    idx = 0
    for by in range(0, h, block_size):
        for bx in range(0, w, block_size):
            block = blocks[idx]; idx += 1
            for dy in range(block_size):
                for dx in range(block_size):
                    out[by+dy][bx+dx] = block[dy][dx]
    return out

def level_shift_block(block):
    return [[block[y][x] - 128 for x in range(8)] for y in range(8)]

def inv_level_shift_block(block):
    return [[block[y][x] + 128 for x in range(8)] for y in range(8)]

# Split channels into blocks
Y_blocks = split_blocks(Y, 8)
Cb_blocks = split_blocks(Cb_s, 8)
Cr_blocks = split_blocks(Cr_s, 8)
print('Blocks:', len(Y_blocks), 'Y blocks,', len(Cb_blocks), 'Cb blocks')

## 5) 2D 8×8 DCT (classic JPEG DCT) and inverse DCT
We implement separable DCT using direct formulas (float). For an educational implementation this is fine.

In [None]:
import math

# Precompute DCT basis
N = 8
C = [[0]*N for _ in range(N)]
for u in range(N):
    for x in range(N):
        C[u][x] = math.cos((2*x+1) * u * math.pi / 16.0)

def dct_2d(block):
    # block assumed 8x8, level-shifted already
    out = [[0.0]*8 for _ in range(8)]
    for u in range(8):
        for v in range(8):
            s = 0.0
            for x in range(8):
                for y in range(8):
                    s += block[x][y] * C[u][x] * C[v][y]
            au = 1.0/math.sqrt(2.0) if u==0 else 1.0
            av = 1.0/math.sqrt(2.0) if v==0 else 1.0
            out[u][v] = 0.25 * au * av * s
    return out

def idct_2d(coeff):
    out = [[0.0]*8 for _ in range(8)]
    for x in range(8):
        for y in range(8):
            s = 0.0
            for u in range(8):
                for v in range(8):
                    au = 1.0/math.sqrt(2.0) if u==0 else 1.0
                    av = 1.0/math.sqrt(2.0) if v==0 else 1.0
                    s += au * av * coeff[u][v] * C[u][x] * C[v][y]
            out[x][y] = 0.25 * s
    return out

# Test DCT on a single block
b = level_shift_block(Y_blocks[0])
d = dct_2d(b)
re = idct_2d(d)
# check reconstruction error (should be tiny)
err = sum(abs(re[i][j] - b[i][j]) for i in range(8) for j in range(8))
print('DCT test error (single block):', err)

## 6) Quantization: standard JPEG quantization tables (luminance & chrominance)
We will use standard matrices and a quality scaling factor.

In [None]:
# Standard JPEG quantization tables (quality scale will be applied)
QY = [
[16,11,10,16,24,40,51,61],
[12,12,14,19,26,58,60,55],
[14,13,16,24,40,57,69,56],
[14,17,22,29,51,87,80,62],
[18,22,37,56,68,109,103,77],
[24,35,55,64,81,104,113,92],
[49,64,78,87,103,121,120,101],
[72,92,95,98,112,100,103,99],
]

QC = [
[17,18,24,47,99,99,99,99],
[18,21,26,66,99,99,99,99],
[24,26,56,99,99,99,99,99],
[47,66,99,99,99,99,99,99],
[99,99,99,99,99,99,99,99],
[99,99,99,99,99,99,99,99],
[99,99,99,99,99,99,99,99],
[99,99,99,99,99,99,99,99],
]

def scale_quant_table(Q, quality):
    # quality in 1..100; convert to scale factor as JPEG does
    if quality < 50:
        scale = 5000 / quality
    else:
        scale = 200 - quality*2
    Qs = [[max(1, min(255, (Q[i][j]*scale + 50)//100)) for j in range(8)] for i in range(8)]
    return Qs

quality = 75  # try changing to 50, 90, etc.
QY_s = scale_quant_table(QY, quality)
QC_s = scale_quant_table(QC, quality)
print('Using quality =', quality)

## 7) Zig-zag scan and run-length encoding (RLE) for AC coefficients
JPEG orders coefficients in zig-zag and encodes AC runs of zeros as (RUNLENGTH, SIZE) pairs; we'll implement a simple RLE that outputs (run, value) pairs and uses EOB marker.

In [None]:
zigzag_index = [
  (0,0),(0,1),(1,0),(2,0),(1,1),(0,2),(0,3),(1,2),
  (2,1),(3,0),(4,0),(3,1),(2,2),(1,3),(0,4),(0,5),
  (1,4),(2,3),(3,2),(4,1),(5,0),(6,0),(5,1),(4,2),
  (3,3),(2,4),(1,5),(0,6),(0,7),(1,6),(2,5),(3,4),
  (4,3),(5,2),(6,1),(7,0),(7,1),(6,2),(5,3),(4,4),
  (3,5),(2,6),(1,7),(2,7),(3,6),(4,5),(5,4),(6,3),
  (7,2),(7,3),(6,4),(5,5),(4,6),(3,7),(4,7),(5,6),
  (6,5),(7,4),(7,5),(6,6),(5,7),(6,7),(7,6),(7,7)
]

def zigzag_flat(coeff):
    flat = []
    for (i,j) in zigzag_index:
        flat.append(coeff[i][j])
    return flat

def rle_ac(flat_ac):
    # flat_ac is 63 AC coefficients (index 1..63)
    out = []
    run = 0
    for v in flat_ac[1:]:  # skip DC at index 0 externally
        if v == 0:
            run += 1
            if run == 16:  # JPEG uses ZRL for 16 zeros -> (15,0)
                out.append((15,0))
                run = 0
        else:
            out.append((run, v))
            run = 0
    if run != 0:
        out.append(('EOB', 0))
    return out

# Test on a sample block
coeff = dct_2d(level_shift_block(Y_blocks[0]))
flat = zigzag_flat(coeff)
print('Zigzag first 10:', flat[:10])
print('RLE sample:', rle_ac(flat))

## 8) DC differential coding and simple Huffman entropy coding
We will encode DC differences and RLE AC tuples into symbols, collect frequencies and build Huffman trees. For simplicity we construct one Huffman table for all symbols combined (not per category like baseline JPEG).

In [None]:
from collections import Counter, defaultdict
import heapq

# Simple Huffman implementation
class HuffmanNode:
    def __init__(self, symbol=None, freq=0, left=None, right=None):
        self.symbol = symbol; self.freq = freq; self.left = left; self.right = right
    def __lt__(self, other):
        return self.freq < other.freq

def build_huffman(freqs):
    heap = [HuffmanNode(sym, f) for sym, f in freqs.items()]
    heapq.heapify(heap)
    if len(heap) == 0:
        return {}
    while len(heap) > 1:
        a = heapq.heappop(heap); b = heapq.heappop(heap)
        node = HuffmanNode(None, a.freq + b.freq, a, b)
        heapq.heappush(heap, node)
    root = heap[0]
    codes = {}
    def walk(node, prefix):
        if node.symbol is not None:
            codes[node.symbol] = prefix or '0'
            return
        walk(node.left, prefix + '0')
        walk(node.right, prefix + '1')
    walk(root, '')
    return codes

# Symbolization helpers
def size_in_bits(x):
    if x == 0: return 0
    return x.bit_length() if x>0 else (-x).bit_length()

def amplitude_to_bits(v):
    # get magnitude category and bitstring for magnitude
    if v == 0:
        return 0, ''
    magnitude = abs(v)
    size = magnitude.bit_length()
    if v < 0:
        # two's complement-like sign representation used in JPEG: invert bits of magnitude-1
        magbits = format(magnitude, 'b')
        inv = ''.join('1' if c=='0' else '0' for c in magbits)
        return size, inv
    else:
        return size, format(magnitude, 'b')

# Collect symbols by scanning all blocks and storing DC diffs and AC (run,value)
def collect_symbols_for_huffman(Y_blocks_coeff, Cb_blocks_coeff, Cr_blocks_coeff):
    symbols = []
    # DC diffs for Y
    prev = 0
    for coeff in Y_blocks_coeff:
        flat = zigzag_flat(coeff)
        dc = int(round(flat[0]))
        diff = dc - prev
        prev = dc
        symbols.append(('DC', 'Y', diff))
        # AC RLE -> (run, val) tuples
        ac = rle_ac(flat)
        for t in ac:
            symbols.append(('AC', 'Y', t))
    # Cb
    prev = 0
    for coeff in Cb_blocks_coeff:
        flat = zigzag_flat(coeff)
        dc = int(round(flat[0]))
        diff = dc - prev; prev = dc
        symbols.append(('DC','Cb',diff))
        for t in rle_ac(flat):
            symbols.append(('AC','Cb',t))
    # Cr
    prev = 0
    for coeff in Cr_blocks_coeff:
        flat = zigzag_flat(coeff)
        dc = int(round(flat[0]))
        diff = dc - prev; prev = dc
        symbols.append(('DC','Cr',diff))
        for t in rle_ac(flat):
            symbols.append(('AC','Cr',t))
    # Convert to hashable keys (strings)
    freq = Counter()
    for s in symbols:
        freq_key = str(s)
        freq[freq_key] += 1
    return freq, symbols

# We'll build codes after computing coefficients (later).

## 9) Full encode pipeline: DCT -> Quantize -> Zigzag -> Symbol collection -> Huffman -> Bitstream
This cell performs end-to-end encoding using the functions above.

In [None]:
# Encode blocks: DCT -> quantize
def quantize_block(coeff, Q):
    return [[int(round(coeff[i][j] / Q[i][j])) for j in range(8)] for i in range(8)]

def dequantize_block(qcoeff, Q):
    return [[qcoeff[i][j] * Q[i][j] for j in range(8)] for i in range(8)]

# Compute coeffs for each block set
Y_coeffs = [dct_2d(level_shift_block(b)) for b in Y_blocks]
Cb_coeffs = [dct_2d(level_shift_block(b)) for b in Cb_blocks]
Cr_coeffs = [dct_2d(level_shift_block(b)) for b in Cr_blocks]

# Quantize (use QY for Y and QC for chroma)
Y_q = [quantize_block(c, QY_s) for c in Y_coeffs]
Cb_q = [quantize_block(c, QC_s) for c in Cb_coeffs]
Cr_q = [quantize_block(c, QC_s) for c in Cr_coeffs]

print('DCT + Quantization done for all blocks.')

# Build Huffman codes from symbols
freqs, symbols_list = collect_symbols_for_huffman(Y_q, Cb_q, Cr_q)
print('Unique symbol types:', len(freqs))
codes = build_huffman(freqs)
print('Built Huffman codes for', len(codes), 'symbols')

## 10) Serialize to a simple container and write compressed size
We'll build a bytestream using bit-level writing of Huffman codewords plus raw magnitude bits for amplitudes.

In [None]:
# BitWriter for building final stream
class BitWriterSimple:
    def __init__(self):
        self.buf = bytearray(); self.acc = 0; self.n = 0
    def write_bits_str(self, s):
        for ch in s:
            self.acc = (self.acc<<1) | (1 if ch=='1' else 0); self.n += 1
            if self.n == 8:
                self.buf.append(self.acc); self.acc = 0; self.n = 0
    def write_bits(self, val, length):
        for i in reversed(range(length)):
            bit = (val>>i)&1
            self.acc = (self.acc<<1) | bit; self.n += 1
            if self.n == 8:
                self.buf.append(self.acc); self.acc = 0; self.n = 0
    def flush(self):
        if self.n>0:
            self.buf.append(self.acc << (8-self.n))
            self.acc = 0; self.n = 0
        return bytes(self.buf)

# Helper to emit a symbol + additional magnitude bits
def emit_symbol(sym, bw, codes):
    key = str(sym)
    code = codes.get(key)
    if code is None:
        # fallback: literal encode string
        code = codes.get('UNKNOWN','0')
    bw.write_bits_str(code)
    # for DC and AC values, we need to also emit magnitude bits if nonzero
    # sym can be ('DC','Y',diff) or ('AC','Y', (run,value)) or ('AC','Y', 'EOB')
    if sym[0] == 'DC':
        diff = sym[2]
        if diff != 0:
            size = abs(diff).bit_length()
            # write size as fixed 6 bits (simple approach) followed by magnitude bits (signed representation)
            bw.write_bits(size, 6)
            mag = abs(diff)
            bw.write_bits(mag, size)
            # sign bit
            bw.write_bits(1 if diff<0 else 0, 1)
        else:
            bw.write_bits(0,6)  # size 0
    elif sym[0] == 'AC':
        t = sym[2]
        if t == 'EOB':
            pass
        elif isinstance(t, tuple):
            run, val = t
            # write run (4 bits), then magnitude like above
            bw.write_bits(run, 4)
            if val != 0:
                size = abs(val).bit_length()
                bw.write_bits(size, 6)
                bw.write_bits(abs(val), size)
                bw.write_bits(1 if val<0 else 0, 1)

# Build stream by traversing symbols_list in order and emitting        
bw = BitWriterSimple()
for s in symbols_list:
    emit_symbol(s, bw, codes)
stream = bw.flush()
print('Compressed stream length (bytes):', len(stream))

# Save a simple container: header + codebook + stream
import json
header = {'W':W,'H':H,'quality':quality,'QY':QY_s,'QC':QC_s}
with open('jpeg_simple.bin','wb') as f:
    h = json.dumps(header).encode('utf8')
    f.write(len(h).to_bytes(4,'big')); f.write(h); f.write(stream)
print('Saved jpeg_simple.bin')

## 11) Decode: read container, rebuild Huffman codes and decode bitstream
We will decode using reverse mapping and the same symbol parsing conventions.

In [None]:
# Build reverse codebook (binary string -> symbol string)
rev_codes = {v:k for k,v in codes.items()}

# BitReader for the stream
class BitReaderSimple:
    def __init__(self, data):
        self.data = data; self.i = 0; self.acc = 0; self.n = 0
    def read_bit(self):
        if self.n == 0:
            if self.i < len(self.data):
                self.acc = self.data[self.i]; self.i += 1; self.n = 8
            else:
                return 0
        self.n -= 1
        return (self.acc >> self.n) & 1
    def read_bits(self, k):
        v = 0
        for _ in range(k):
            v = (v<<1) | self.read_bit()
        return v

# Build prefix tree from codes for fast decoding
class PrefixNode:
    def __init__(self):
        self.left = None; self.right = None; self.symbol = None
root = PrefixNode()
for sym_str, code in codes.items():
    node = root
    for ch in code:
        if ch=='0':
            if node.left is None: node.left = PrefixNode()
            node = node.left
        else:
            if node.right is None: node.right = PrefixNode()
            node = node.right
    node.symbol = sym_str

# Now decode by walking tree bit-by-bit and then parsing magnitude bits as per our encoding
br = BitReaderSimple(stream)
decoded_symbols = []
while True:
    # walk prefix tree
    node = root
    while node.symbol is None:
        b = br.read_bit()
        node = node.left if b==0 else node.right
        if node is None:
            # reached end
            break
    if node is None or node.symbol is None:
        break
    sym_str = node.symbol
    # convert back to tuple via eval (safe here because tokens created by us)
    sym = eval(sym_str)
    # read magnitude bits if needed and reconstruct original sym tuple
    if sym[0] == 'DC':
        size = br.read_bits(6)
        diff = 0
        if size > 0:
            mag = br.read_bits(size)
            sign = br.read_bits(1)
            diff = -mag if sign==1 else mag
        sym = ('DC', sym[1], diff)
    elif sym[0] == 'AC':
        t = sym[2]
        if t == 'EOB':
            pass
        else:
            # t is a tuple placeholder like (run,0) where value 0 stands as placeholder; we need to read run and possibly val
            # but during building we stored exact tuple string, so use that representation directly
            # For simplicity, we assume codes included run bits and magnitude; here we reparse by reading run,magnitude as we encoded
            run = br.read_bits(4)
            size = br.read_bits(6)
            val = 0
            if size>0:
                mag = br.read_bits(size)
                sign = br.read_bits(1)
                val = -mag if sign==1 else mag
            sym = ('AC', sym[1], (run, val))
    decoded_symbols.append(sym)
    # Stop condition: until we've decoded same number of symbols as encoded
    if len(decoded_symbols) >= len(symbols_list):
        break

print('Decoded symbols count:', len(decoded_symbols))

## 12) Reconstruct quantized blocks from decoded symbols, inverse quantization, IDCT, upsample chroma, and merge channels
This reverses earlier steps and rebuilds the RGB image.

In [None]:
# Helper: rebuild blocks sequentially from decoded symbols
def rebuild_blocks_from_symbols(decoded_symbols, num_Y_blocks, num_C_blocks):
    # We'll parse decoded_symbols in the same order we created them
    Yq_blocks = []
    Cbq_blocks = []
    Crq_blocks = []
    idx = 0
    # Helper to read DC and then AC pairs into block coefficients
    def read_block(symbols, start_idx):
        # expects next symbol to be DC for that channel
        dc_sym = symbols[start_idx]
        assert dc_sym[0]=='DC'
        dc_val = dc_sym[2]
        start_idx += 1
        flat = [0]*64
        flat[0] = dc_val  # but dc_val was diff, so we'll fix later with cumulative
        # read AC until EOB or until we've filled
        i = 1
        while i < 64 and start_idx < len(symbols):
            s = symbols[start_idx]; start_idx += 1
            if s[0] == 'AC' and s[2] == 'EOB':
                break
            if s[0] == 'AC':
                run, val = s[2]
                i += run
                if i<64:
                    flat[i] = val
                    i += 1
        return flat, start_idx

    # Reconstruct Y blocks
    idx = 0
    prev = 0
    for _ in range(num_Y_blocks):
        flat, idx = read_block(decoded_symbols, idx)
        # DC was stored as diff, recover absolute
        flat[0] = prev + flat[0]; prev = flat[0]
        # reshape zigzag into 8x8
        block = [[0]*8 for _ in range(8)]
        for k,(i,j) in enumerate(zigzag_index):
            block[i][j] = flat[k]
        Yq_blocks.append(block)
    # Cb
    prev = 0
    for _ in range(num_C_blocks):
        flat, idx = read_block(decoded_symbols, idx)
        flat[0] = prev + flat[0]; prev = flat[0]
        block = [[0]*8 for _ in range(8)]
        for k,(i,j) in enumerate(zigzag_index):
            block[i][j] = flat[k]
        Cbq_blocks.append(block)
    # Cr
    prev = 0
    for _ in range(num_C_blocks):
        flat, idx = read_block(decoded_symbols, idx)
        flat[0] = prev + flat[0]; prev = flat[0]
        block = [[0]*8 for _ in range(8)]
        for k,(i,j) in enumerate(zigzag_index):
            block[i][j] = flat[k]
        Crq_blocks.append(block)
    return Yq_blocks, Cbq_blocks, Crq_blocks

num_Y = len(Y_q); num_C = len(Cb_q)
Yq_rec, Cbq_rec, Crq_rec = rebuild_blocks_from_symbols(decoded_symbols, num_Y, num_C)
print('Rebuilt quantized blocks counts:', len(Yq_rec), len(Cbq_rec), len(Crq_rec))

# Dequantize and IDCT
Y_rec_blocks = [inv_level_shift_block([[int(round(v)) for v in row] for row in dequantize_block(b, QY_s)]) for b in Yq_rec]
Cb_rec_blocks = [ [[int(round(v)) for v in row] for row in dequantize_block(b, QC_s)] for b in Cbq_rec]
Cr_rec_blocks = [ [[int(round(v)) for v in row] for row in dequantize_block(b, QC_s)] for b in Crq_rec]

# Merge chroma back to full resolution by upsampling 2x nearest (simple)
def upsample_420(channel_sub, W, H):
    out = [[0]*W for _ in range(H)]
    h2 = len(channel_sub); w2 = len(channel_sub[0])
    for y in range(h2):
        for x in range(w2):
            v = channel_sub[y][x]
            out[2*y][2*x] = v; out[2*y][2*x+1] = v; out[2*y+1][2*x] = v; out[2*y+1][2*x+1] = v
    return out

# Merge blocks back to 2D channels
Y_rec = merge_blocks(Y_rec_blocks, W, H, 8)
Cb_sub_rec = merge_blocks(Cb_rec_blocks, W//2, H//2, 8)
Cr_sub_rec = merge_blocks(Crq_rec, W//2, H//2, 8) if False else merge_blocks(Cr_rec_blocks, W//2, H//2, 8)  # fix name
Cb_rec = upsample_420(Cb_sub_rec, W, H)
Cr_rec = upsample_420(Cr_sub_rec, W, H)

# Compose YCbCr to RGB and save
from PIL import Image
out_img = Image.new('RGB', (W,H))
pixels_out = []
for y in range(H):
    for x in range(W):
        Yv = int(round(Y_rec[y][x])); Cbv = int(round(Cb_rec[y][x])); Crv = int(round(Cr_rec[y][x]))
        # clip
        Yv = max(0,min(255,Yv)); Cbv=max(0,min(255,Cbv)); Crv=max(0,min(255,Crv))
        # convert YCbCr to RGB (standard)
        r = int(round(Yv + 1.402 * (Crv-128)))
        g = int(round(Yv - 0.344136 * (Cbv-128) - 0.714136 * (Crv-128)))
        b = int(round(Yv + 1.772 * (Cbv-128)))
        r = max(0,min(255,r)); g = max(0,min(255,g)); b = max(0,min(255,b))
        pixels_out.append((r,g,b))
out_img.putdata(pixels_out)
out_img.save('reconstructed.jpg')
print('Saved reconstructed.jpg')

## 13) Display original vs reconstructed and PSNR
Run this cell to see the visual results and PSNR metric.

In [None]:
import matplotlib.pyplot as plt, math
orig = img.convert('RGB')
reco = Image.open('reconstructed.jpg').convert('RGB')
plt.figure(figsize=(10,5))
plt.subplot(1,2,1); plt.imshow(orig); plt.title('Original'); plt.axis('off')
plt.subplot(1,2,2); plt.imshow(reco); plt.title('Reconstructed'); plt.axis('off')
plt.show()

# PSNR
def psnr_img(a,b):
    a_pix = list(a.getdata()); b_pix=list(b.getdata())
    mse = 0.0
    for p,q in zip(a_pix,b_pix):
        for i in range(3):
            d = p[i]-q[i]; mse += d*d
    mse /= (len(a_pix)*3)
    if mse==0: return float('inf')
    return 10*math.log10((255*255)/mse)
print('PSNR:', psnr_img(orig,reco))