# PixelCNN Image Compression on CIFAR-10

This notebook implements a PixelCNN-based learned image compression model trained on CIFAR-10.
The model learns to predict pixel probability distributions which are then used with ANS (Asymmetric Numeral Systems) 
encoding via `pytorchac` for efficient compression.

## Features:
- Masked convolutions for autoregressive pixel prediction
- Residual blocks for deeper architecture
- Checkpoint saving during training
- Comparison with traditional codecs (JPEG, PNG, WebP)

In [None]:
# Install required packages (uncomment if needed)
# !pip install torch torchvision pytorchac pillow numpy matplotlib tqdm

In [None]:
import os
import io
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import json
from pathlib import Path

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create directories for checkpoints and results
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('results', exist_ok=True)

## Model Architecture

The PixelCNN uses masked convolutions to ensure autoregressive property - each pixel only depends on previously seen pixels (raster scan order: top-to-bottom, left-to-right).

In [None]:
class MaskedConv2d(nn.Conv2d):
    """
    Masked Convolution for autoregressive models.
    
    mask_type='A': First layer - excludes center pixel (for input layer)
    mask_type='B': Subsequent layers - includes center pixel (for hidden layers)
    """
    def __init__(self, mask_type, in_channels, out_channels, kernel_size, **kwargs):
        assert mask_type in ['A', 'B'], "mask_type must be 'A' or 'B'"
        
        # Ensure kernel_size is odd for symmetric masking
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1, "Kernel size must be odd"
        
        # Set padding to maintain spatial dimensions
        padding = (kernel_size[0] // 2, kernel_size[1] // 2)
        super().__init__(in_channels, out_channels, kernel_size, padding=padding, **kwargs)
        
        self.mask_type = mask_type
        self.register_buffer('mask', torch.ones_like(self.weight))
        
        # Create mask
        _, _, h, w = self.weight.shape
        center_h, center_w = h // 2, w // 2
        
        # Zero out lower half
        self.mask[:, :, center_h + 1:, :] = 0
        # Zero out right side of center row
        self.mask[:, :, center_h, center_w + 1:] = 0
        
        # For type A, also zero out center pixel
        if mask_type == 'A':
            self.mask[:, :, center_h, center_w] = 0
    
    def forward(self, x):
        self.weight.data *= self.mask
        return super().forward(x)


class ResidualBlock(nn.Module):
    """Residual block with masked convolutions."""
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        self.conv1 = MaskedConv2d('B', channels, channels // 2, 1)
        self.conv2 = MaskedConv2d('B', channels // 2, channels // 2, kernel_size)
        self.conv3 = MaskedConv2d('B', channels // 2, channels, 1)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return self.relu(out + residual)


class PixelCNN(nn.Module):
    """
    PixelCNN for image compression.
    
    Predicts probability distribution over 256 pixel values for each pixel
    conditioned on previously seen pixels.
    """
    def __init__(self, in_channels=3, hidden_channels=128, num_residual_blocks=12, 
                 num_classes=256, kernel_size=7):
        super().__init__()
        
        self.num_classes = num_classes
        
        # Initial convolution (mask type A - excludes center)
        self.input_conv = MaskedConv2d('A', in_channels, hidden_channels, kernel_size)
        
        # Residual blocks
        self.residual_blocks = nn.ModuleList([
            ResidualBlock(hidden_channels, kernel_size=3) 
            for _ in range(num_residual_blocks)
        ])
        
        # Output layers
        self.output_conv1 = MaskedConv2d('B', hidden_channels, hidden_channels, 1)
        self.output_conv2 = MaskedConv2d('B', hidden_channels, hidden_channels, 1)
        
        # Final layer outputs logits for each pixel value (0-255) for each channel
        self.final_conv = MaskedConv2d('B', hidden_channels, in_channels * num_classes, 1)
        
        self.relu = nn.ReLU(inplace=True)
        self.in_channels = in_channels
    
    def forward(self, x):
        """
        Forward pass.
        
        Args:
            x: Input image tensor of shape (B, C, H, W) with values in [0, 255] as long tensor
               or normalized values
        
        Returns:
            logits: Shape (B, C, num_classes, H, W) - logits for each pixel value
        """
        # Normalize input to [-1, 1] for better training
        if x.dtype == torch.long:
            x = x.float() / 127.5 - 1
        
        out = self.relu(self.input_conv(x))
        
        for block in self.residual_blocks:
            out = block(out)
        
        out = self.relu(self.output_conv1(out))
        out = self.relu(self.output_conv2(out))
        out = self.final_conv(out)
        
        # Reshape to (B, C, num_classes, H, W)
        B, _, H, W = out.shape
        out = out.view(B, self.in_channels, self.num_classes, H, W)
        
        return out
    
    def get_probabilities(self, x):
        """Get probability distributions for each pixel."""
        logits = self.forward(x)
        return F.softmax(logits, dim=2)
    
    def loss(self, x, target):
        """
        Compute cross-entropy loss.
        
        Args:
            x: Input images (B, C, H, W)
            target: Target pixel values as long tensor (B, C, H, W) with values 0-255
        
        Returns:
            loss: Cross-entropy loss (bits per sub-pixel)
        """
        logits = self.forward(x)  # (B, C, num_classes, H, W)
        B, C, num_classes, H, W = logits.shape
        
        # Reshape for cross entropy
        logits = logits.permute(0, 1, 3, 4, 2).contiguous()  # (B, C, H, W, num_classes)
        logits = logits.view(-1, num_classes)  # (B*C*H*W, num_classes)
        target = target.view(-1)  # (B*C*H*W,)
        
        # Cross entropy loss in nats, convert to bits
        loss = F.cross_entropy(logits, target, reduction='mean')
        bits_per_subpixel = loss / np.log(2)
        
        return bits_per_subpixel


# Test model instantiation
model = PixelCNN(in_channels=3, hidden_channels=128, num_residual_blocks=12).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Data Loading

Load CIFAR-10 dataset and create train/validation split.

In [None]:
class CIFAR10Compression(torch.utils.data.Dataset):
    """
    CIFAR-10 dataset wrapper for compression task.
    Returns images as uint8 tensors (0-255).
    """
    def __init__(self, root='./data', train=True, download=True):
        # Load without normalization - we need raw pixel values
        self.dataset = torchvision.datasets.CIFAR10(
            root=root, 
            train=train, 
            download=download,
            transform=transforms.ToTensor()  # Converts to [0, 1] float
        )
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img, _ = self.dataset[idx]  # We don't need labels
        # Convert to [0, 255] uint8 for compression
        img_uint8 = (img * 255).to(torch.uint8)
        return img_uint8


# Load datasets
print("Loading CIFAR-10 dataset...")
full_train_dataset = CIFAR10Compression(train=True, download=True)
test_dataset = CIFAR10Compression(train=False, download=True)

# Split training into train and validation (90/10)
train_size = int(0.9 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(
    full_train_dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Create dataloaders
BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=2, pin_memory=True)

# Visualize some samples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
sample_batch = next(iter(train_loader))
for i, ax in enumerate(axes.flat):
    img = sample_batch[i].permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.axis('off')
plt.suptitle('Sample CIFAR-10 Images')
plt.tight_layout()
plt.savefig('results/sample_images.png', dpi=150)
plt.show()

## Training

Train the PixelCNN model with checkpoint saving.

In [None]:
class CheckpointManager:
    """Manages model checkpoints during training."""
    
    def __init__(self, checkpoint_dir='checkpoints', model_name='pixelcnn'):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(exist_ok=True)
        self.model_name = model_name
        self.best_loss = float('inf')
    
    def save_checkpoint(self, model, optimizer, scheduler, epoch, train_loss, val_loss, is_best=False):
        """Save a training checkpoint."""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'train_loss': train_loss,
            'val_loss': val_loss,
        }
        
        # Save latest checkpoint
        latest_path = self.checkpoint_dir / f'{self.model_name}_latest.pt'
        torch.save(checkpoint, latest_path)
        
        # Save epoch checkpoint (every 5 epochs)
        if (epoch + 1) % 5 == 0:
            epoch_path = self.checkpoint_dir / f'{self.model_name}_epoch_{epoch+1}.pt'
            torch.save(checkpoint, epoch_path)
            print(f"  Saved epoch checkpoint: {epoch_path}")
        
        # Save best model
        if is_best:
            best_path = self.checkpoint_dir / f'{self.model_name}_best.pt'
            torch.save(checkpoint, best_path)
            print(f"  Saved best model: {best_path}")
    
    def load_checkpoint(self, model, optimizer=None, scheduler=None, checkpoint_path=None):
        """Load a checkpoint."""
        if checkpoint_path is None:
            checkpoint_path = self.checkpoint_dir / f'{self.model_name}_latest.pt'
        
        if not checkpoint_path.exists():
            print("No checkpoint found, starting from scratch")
            return 0
        
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        if optimizer and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if scheduler and checkpoint.get('scheduler_state_dict'):
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']+1}")
        print(f"  Train loss: {checkpoint['train_loss']:.4f} bpp")
        print(f"  Val loss: {checkpoint['val_loss']:.4f} bpp")
        
        return checkpoint['epoch'] + 1


def train_epoch(model, train_loader, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc='Training', leave=False)
    for batch in pbar:
        batch = batch.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        
        # Input is the image, target is the same image (autoregressive prediction)
        input_img = batch.float() / 127.5 - 1  # Normalize to [-1, 1]
        target = batch.long()  # Target as class indices (0-255)
        
        loss = model.loss(input_img, target)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f} bpp'})
    
    return total_loss / num_batches


@torch.no_grad()
def validate(model, val_loader, device):
    """Validate the model."""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(val_loader, desc='Validating', leave=False):
        batch = batch.to(device)
        
        input_img = batch.float() / 127.5 - 1
        target = batch.long()
        
        loss = model.loss(input_img, target)
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches


def train_model(model, train_loader, val_loader, num_epochs=50, lr=3e-4, resume=True):
    """Full training loop with checkpointing."""
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    checkpoint_manager = CheckpointManager()
    
    # Resume from checkpoint if exists
    start_epoch = 0
    if resume:
        start_epoch = checkpoint_manager.load_checkpoint(model, optimizer, scheduler)
    
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    print(f"\nStarting training from epoch {start_epoch + 1}")
    print(f"Total epochs: {num_epochs}")
    print("-" * 50)
    
    for epoch in range(start_epoch, num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, device)
        train_losses.append(train_loss)
        
        # Validate
        val_loss = validate(model, val_loader, device)
        val_losses.append(val_loss)
        
        # Update scheduler
        scheduler.step(val_loss)
        
        # Check if best model
        is_best = val_loss < best_val_loss
        if is_best:
            best_val_loss = val_loss
        
        print(f"  Train Loss: {train_loss:.4f} bpp | Val Loss: {val_loss:.4f} bpp")
        print(f"  Best Val Loss: {best_val_loss:.4f} bpp")
        print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
        
        # Save checkpoint
        checkpoint_manager.save_checkpoint(
            model, optimizer, scheduler, epoch, train_loss, val_loss, is_best
        )
    
    # Plot training curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (bits per sub-pixel)')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True)
    plt.savefig('results/training_curve.png', dpi=150)
    plt.show()
    
    # Save training history
    history = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'best_val_loss': best_val_loss
    }
    with open('results/training_history.json', 'w') as f:
        json.dump(history, f, indent=2)
    
    return train_losses, val_losses

In [None]:
# Training configuration
NUM_EPOCHS = 50  # Adjust based on your compute budget
LEARNING_RATE = 3e-4

# Initialize fresh model or use existing one
model = PixelCNN(
    in_channels=3, 
    hidden_channels=128, 
    num_residual_blocks=12
).to(device)

# Train the model
train_losses, val_losses = train_model(
    model, 
    train_loader, 
    val_loader, 
    num_epochs=NUM_EPOCHS, 
    lr=LEARNING_RATE,
    resume=True  # Set to False to start fresh
)

## Compression with ANS Encoding

Use pytorchac for Asymmetric Numeral Systems (ANS) entropy coding.
The PixelCNN provides probability distributions, and ANS encodes symbols based on these distributions.

In [None]:
import constriction

class PixelCNNCompressor:
    """Compress images using PixelCNN + ANS encoding (constriction library)."""
    
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.model.eval()
    
    @torch.no_grad()
    def compress(self, image):
        """Compress a single image using ANS encoding with constriction."""
        self.model.eval()
        C, H, W = image.shape
        image = image.to(self.device)
        
        input_img = image.float().unsqueeze(0) / 127.5 - 1
        logits = self.model(input_img)
        probs = F.softmax(logits, dim=2).squeeze(0)
        
        # Shape: (C, 256, H, W) -> (C*H*W, 256)
        probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, 256)
        symbols = image.view(-1).to(torch.int32).cpu().numpy()
        
        # Normalize probabilities and convert to numpy
        probs = probs.clamp(min=1e-9)
        probs = probs / probs.sum(dim=-1, keepdim=True)
        probs_np = probs.cpu().numpy().astype(np.float32)
        
        # Use constriction's ANS encoder
        ans = constriction.stream.stack.AnsCoder()
        model_family = constriction.stream.model.Categorical(perfect=False)
        ans.encode_reverse(symbols, model_family, probs_np)
        
        compressed = ans.get_compressed()
        return compressed.tobytes(), (C, H, W)
    
    @torch.no_grad()
    def decompress(self, compressed_bytes, shape):
        """Decompress an image using PixelCNN + ANS."""
        self.model.eval()
        C, H, W = shape
        
        # Get probabilities from zeros (due to autoregressive masking)
        image = torch.zeros(1, C, H, W, device=self.device)
        input_img = image / 127.5 - 1
        logits = self.model(input_img)
        probs = F.softmax(logits, dim=2).squeeze(0)
        probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, 256)
        
        probs = probs.clamp(min=1e-9)
        probs = probs / probs.sum(dim=-1, keepdim=True)
        probs_np = probs.cpu().numpy().astype(np.float32)
        
        # Decode using constriction
        compressed = np.frombuffer(compressed_bytes, dtype=np.uint32)
        ans = constriction.stream.stack.AnsCoder(compressed)
        model_family = constriction.stream.model.Categorical(perfect=False)
        decoded = ans.decode(model_family, probs_np)
        
        return torch.from_numpy(decoded.astype(np.float32)).view(C, H, W)

## Evaluation: Compare with Traditional Codecs

Compare PixelCNN compression with PNG (lossless), JPEG (lossy), and WebP.

In [None]:
class TraditionalCodec:
    """Wrapper for traditional image codecs."""
    
    @staticmethod
    def compress_png(image_np):
        """Compress using PNG (lossless)."""
        img = Image.fromarray(image_np)
        buffer = io.BytesIO()
        img.save(buffer, format='PNG', optimize=True)
        return buffer.getvalue()
    
    @staticmethod
    def compress_jpeg(image_np, quality=95):
        """Compress using JPEG."""
        img = Image.fromarray(image_np)
        buffer = io.BytesIO()
        img.save(buffer, format='JPEG', quality=quality)
        return buffer.getvalue()
    
    @staticmethod
    def decompress_jpeg(compressed_bytes):
        """Decompress JPEG."""
        buffer = io.BytesIO(compressed_bytes)
        img = Image.open(buffer)
        return np.array(img)
    
    @staticmethod
    def compress_webp(image_np, quality=100, lossless=True):
        """Compress using WebP."""
        img = Image.fromarray(image_np)
        buffer = io.BytesIO()
        img.save(buffer, format='WEBP', quality=quality, lossless=lossless)
        return buffer.getvalue()
    
    @staticmethod
    def decompress_webp(compressed_bytes):
        """Decompress WebP."""
        buffer = io.BytesIO(compressed_bytes)
        img = Image.open(buffer)
        return np.array(img)


def compute_metrics(original, reconstructed):
    """Compute image quality metrics."""
    original = original.astype(np.float64)
    reconstructed = reconstructed.astype(np.float64)
    
    mse = np.mean((original - reconstructed) ** 2)
    
    if mse == 0:
        psnr = float('inf')
    else:
        psnr = 10 * np.log10((255 ** 2) / mse)
    
    return {'mse': mse, 'psnr': psnr}


def compute_bpp(compressed_size_bytes, image_shape):
    """Compute bits per pixel."""
    _, h, w = image_shape
    total_pixels = h * w
    total_bits = compressed_size_bytes * 8
    return total_bits / total_pixels


@torch.no_grad()
def evaluate_compression(model, val_loader, num_samples=100, device='cuda'):
    """
    Evaluate compression performance on validation set.
    Compare PixelCNN with traditional codecs.
    """
    model.eval()
    compressor = PixelCNNCompressor(model, device)
    codec = TraditionalCodec()
    
    results = {
        'pixelcnn': {'bpp': [], 'psnr': []},
        'png': {'bpp': [], 'psnr': []},
        'jpeg_95': {'bpp': [], 'psnr': []},
        'jpeg_75': {'bpp': [], 'psnr': []},
        'webp_lossless': {'bpp': [], 'psnr': []},
        'webp_lossy': {'bpp': [], 'psnr': []},
    }
    
    sample_count = 0
    
    print(f"Evaluating on {num_samples} samples...")
    
    for batch in tqdm(val_loader, desc='Evaluating'):
        for img_tensor in batch:
            if sample_count >= num_samples:
                break
            
            img_np = img_tensor.permute(1, 2, 0).numpy()  # (H, W, C)
            
            # PixelCNN compression (using fast parallel method)
            try:
                compressed_pixelcnn, shape = compressor.compress(img_tensor)
                pixelcnn_bpp = compute_bpp(len(compressed_pixelcnn), img_tensor.shape)
                results['pixelcnn']['bpp'].append(pixelcnn_bpp)
                results['pixelcnn']['psnr'].append(float('inf'))  # Lossless
            except Exception as e:
                print(f"PixelCNN compression error: {e}")
                results['pixelcnn']['bpp'].append(float('nan'))
                results['pixelcnn']['psnr'].append(float('nan'))
            
            # PNG (lossless)
            compressed_png = codec.compress_png(img_np)
            png_bpp = compute_bpp(len(compressed_png), img_tensor.shape)
            results['png']['bpp'].append(png_bpp)
            results['png']['psnr'].append(float('inf'))
            
            # JPEG quality 95
            compressed_jpeg95 = codec.compress_jpeg(img_np, quality=95)
            jpeg95_bpp = compute_bpp(len(compressed_jpeg95), img_tensor.shape)
            decoded_jpeg95 = codec.decompress_jpeg(compressed_jpeg95)
            jpeg95_metrics = compute_metrics(img_np, decoded_jpeg95)
            results['jpeg_95']['bpp'].append(jpeg95_bpp)
            results['jpeg_95']['psnr'].append(jpeg95_metrics['psnr'])
            
            # JPEG quality 75
            compressed_jpeg75 = codec.compress_jpeg(img_np, quality=75)
            jpeg75_bpp = compute_bpp(len(compressed_jpeg75), img_tensor.shape)
            decoded_jpeg75 = codec.decompress_jpeg(compressed_jpeg75)
            jpeg75_metrics = compute_metrics(img_np, decoded_jpeg75)
            results['jpeg_75']['bpp'].append(jpeg75_bpp)
            results['jpeg_75']['psnr'].append(jpeg75_metrics['psnr'])
            
            # WebP lossless
            compressed_webp_ll = codec.compress_webp(img_np, lossless=True)
            webp_ll_bpp = compute_bpp(len(compressed_webp_ll), img_tensor.shape)
            results['webp_lossless']['bpp'].append(webp_ll_bpp)
            results['webp_lossless']['psnr'].append(float('inf'))
            
            # WebP lossy quality 95
            compressed_webp_lossy = codec.compress_webp(img_np, quality=95, lossless=False)
            webp_lossy_bpp = compute_bpp(len(compressed_webp_lossy), img_tensor.shape)
            decoded_webp = codec.decompress_webp(compressed_webp_lossy)
            webp_metrics = compute_metrics(img_np, decoded_webp)
            results['webp_lossy']['bpp'].append(webp_lossy_bpp)
            results['webp_lossy']['psnr'].append(webp_metrics['psnr'])
            
            sample_count += 1
        
        if sample_count >= num_samples:
            break
    
    # Compute averages
    summary = {}
    for codec_name, metrics in results.items():
        valid_bpp = [b for b in metrics['bpp'] if not np.isnan(b)]
        valid_psnr = [p for p in metrics['psnr'] if not np.isnan(p) and not np.isinf(p)]
        
        summary[codec_name] = {
            'avg_bpp': np.mean(valid_bpp) if valid_bpp else float('nan'),
            'std_bpp': np.std(valid_bpp) if valid_bpp else float('nan'),
            'avg_psnr': np.mean(valid_psnr) if valid_psnr else 'lossless',
            'num_samples': len(valid_bpp)
        }
    
    return results, summary


def print_results_table(summary):
    """Print a formatted results table."""
    print("\n" + "=" * 70)
    print("COMPRESSION COMPARISON RESULTS")
    print("=" * 70)
    print(f"{'Codec':<20} {'Avg BPP':<15} {'Std BPP':<15} {'Avg PSNR (dB)':<15}")
    print("-" * 70)
    
    for codec_name, stats in summary.items():
        psnr_str = f"{stats['avg_psnr']:.2f}" if isinstance(stats['avg_psnr'], float) else stats['avg_psnr']
        print(f"{codec_name:<20} {stats['avg_bpp']:<15.4f} {stats['std_bpp']:<15.4f} {psnr_str:<15}")
    
    print("=" * 70)


def plot_results(results, summary):
    """Plot comparison results."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # BPP comparison
    codecs = list(summary.keys())
    bpps = [summary[c]['avg_bpp'] for c in codecs]
    stds = [summary[c]['std_bpp'] for c in codecs]
    
    colors = ['#2ecc71', '#3498db', '#e74c3c', '#e67e22', '#9b59b6', '#1abc9c']
    
    bars = axes[0].bar(codecs, bpps, yerr=stds, capsize=5, color=colors)
    axes[0].set_ylabel('Bits Per Pixel (BPP)')
    axes[0].set_title('Compression Ratio Comparison')
    axes[0].tick_params(axis='x', rotation=45)
    axes[0].grid(axis='y', alpha=0.3)
    
    # Add value labels
    for bar, bpp in zip(bars, bpps):
        axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                     f'{bpp:.2f}', ha='center', va='bottom', fontsize=9)
    
    # BPP distribution (box plot)
    bpp_data = [results[c]['bpp'] for c in codecs]
    bp = axes[1].boxplot(bpp_data, labels=codecs, patch_artist=True)
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    axes[1].set_ylabel('Bits Per Pixel (BPP)')
    axes[1].set_title('BPP Distribution')
    axes[1].tick_params(axis='x', rotation=45)
    axes[1].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('results/compression_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    return fig

In [None]:
# Load best model for evaluation
checkpoint_path = Path('checkpoints/pixelcnn_best.pt')

if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
    print(f"Validation loss: {checkpoint['val_loss']:.4f} bpp")
else:
    print("No checkpoint found. Using current model state.")
    print("Make sure to train the model first!")

# Run evaluation
results, summary = evaluate_compression(model, val_loader, num_samples=100, device=device)

# Print results
print_results_table(summary)

# Plot results
plot_results(results, summary)

# Save results
with open('results/compression_results.json', 'w') as f:
    # Convert numpy to python types for JSON serialization
    json_results = {}
    for codec, metrics in summary.items():
        json_results[codec] = {
            'avg_bpp': float(metrics['avg_bpp']) if not np.isnan(metrics['avg_bpp']) else None,
            'std_bpp': float(metrics['std_bpp']) if not np.isnan(metrics['std_bpp']) else None,
            'avg_psnr': float(metrics['avg_psnr']) if isinstance(metrics['avg_psnr'], float) else metrics['avg_psnr'],
            'num_samples': metrics['num_samples']
        }
    json.dump(json_results, f, indent=2)
print("\nResults saved to results/compression_results.json")

## Visualize Compression on Individual Images

In [None]:
@torch.no_grad()
def visualize_compression_example(model, val_loader, device='cuda', num_examples=3):
    """Visualize compression on individual images."""
    model.eval()
    compressor = PixelCNNCompressor(model, device)
    codec = TraditionalCodec()
    
    # Get some samples
    samples = next(iter(val_loader))[:num_examples]
    
    fig, axes = plt.subplots(num_examples, 5, figsize=(15, 3*num_examples))
    if num_examples == 1:
        axes = axes.reshape(1, -1)
    
    for idx, img_tensor in enumerate(samples):
        img_np = img_tensor.permute(1, 2, 0).numpy()
        
        # Compute compressed sizes
        try:
            compressed_pixelcnn, _ = compressor.compress(img_tensor)
            pixelcnn_size = len(compressed_pixelcnn)
        except:
            pixelcnn_size = -1
        
        compressed_png = codec.compress_png(img_np)
        compressed_jpeg = codec.compress_jpeg(img_np, quality=75)
        compressed_webp = codec.compress_webp(img_np, lossless=True)
        
        # Decode JPEG for visualization
        decoded_jpeg = codec.decompress_jpeg(compressed_jpeg)
        
        # Original size (uncompressed)
        original_size = img_np.size  # H * W * C bytes
        
        # Plot
        axes[idx, 0].imshow(img_np)
        axes[idx, 0].set_title(f'Original\n{original_size} bytes')
        axes[idx, 0].axis('off')
        
        axes[idx, 1].imshow(img_np)
        pixelcnn_str = f'{pixelcnn_size}' if pixelcnn_size > 0 else 'N/A'
        axes[idx, 1].set_title(f'PixelCNN+ANS\n{pixelcnn_str} bytes')
        axes[idx, 1].axis('off')
        
        axes[idx, 2].imshow(img_np)
        axes[idx, 2].set_title(f'PNG\n{len(compressed_png)} bytes')
        axes[idx, 2].axis('off')
        
        axes[idx, 3].imshow(decoded_jpeg)
        axes[idx, 3].set_title(f'JPEG (Q=75)\n{len(compressed_jpeg)} bytes')
        axes[idx, 3].axis('off')
        
        axes[idx, 4].imshow(img_np)
        axes[idx, 4].set_title(f'WebP Lossless\n{len(compressed_webp)} bytes')
        axes[idx, 4].axis('off')
    
    plt.suptitle('Compression Comparison (Byte Sizes)', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig('results/compression_examples.png', dpi=150, bbox_inches='tight')
    plt.show()

# Visualize examples
visualize_compression_example(model, val_loader, device=device, num_examples=3)

## Theoretical Analysis

Compute the theoretical compression rate based on model's cross-entropy loss (which equals the entropy of the predicted distribution when optimally coded).

In [None]:
@torch.no_grad()
def compute_theoretical_bpp(model, data_loader, device='cuda', num_batches=50):
    """
    Compute the theoretical bits-per-pixel based on model's predictions.
    
    The cross-entropy loss in bits is equivalent to the expected codelength
    under optimal entropy coding (Shannon's source coding theorem).
    """
    model.eval()
    
    total_bits = 0
    total_subpixels = 0
    
    for i, batch in enumerate(tqdm(data_loader, desc='Computing theoretical BPP')):
        if i >= num_batches:
            break
            
        batch = batch.to(device)
        B, C, H, W = batch.shape
        
        input_img = batch.float() / 127.5 - 1
        target = batch.long()
        
        # Get bits per sub-pixel
        loss_bpp = model.loss(input_img, target)
        
        # Total bits for this batch
        num_subpixels = B * C * H * W
        total_bits += loss_bpp.item() * num_subpixels
        total_subpixels += num_subpixels
    
    # Bits per sub-pixel (per channel)
    bpp_subpixel = total_bits / total_subpixels
    
    # Bits per pixel (all channels combined) - for CIFAR-10: 3 channels
    bpp_pixel = bpp_subpixel * 3
    
    return bpp_subpixel, bpp_pixel


# Compute theoretical BPP
bpp_subpixel, bpp_pixel = compute_theoretical_bpp(model, val_loader, device=device)

print("\n" + "=" * 50)
print("THEORETICAL COMPRESSION RATE (based on model loss)")
print("=" * 50)
print(f"Bits per sub-pixel (channel): {bpp_subpixel:.4f}")
print(f"Bits per pixel (RGB):         {bpp_pixel:.4f}")
print(f"Theoretical file size for 32x32 RGB image:")
print(f"  - Uncompressed: {32*32*3*8} bits ({32*32*3} bytes)")
print(f"  - PixelCNN:     {bpp_pixel * 32 * 32:.0f} bits ({bpp_pixel * 32 * 32 / 8:.0f} bytes)")
print(f"Compression ratio: {(32*32*3*8) / (bpp_pixel * 32 * 32):.2f}x")
print("=" * 50)

## Summary

### Files Created:
- `pixelCNN.ipynb` - This notebook with full training and evaluation pipeline
- `evaluate_compression.py` - Standalone script for detailed evaluation
- `checkpoints/` - Directory containing model checkpoints
- `results/` - Directory containing evaluation results and plots

### Usage:

1. **Training**: Run the training cells above. Checkpoints are saved automatically.

2. **Evaluation in notebook**: Run the evaluation cells to compare with traditional codecs.

3. **Standalone evaluation**:
```bash
python evaluate_compression.py --checkpoint checkpoints/pixelcnn_best.pt --num_samples 500
```

### Expected Results:

After training for ~50 epochs, the PixelCNN should achieve approximately:
- **4-5 bits per pixel (bpp)** on CIFAR-10

Comparison with traditional lossless codecs on CIFAR-10:
- **PNG**: ~6-7 bpp
- **WebP Lossless**: ~5-6 bpp

The learned compression model can potentially achieve better rates by exploiting the statistical regularities learned from the training data.