# 5. GPT-2

## 1) Importy i konfiguracja

In [1]:
import os
import struct
import time
import constriction
import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

TEST_PATH = "../data/canterbury_small.bin"
COMPRESSED_PATH = "../out/compressed_gpt2.bin"
DECOMPRESSED_PATH = "../out/decompressed_gpt2.bin"

CONTEXT_SIZE = 1024
STRIDE = 512
MODEL_NAME = "gpt2"

  from .autonotebook import tqdm as notebook_tqdm


Device: cpu


## 2) DataLoader

In [2]:
def load_model(model_name=MODEL_NAME):
    print(f"Loading {model_name}...")
    tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)
    model.eval()
    model.to(DEVICE)
    params = sum(p.numel() for p in model.parameters())
    print(f"Loaded: {params:,} parameters on {DEVICE}")
    return model, tokenizer

## 3) Model

In [None]:
def _get_next_probs(model, input_ids, past_kv=None):
    with torch.no_grad():
        out = model(input_ids, past_key_values=past_kv, use_cache=True)
        # Batch, token, vocab
        logits = out.logits[:, -1, :]
        probs = F.softmax(logits, dim=-1).cpu().numpy().astype(np.float64)[0]
        probs = np.clip(probs, 1e-9, 1.0)
        probs /= probs.sum()
    return probs, out.past_key_values


def _reset_context(model, context_tokens):
    input_ids = torch.tensor([context_tokens], device=DEVICE)
    with torch.no_grad():
        out = model(input_ids, use_cache=True)
        logits = out.logits[:, -1, :]
        probs = F.softmax(logits, dim=-1).cpu().numpy().astype(np.float64)[0]
        probs = np.clip(probs, 1e-9, 1.0)
        probs /= probs.sum()
    return probs, out.past_key_values, len(context_tokens)


def print_time_profile(name, timings, total_time):
    print(f"\n{name} time profile:")
    if total_time <= 0:
        print("  No timing data")
        return

    items = sorted(timings.items(), key=lambda x: x[1], reverse=True)
    tracked = 0.0
    for key, value in items:
        tracked += value
        pct = (value / total_time) * 100
        print(f"  {key:<16} {value:>8.4f}s  ({pct:>6.2f}%)")

    remaining = max(total_time - tracked, 0.0)
    if remaining > 1e-9:
        pct = (remaining / total_time) * 100
        print(f"  {'other':<16} {remaining:>8.4f}s  ({pct:>6.2f}%)")

## 4) Trening

In [4]:
def compress_file(model, tokenizer, input_path, output_path):
    model.eval()
    start_time = time.perf_counter()

    timings = {
        "read_bytes": 0.0,
        "tokenize": 0.0,
        "model_infer": 0.0,
        "arith_encode": 0.0,
        "write_file": 0.0,
    }

    t0 = time.perf_counter()
    with open(input_path, "rb") as f:
        raw_bytes = f.read()
    timings["read_bytes"] += time.perf_counter() - t0

    t0 = time.perf_counter()
    text = raw_bytes.decode("utf-8", errors="replace")
    tokens = tokenizer.encode(text)
    timings["tokenize"] += time.perf_counter() - t0

    num_tokens = len(tokens)
    original_size = len(raw_bytes)

    if tokenizer.decode(tokens).encode("utf-8") != raw_bytes:
        print("⚠️ Tokenization roundtrip not perfect - minor differences possible")

    print(f"Original: {original_size:,} bytes → {num_tokens:,} tokens")

    encoder = constriction.stream.queue.RangeEncoder()
    bos_id = tokenizer.eos_token_id

    past_kv = None
    kv_len = 0
    all_seen = [bos_id]

    for i in tqdm(range(num_tokens), desc="Compressing"):
        t_model = time.perf_counter()
        if kv_len >= CONTEXT_SIZE:
            context = all_seen[-STRIDE:]
            probs, past_kv, kv_len = _reset_context(model, context)
        else:
            input_ids = torch.tensor([[all_seen[-1]]], device=DEVICE)
            probs, past_kv = _get_next_probs(model, input_ids, past_kv)
            kv_len += 1
        timings["model_infer"] += time.perf_counter() - t_model

        t_coder = time.perf_counter()
        dist = constriction.stream.model.Categorical(probs, perfect=False)
        encoder.encode(int(tokens[i]), dist)
        timings["arith_encode"] += time.perf_counter() - t_coder

        all_seen.append(tokens[i])

    compressed_bits = encoder.get_compressed()

    t0 = time.perf_counter()
    with open(output_path, "wb") as f:
        f.write(struct.pack("<I", original_size))
        f.write(struct.pack("<I", num_tokens))
        f.write(compressed_bits.tobytes())
    timings["write_file"] += time.perf_counter() - t0

    duration = time.perf_counter() - start_time
    compressed_size = os.path.getsize(output_path)

    print_time_profile("Compression", timings, duration)

    return {
        "time": duration,
        "original_size": original_size,
        "compressed_size": compressed_size,
        "ratio": original_size / compressed_size,
        "bpc": (compressed_size * 8) / original_size,
        "speed_bps": original_size / duration,
        "num_tokens": num_tokens,
        "timings": timings,
    }

## 5) Funkcje kompresji i dekompresji

In [5]:
def decompress_file(model, tokenizer, input_path, output_path):
    model.eval()
    start_time = time.perf_counter()

    timings = {
        "read_file": 0.0,
        "model_infer": 0.0,
        "arith_decode": 0.0,
        "detokenize": 0.0,
        "write_file": 0.0,
    }

    t0 = time.perf_counter()
    with open(input_path, "rb") as f:
        original_size = struct.unpack("<I", f.read(4))[0]
        num_tokens = struct.unpack("<I", f.read(4))[0]
        bits = np.frombuffer(f.read(), dtype=np.uint32)
    timings["read_file"] += time.perf_counter() - t0

    decoder = constriction.stream.queue.RangeDecoder(bits)
    bos_id = tokenizer.eos_token_id

    print(f"Decompressing: {num_tokens:,} tokens → {original_size:,} bytes")

    past_kv = None
    kv_len = 0
    all_seen = [bos_id]
    decoded_tokens = []

    for _ in tqdm(range(num_tokens), desc="Decompressing"):
        t_model = time.perf_counter()
        if kv_len >= CONTEXT_SIZE:
            context = all_seen[-STRIDE:]
            probs, past_kv, kv_len = _reset_context(model, context)
        else:
            input_ids = torch.tensor([[all_seen[-1]]], device=DEVICE)
            probs, past_kv = _get_next_probs(model, input_ids, past_kv)
            kv_len += 1
        timings["model_infer"] += time.perf_counter() - t_model

        t_coder = time.perf_counter()
        dist = constriction.stream.model.Categorical(probs, perfect=False)
        token = int(decoder.decode(dist))
        timings["arith_decode"] += time.perf_counter() - t_coder

        decoded_tokens.append(token)
        all_seen.append(token)

    t0 = time.perf_counter()
    text = tokenizer.decode(decoded_tokens)
    timings["detokenize"] += time.perf_counter() - t0

    t0 = time.perf_counter()
    with open(output_path, "wb") as f:
        f.write(text.encode("utf-8"))
    timings["write_file"] += time.perf_counter() - t0

    duration = time.perf_counter() - start_time
    print_time_profile("Decompression", timings, duration)

    return {
        "time": duration,
        "speed_bps": original_size / duration,
        "timings": timings,
    }

In [6]:
def validate_roundtrip(original_path, decoded_path):
    with open(original_path, "rb") as f1, open(decoded_path, "rb") as f2:
        orig_data = f1.read()
        dec_data = f2.read()

    if orig_data == dec_data:
        print("✅ SUCCESS: Perfect match!")
        return True

    print("❌ MISMATCH!")
    print(f"   Original: {len(orig_data)} bytes, Decompressed: {len(dec_data)} bytes")
    for i in range(min(len(orig_data), len(dec_data))):
        if orig_data[i] != dec_data[i]:
            print(f"   First diff at byte {i}: {orig_data[i]} vs {dec_data[i]}")
            break
    return False

## 6) Train i test

In [7]:
print(f"Test file: {TEST_PATH}")

assert os.path.exists(TEST_PATH), f"File not found: {TEST_PATH}"

model, tokenizer = load_model()

print("\n=== COMPRESSION ===")
comp_metrics = compress_file(model, tokenizer, TEST_PATH, COMPRESSED_PATH)
print(f"Ratio: {comp_metrics['ratio']:.2f}x | BPC: {comp_metrics['bpc']:.2f}")

print("\n=== DECOMPRESSION ===")
decomp_metrics = decompress_file(model, tokenizer, COMPRESSED_PATH, DECOMPRESSED_PATH)
print(f"Speed: {decomp_metrics['speed_bps']:.2f} B/s")

print("\n=== VERIFICATION ===")
validate_roundtrip(TEST_PATH, DECOMPRESSED_PATH)

print("\n=== SUMMARY ===")
print(f"Baseline Results:")
print(f"Compression Speed: {comp_metrics['speed_bps']:.2f} B/s")
print(f"Decompression Speed: {decomp_metrics['speed_bps']:.2f} B/s")
print(f"Compression Ratio: {comp_metrics['ratio']:.2f}x")
print(f"BPC: {comp_metrics['bpc']:.2f}")

Test file: ../data/canterbury_small.bin
Loading gpt2...


Loading weights: 100%|██████████| 148/148 [00:00<00:00, 3017.63it/s, Materializing param=transformer.wte.weight]             
[1mGPT2LMHeadModel LOAD REPORT[0m from: gpt2
Key                  | Status     |  | 
---------------------+------------+--+-
h.{0...11}.attn.bias | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m
Token indices sequence length is longer than the specified maximum sequence length for this model (3064 > 1024). Running this sequence through the model will result in indexing errors


Loaded: 124,439,808 parameters on cpu

=== COMPRESSION ===
Original: 10,846 bytes → 3,064 tokens


Compressing: 100%|██████████| 3064/3064 [00:25<00:00, 119.42it/s]



Compression time profile:
  model_infer       25.3544s  ( 98.78%)
  arith_encode       0.2091s  (  0.81%)
  tokenize           0.0067s  (  0.03%)
  write_file         0.0010s  (  0.00%)
  read_bytes         0.0003s  (  0.00%)
  other              0.0953s  (  0.37%)
Ratio: 4.75x | BPC: 1.68

=== DECOMPRESSION ===
Decompressing: 3,064 tokens → 10,846 bytes


Decompressing: 100%|██████████| 3064/3064 [00:24<00:00, 124.66it/s]


Decompression time profile:
  model_infer       24.2668s  ( 98.72%)
  arith_decode       0.2179s  (  0.89%)
  write_file         0.0019s  (  0.01%)
  detokenize         0.0008s  (  0.00%)
  read_file          0.0001s  (  0.00%)
  other              0.0951s  (  0.39%)
Speed: 441.21 B/s

=== VERIFICATION ===
✅ SUCCESS: Perfect match!

=== SUMMARY ===
Baseline Results:
Compression Speed: 422.57 B/s
Decompression Speed: 441.21 B/s
Compression Ratio: 4.75x
BPC: 1.68



