# IDF Training Notebook

Train an Integer Discrete Flow (IDF) model for lossless image compression on ImageNet-1k.

This notebook uses random 64×64 crops from the streaming ImageNet dataset.

In [None]:
# ============================================================
# COLAB SETUP - Run this cell first!
# ============================================================
import os
import sys

# Check if running on Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Install dependencies first
    print('Installing dependencies...')
    %pip install -q datasets huggingface_hub tqdm
    
    # Clone repo if not exists
    if not os.path.exists('AI-Compression'):
        !git clone https://github.com/darren10101/AI-Compression
    os.chdir('AI-Compression')
    
    print(f'Working directory: {os.getcwd()}')
    print(f'Contents: {os.listdir(".")}')
    
    # Check if src folder exists
    if not os.path.exists('src'):
        print('\n⚠️  WARNING: src folder not found!')
        print('Please clone your GitHub repo or upload the src folder.')

# Add current directory to path
sys.path.insert(0, '.')

import torch
import matplotlib.pyplot as plt
from IPython.display import clear_output

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

Installing dependencies...
Working directory: /content/AI-Compression/AI-Compression
Contents: ['.git', 'environment.yml', 'configs', 'README.md', 'train_notebook.ipynb', 'scripts', '.cursor', 'src', 'AI-Compression']
⚠️  HF_TOKEN not found!
   Create a .env file with: HF_TOKEN=your_token_here
   Get your token at: https://huggingface.co/settings/tokens
   Make sure you have accepted ImageNet terms at:
   https://huggingface.co/datasets/ILSVRC/imagenet-1k
Using device: cuda
GPU: Tesla T4
Memory: 15.8 GB


## 1. Setup Dataset

Create dataloaders that stream ImageNet and extract random 64×64 crops.

In [None]:
# HuggingFace login for ImageNet-1k access
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
from src.dataset.crop_dataset import create_dataloader, RandomCropDataset

# Configuration
CROP_SIZE = 64
BATCH_SIZE = 32
# num_workers=0 is required for streaming datasets on Colab
# (multiprocessing with IterableDataset causes issues)
NUM_WORKERS = 0

# Shuffle buffer size: larger = better randomness, but slower first batch
# - buffer_size=1: instant first batch (no shuffling)
# - buffer_size=100: ~10-20 sec first batch, moderate shuffling
# - buffer_size=1000: ~1-2 min first batch, good shuffling
BUFFER_SIZE = 100  # Good tradeoff for Colab

# Create dataloaders
print('Creating training dataloader...')
train_loader = create_dataloader(
    split='train',
    crop_size=CROP_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    buffer_size=BUFFER_SIZE,
)

print('Creating validation dataloader...')
val_loader = create_dataloader(
    split='validation',
    crop_size=CROP_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    buffer_size=BUFFER_SIZE,
)

Creating training dataloader...
Creating validation dataloader...


In [None]:
# Visualize some training samples
# Note: First iteration may take 10-30 seconds to fill the shuffle buffer
print('Loading first batch (this may take a moment on first run)...')

# Create iterator once and reuse it
train_iter = iter(train_loader)
sample_batch = next(train_iter)

print(f'Batch shape: {sample_batch.shape}')
print(f'Value range: [{sample_batch.min():.1f}, {sample_batch.max():.1f}]')

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    if i < len(sample_batch):
        img = sample_batch[i].permute(1, 2, 0).numpy() / 255.0
        ax.imshow(img.clip(0, 1))
        ax.axis('off')
        ax.set_title(f'Sample {i+1}')
plt.suptitle('Random 64×64 Crops from ImageNet', fontsize=14)
plt.tight_layout()
plt.show()

Loading first batch (this may take a moment on first run)...


Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

## 2. Create Model

Initialize the Integer Discrete Flow model.

In [4]:
from src.model import create_idf_model

# Model configuration
model_config = {
    'in_channels': 3,
    'hidden_channels': 64,    # Width of coupling networks
    'num_levels': 3,           # Hierarchical levels (squeeze + flow block)
    'num_steps': 8,            # Flow steps per level
}

model = create_idf_model(model_config)
model = model.to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f'Model parameters: {num_params:,}')
print(f'Model config: {model_config}')

Model parameters: 545,588
Model config: {'in_channels': 3, 'hidden_channels': 64, 'num_levels': 3, 'num_steps': 8}


In [6]:
# Test forward pass
test_batch = sample_batch[:4].to(device)
print(f'Input shape: {test_batch.shape}')

with torch.no_grad():
    loss, bpd = model.compute_loss(test_batch)
    print(f'Initial BPD: {bpd.item():.3f}')
    print(f'(Lower is better, theoretical minimum ~4-5 bpd for natural images)')

NameError: name 'sample_batch' is not defined

## 3. Training Setup

In [None]:
from src.train import IDFTrainer

# Training configuration
LEARNING_RATE = 1e-4
GRAD_CLIP = 1.0
CHECKPOINT_DIR = '../checkpoints'

trainer = IDFTrainer(
    model=model,
    device=device,
    lr=LEARNING_RATE,
    grad_clip=GRAD_CLIP,
    checkpoint_dir=CHECKPOINT_DIR,
    use_amp=True,  # Use mixed precision if available
)

print(f'Trainer initialized')
print(f'Checkpoints will be saved to: {CHECKPOINT_DIR}')

## 4. Train the Model

Run the training loop. You can interrupt at any time - checkpoints are saved periodically.

In [None]:
# Training parameters
TOTAL_STEPS = 50000     # Adjust based on compute budget
LOG_INTERVAL = 100      # Log every N steps
VAL_INTERVAL = 1000     # Validate every N steps
SAVE_INTERVAL = 5000    # Save checkpoint every N steps

# Start training
history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_steps=TOTAL_STEPS,
    log_interval=LOG_INTERVAL,
    val_interval=VAL_INTERVAL,
    save_interval=SAVE_INTERVAL,
)

## 5. Visualize Training Progress

In [None]:
# Plot training curves
if len(history) > 0:
    steps = [h['step'] for h in history]
    bpds = [h['bpd'] for h in history]
    lrs = [h['lr'] for h in history]
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # BPD plot
    axes[0].plot(steps, bpds)
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Bits per Dimension (BPD)')
    axes[0].set_title('Training Loss')
    axes[0].grid(True, alpha=0.3)
    
    # Learning rate plot
    axes[1].plot(steps, lrs)
    axes[1].set_xlabel('Step')
    axes[1].set_ylabel('Learning Rate')
    axes[1].set_title('Learning Rate Schedule')
    axes[1].set_yscale('log')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f'Final BPD: {bpds[-1]:.4f}')
    print(f'Best BPD: {min(bpds):.4f}')
else:
    print('No training history yet.')

## 6. Test Compression / Reconstruction

Verify the model can perfectly reconstruct images (lossless).

In [None]:
# Test lossless reconstruction
model.eval()

test_batch = next(iter(val_loader))[:4].to(device)

with torch.no_grad():
    # Compress
    latents, prior_params = model.compress(test_batch)
    
    # Decompress
    reconstructed = model.decompress(latents)
    
# Check reconstruction error
error = (test_batch - reconstructed).abs()
max_error = error.max().item()
mean_error = error.mean().item()

print(f'Max reconstruction error: {max_error:.6f}')
print(f'Mean reconstruction error: {mean_error:.6f}')

if max_error < 1e-4:
    print('✓ Reconstruction is lossless!')
else:
    print('⚠ Reconstruction has errors (expected for untrained model)')

In [None]:
# Visualize original vs reconstructed
fig, axes = plt.subplots(2, 4, figsize=(12, 6))

for i in range(4):
    # Original
    orig = test_batch[i].cpu().permute(1, 2, 0).numpy() / 255.0
    axes[0, i].imshow(orig.clip(0, 1))
    axes[0, i].set_title(f'Original {i+1}')
    axes[0, i].axis('off')
    
    # Reconstructed
    recon = reconstructed[i].cpu().permute(1, 2, 0).numpy() / 255.0
    axes[1, i].imshow(recon.clip(0, 1))
    axes[1, i].set_title(f'Reconstructed {i+1}')
    axes[1, i].axis('off')

plt.suptitle('Lossless Compression Test', fontsize=14)
plt.tight_layout()
plt.show()

## 7. Save Final Model

In [None]:
# Save the trained model
trainer.save_checkpoint('notebook_final.pt')
print('Model saved!')

---

## Resume Training (Optional)

To resume from a checkpoint:

In [None]:
# # Uncomment to resume from checkpoint
# trainer.load_checkpoint('../checkpoints/checkpoint_step10000.pt')
# 
# # Continue training
# history = trainer.train(
#     train_loader=train_loader,
#     val_loader=val_loader,
#     num_steps=100000,
# )