In [4]:
import numpy as np
import matplotlib.pyplot as plt

def lzw_compress(image, max_dict_size=4096):
    """
    Faster LZW compression for RGB NumPy images.
    Uses byte strings as keys for faster hashing.
    """
    # Flatten to bytes for faster operations
    pixels = image.reshape(-1, image.shape[-1])
    pixel_bytes = [px.tobytes() for px in pixels]
    
    # Initialize dictionary with unique pixel values
    unique_pixels = [pixels[i] for i in np.unique(pixels, axis=0, return_index=True)[1]]
    dictionary = {p.tobytes(): i for i, p in enumerate(unique_pixels)}
    dict_size = len(dictionary)
    
    w = b''
    compressed = []
    
    for pixel in pixel_bytes:
        wc = w + pixel
        if wc in dictionary:
            w = wc
        else:
            compressed.append(dictionary[w])
            if dict_size < max_dict_size:
                dictionary[wc] = dict_size
                dict_size += 1
            else:
                # Reset dictionary when limit reached
                dictionary = {p.tobytes(): i for i, p in enumerate(unique_pixels)}
                dict_size = len(dictionary)
            w = pixel
    
    if w:
        compressed.append(dictionary[w])
    
    return compressed, dictionary, unique_pixels


def lzw_decompress(compressed, dictionary, shape, unique_pixels, max_dict_size=4096):
    """
    Faster LZW decompression using byte string dictionary.
    """
    rev_dict = {v: k for k, v in dictionary.items()}
    dict_size = len(rev_dict)
    
    w = rev_dict[compressed[0]]
    result = [w]
    
    for k in compressed[1:]:
        if k in rev_dict:
            entry = rev_dict[k]
        elif k == dict_size:
            entry = w + w[:shape[-1]]
        else:
            raise ValueError("Invalid compressed code encountered")
        
        result.append(entry)
        
        if dict_size < max_dict_size:
            rev_dict[dict_size] = w + entry[:shape[-1]]
            dict_size += 1
        else:
            rev_dict = {v: k for k, v in dictionary.items()}
            dict_size = len(rev_dict)
        
        w = entry
    
    # Convert byte string back to array
    pixel_size = shape[-1]
    decoded = np.frombuffer(b''.join(result), dtype=np.uint8)
    return decoded.reshape(shape)


def calculate_lzw_compression_ratio(image, max_dict_size=4096):
    compressed, dictionary, unique_pixels = lzw_compress(image, max_dict_size)
    
    orig_bits = image.nbytes * 8
    compressed_bits = len(compressed) * 32
    dict_bits = len(dictionary) * 32
    total_bits = compressed_bits + dict_bits
    ratio = orig_bits / total_bits
    
    return ratio, compressed, dictionary, unique_pixels, total_bits, orig_bits


def visualize_lzw_compression(image, max_dict_size=256):
    ratio, compressed, dictionary, unique_pixels, total_bits, orig_bits = \
        calculate_lzw_compression_ratio(image, max_dict_size)
    decoded_img = lzw_decompress(compressed, dictionary, image.shape, unique_pixels, max_dict_size)
    
    print(f"\nLZW Compression (max dict size={max_dict_size})")
    print(f"Original bits: {orig_bits:,}")
    print(f"Compressed + dictionary bits: {total_bits:,}")
    print(f"Compression ratio: {ratio:.3f}")
    
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Original")
    plt.axis("off")
    
    plt.subplot(1, 2, 2)
    plt.imshow(decoded_img)
    plt.title(f"LZW (dict={max_dict_size})")
    plt.axis("off")
    
    plt.tight_layout()
    plt.show()

import matplotlib.image as mpimg

img = mpimg.imread("/mnt/769EC2439EC1FB9D/vsc_projs/DIP/kodim01.png")
if img.dtype != np.uint8:
    img = (img * 255).astype(np.uint8)

visualize_lzw_compression(img)

KeyboardInterrupt: 