# 4. Powerful Baseline Model

Mocniejsza wersja baseline z większym modelem i lepszym treningiem.

## 1) Importy i konfiguracja

In [10]:
import torch
import torch.nn as nn
import numpy as np
import constriction
import os
import struct
import time
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Configuration
IS_COLAB = False  # Set to True if running on Google Colab

if IS_COLAB:
    TRAIN_PATH = "/content/all_silesia.bin"
    TEST_PATH = "/content/all_canterbury.bin"
    COMPRESSED_PATH = "/content/compressed_powerful.bin"
    DECOMPRESSED_PATH = "/content/decompressed_powerful.txt"
    MODEL_PATH = "/content/model_compressor_powerful.pth"
else:
    TRAIN_PATH = "../data/all_silesia.bin" 
    TEST_PATH = "../data/all_canterbury.bin"
    COMPRESSED_PATH = "../out/compressed_powerful.bin"
    DECOMPRESSED_PATH = "../out/decompressed_powerful.txt"
    MODEL_PATH = "../out/model_compressor_powerful.pth"

# Model Hyperparameters
EMBED_DIM = 128
HIDDEN_SIZE = 512
NUM_LAYERS = 3
DROPOUT = 0.1

# Training Hyperparameters
EPOCHS = 2
SEQ_LEN = 512
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
CLIP_GRAD = 1.0

# if metal apple is availble,  use it
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Using device: Metal Apple (MPS)")

# Device setup
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("Using device: CUDA")
# else:
#     DEVICE = torch.device("cpu")
#     print("Using device: CPU")

Using device: Metal Apple (MPS)


## 2) DataLoader

In [11]:
class ByteDataset(Dataset):
    def __init__(self, file_path, seq_len):
        if not os.path.exists(file_path):
            # Create dummy data if file doesn't exist for testing flow
            print(f"Warning: {file_path} not found. Using dummy data.")
            self.data = torch.randint(0, 256, (100000,), dtype=torch.long)
        else:
            with open(file_path, 'rb') as f:
                self.data = np.frombuffer(f.read(), dtype=np.uint8)
            self.data = torch.from_numpy(self.data).long()
            
        self.seq_len = seq_len
        self.n_samples = len(self.data) - seq_len - 1

    def __len__(self):
        # Ensure we don't return negative length
        return max(0, self.n_samples // self.seq_len)

    def __getitem__(self, idx):
        # Use simple strided access. 
        # Ideally, for stateful RNNs, we want contiguous batches, but that requires custom sampler.
        # Here we just rely on long SEQ_LEN to warm up the state.
        start = idx * self.seq_len
        end = start + self.seq_len + 1
        
        chunk = self.data[start:end]
        return chunk[:-1], chunk[1:]

## 3) Model

In [12]:
class PowerfulCompressor(nn.Module):
    def __init__(self, vocab_size=257, embed_dim=EMBED_DIM, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, dropout=DROPOUT):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        self.fc = nn.Linear(hidden_size, 256) # Output 256 byte probabilities

    def forward(self, x, hidden=None):
        x = self.embed(x)
        out, hidden = self.gru(x, hidden)
        logits = self.fc(out)
        return logits, hidden
    
    def _get_probs(self, x, hidden):
        # Helper for inference
        with torch.no_grad():
            logits, hidden = self(x, hidden)
            # Softmax over the last dimension
            probs = torch.softmax(logits[0, 0], dim=0).cpu().numpy().astype(np.float64)
            # Float64 helps with precision for arithmetic coding, though model is float32 usually
        return probs, hidden

## 4) Trening

In [13]:
def train_model(model, train_path, epochs=EPOCHS):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=1, factor=0.5)
    criterion = nn.CrossEntropyLoss()
    scaler = torch.amp.GradScaler('cuda') if DEVICE.type == 'cuda' else None

    dataset = ByteDataset(train_path, SEQ_LEN)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    
    start_time = time.time()
    history = []

    print(f"Training on {len(dataset)} sequences of length {SEQ_LEN}...")

    for epoch in range(epochs):
        total_loss = 0
        steps = 0
        pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}", unit="batch")

        for x, y in pbar:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            
            # Mixed Precision Context
            if DEVICE.type == 'cuda':
                with torch.amp.autocast('cuda'):
                    logits, _ = model(x)
                    loss = criterion(logits.view(-1, 256), y.view(-1))
                scaler.scale(loss).backward()
                
                # Gradient Clipping (unscale first)
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)
                
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard Precision (CPU/MPS)
                logits, _ = model(x)
                loss = criterion(logits.view(-1, 256), y.view(-1))
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)
                optimizer.step()

            total_loss += loss.item()
            steps += 1
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        avg_loss = total_loss / steps
        bpc = avg_loss / 0.693147
        history.append({'loss': avg_loss, 'bpc': bpc})
        
        print(f"Epoch {epoch + 1}/{epochs} | Loss: {avg_loss:.4f} | BPC: {bpc:.4f}")
        scheduler.step(avg_loss)

    total_time = time.time() - start_time
    print(f"Training finished in {total_time:.2f} seconds.")
    return history, total_time

## 5) Funkcje kompresji i dekompresji

In [14]:
def print_time_profile(name, timings, total_time):
    print(f"\n{name} time profile:")
    if total_time <= 0:
        print("  No timing data")
        return

    items = sorted(timings.items(), key=lambda x: x[1], reverse=True)
    tracked = 0.0
    for key, value in items:
        tracked += value
        pct = (value / total_time) * 100
        print(f"  {key:<16} {value:>8.4f}s  ({pct:>6.2f}%)")

    remaining = max(total_time - tracked, 0.0)
    if remaining > 1e-9:
        pct = (remaining / total_time) * 100
        print(f"  {'other':<16} {remaining:>8.4f}s  ({pct:>6.2f}%)")


def compress_file(model, input_path, output_path):
    model.eval()
    encoder = constriction.stream.queue.RangeEncoder()

    timings = {
        'read_bytes': 0.0,
        'model_infer': 0.0,
        'arith_encode': 0.0,
        'state_update': 0.0,
        'write_file': 0.0,
    }

    t0 = time.perf_counter()
    with open(input_path, "rb") as f:
        data_to_compress = np.frombuffer(f.read(), dtype=np.uint8)
    timings['read_bytes'] += time.perf_counter() - t0

    curr_symbol = torch.tensor([[256]], dtype=torch.long, device=DEVICE)
    hidden = None
    length = len(data_to_compress)

    print(f"Compressing {length} bytes...")
    start_time = time.perf_counter()

    for symbol in tqdm(data_to_compress, desc="Encoding"):
        t_model = time.perf_counter()
        probs, hidden = model._get_probs(curr_symbol, hidden)
        timings['model_infer'] += time.perf_counter() - t_model

        probs = np.clip(probs, 1e-9, 1.0)
        probs /= probs.sum()

        t_coder = time.perf_counter()
        dist = constriction.stream.model.Categorical(probs, perfect=False)
        encoder.encode(int(symbol), dist)
        timings['arith_encode'] += time.perf_counter() - t_coder

        t_state = time.perf_counter()
        curr_symbol = torch.tensor([[symbol]], dtype=torch.long, device=DEVICE)
        timings['state_update'] += time.perf_counter() - t_state

    compressed_bits = encoder.get_compressed()

    t0 = time.perf_counter()
    with open(output_path, "wb") as f:
        f.write(struct.pack('<I', length))
        f.write(compressed_bits.tobytes())
    timings['write_file'] += time.perf_counter() - t0

    duration = time.perf_counter() - start_time
    original_size = length
    compressed_size = os.path.getsize(output_path)
    ratio = original_size / compressed_size
    bpc = (compressed_size * 8) / original_size

    print_time_profile("Compression", timings, duration)

    return {
        'time': duration,
        'original_size': original_size,
        'compressed_size': compressed_size,
        'ratio': ratio,
        'bpc': bpc,
        'speed_bps': original_size / duration,
        'timings': timings,
    }


def decompress_file(model, input_path, output_path):
    model.eval()

    timings = {
        'read_file': 0.0,
        'model_infer': 0.0,
        'arith_decode': 0.0,
        'state_update': 0.0,
        'write_file': 0.0,
    }

    start_time = time.perf_counter()

    t0 = time.perf_counter()
    with open(input_path, "rb") as f:
        orig_len = struct.unpack('<I', f.read(4))[0]
        bits = np.frombuffer(f.read(), dtype=np.uint32)
    timings['read_file'] += time.perf_counter() - t0

    decoder = constriction.stream.queue.RangeDecoder(bits)
    decoded_data = []
    curr_symbol = torch.tensor([[256]], dtype=torch.long, device=DEVICE)
    hidden = None

    print(f"Decompressing {orig_len} bytes...")

    for _ in tqdm(range(orig_len), desc="Decoding"):
        t_model = time.perf_counter()
        probs, hidden = model._get_probs(curr_symbol, hidden)
        timings['model_infer'] += time.perf_counter() - t_model

        probs = np.clip(probs, 1e-9, 1.0)
        probs /= probs.sum()

        t_coder = time.perf_counter()
        dist = constriction.stream.model.Categorical(probs, perfect=False)
        symbol = decoder.decode(dist)
        timings['arith_decode'] += time.perf_counter() - t_coder

        decoded_data.append(symbol)

        t_state = time.perf_counter()
        curr_symbol = torch.tensor([[symbol]], dtype=torch.long, device=DEVICE)
        timings['state_update'] += time.perf_counter() - t_state

    t0 = time.perf_counter()
    with open(output_path, "wb") as f:
        f.write(bytes(decoded_data))
    timings['write_file'] += time.perf_counter() - t0

    duration = time.perf_counter() - start_time

    print_time_profile("Decompression", timings, duration)

    return {
        'time': duration,
        'speed_bps': orig_len / duration,
        'timings': timings,
    }

## 6) Train i test

In [15]:
# Setup
TEST_PATH = "../data/canterbury_small.bin"
model = PowerfulCompressor().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# # 1. Train
# if os.path.exists(TRAIN_PATH):
#     print("=== TRAINING ===")
#     train_metrics, train_time = train_model(model, TRAIN_PATH)
#     torch.save(model.state_dict(), MODEL_PATH)
#     print(f"Model saved to {MODEL_PATH}")
# else:
#     print(f"Skipping training: {TRAIN_PATH} not found.")

# 2. Compress
if os.path.exists(TEST_PATH):
    print("=== COMPRESSION ===")
    comp_metrics = compress_file(model, TEST_PATH, COMPRESSED_PATH)
    print(f"Compression Ratio: {comp_metrics['ratio']:.2f}x")
    print(f"BPC: {comp_metrics['bpc']:.2f}")
else:
    print(f"Skipping compression: {TEST_PATH} not found.")

# 3. Decompress
if os.path.exists(COMPRESSED_PATH):
    print("=== DECOMPRESSION ===")
    decomp_metrics = decompress_file(model, COMPRESSED_PATH, DECOMPRESSED_PATH)
    
    # Verify
    with open(TEST_PATH, 'rb') as f1, open(DECOMPRESSED_PATH, 'rb') as f2:
        if f1.read() == f2.read():
            print("SUCCESS: Integrity verified!")
        else:
            print("FAILURE: Data mismatch!")
else:
    print("Skipping decompression: File not found.")

Model Parameters: 4,302,208
=== COMPRESSION ===
Compressing 10846 bytes...


Encoding: 100%|██████████| 10846/10846 [00:07<00:00, 1419.75it/s]



Compression time profile:
  model_infer        6.0512s  ( 79.20%)
  state_update       1.4888s  ( 19.48%)
  arith_encode       0.0150s  (  0.20%)
  write_file         0.0004s  (  0.01%)
  read_bytes         0.0002s  (  0.00%)
  other              0.0850s  (  1.11%)
Compression Ratio: 2.89x
BPC: 2.77
=== DECOMPRESSION ===
Decompressing 10846 bytes...


Decoding: 100%|██████████| 10846/10846 [00:07<00:00, 1432.57it/s]


Decompression time profile:
  model_infer        6.0064s  ( 79.32%)
  state_update       1.4609s  ( 19.29%)
  arith_decode       0.0150s  (  0.20%)
  write_file         0.0003s  (  0.00%)
  read_file          0.0001s  (  0.00%)
  other              0.0898s  (  1.19%)
SUCCESS: Integrity verified!





## 7) Podsumowanie

In [16]:
print(f"Baseline Results:")
print(f"Compression Speed: {comp_metrics['speed_bps']:.2f} B/s")
print(f"Decompression Speed: {decomp_metrics['speed_bps']:.2f} B/s")
print(f"Compression Ratio: {comp_metrics['ratio']:.2f}x")
print(f"BPC: {comp_metrics['bpc']:.2f}")

Baseline Results:
Compression Speed: 1419.52 B/s
Decompression Speed: 1432.30 B/s
Compression Ratio: 2.89x
BPC: 2.77
