<a href="https://colab.research.google.com/github/dqtuan99/SlimCAE_PyTorch/blob/main/SlimCAE_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

There are many library conflicts in Colab due to preinstalled packages. My first setup cell locks compatible versions and then deliberately kills the kernel (os.kill(os.getpid(), 9)) to force a clean restart. This is expected. After the runtime reconnects, start from Cell 2. Any errors shown when the kernel is terminated can be ignored.

In [None]:
# ⬇️ Run this cell first: pin NumPy/torch/core deps → auto-restart
import os, sys, subprocess

def run(*args):
    print(">>>", " ".join(args))
    subprocess.check_call(list(args))

# (1) Proactively uninstall packages that may pull in NumPy 2.x (ignore if missing)
for pkg in [
    "numpy","scipy","opencv-python","opencv-contrib-python","opencv-python-headless",
    "thinc","tsfresh","sentence-transformers","peft","transformers"
]:
    subprocess.call(["pip","uninstall","-y",pkg])

# (2) Pin the core versions
run("pip","install","--no-cache-dir","numpy==1.26.4","scipy==1.11.4")
run("pip","install","--no-cache-dir",
    "torch==2.1.2+cu118","torchvision==0.16.2+cu118","torchaudio==2.1.2+cu118",
    "--index-url","https://download.pytorch.org/whl/cu118")
run("pip","install","--no-cache-dir","pytorch-lightning==1.9.5","hydra-core","omegaconf")
run("pip","install","--no-cache-dir","pybind11","ninja","pybind11-stubgen==0.10.0")

# (3) CompressAI (incl. GDN deps) and MS-SSIM
run("pip","install","--no-cache-dir","compressai==1.2.6","pytorch-msssim==0.2.1")

# Show versions
import numpy, torch
print("NumPy:", numpy.__version__)
print("Torch:", torch.__version__)

print("\n🔁 Restarting runtime… (after reconnect, continue from Cell 2)")
os.kill(os.getpid(), 9)  # ← force-restart the Colab kernel


>>> pip install --no-cache-dir numpy==1.26.4 scipy==1.11.4
>>> pip install --no-cache-dir torch==2.1.2+cu118 torchvision==0.16.2+cu118 torchaudio==2.1.2+cu118 --index-url https://download.pytorch.org/whl/cu118
>>> pip install --no-cache-dir pytorch-lightning==1.9.5 hydra-core omegaconf
>>> pip install --no-cache-dir pybind11 ninja pybind11-stubgen==0.10.0
>>> pip install --no-cache-dir compressai==1.2.6 pytorch-msssim==0.2.1


In [None]:
import os, sys, subprocess, shutil, site, inspect, pathlib

REPO = "/content/cbench_BaSIC"
URL  = "https://github.com/worldlife123/cbench_BaSIC.git"

def run(cmd, cwd=None, check=False):
    print(">>>", " ".join(cmd))
    p = subprocess.run(cmd, cwd=cwd, text=True, capture_output=True)
    # dump full logs
    if p.stdout:
        print(p.stdout)
    if p.returncode != 0:
        print("❌ STDERR:\n", p.stderr)
        if check:
            raise subprocess.CalledProcessError(p.returncode, cmd)
    return p.returncode == 0

# 0) Required packages (Cell 1 should have installed these; reinforce if missing)
run([sys.executable,"-m","pip","install","--no-cache-dir",
     "compressai==1.2.6","pytorch-msssim==0.2.1"])
run([sys.executable,"-m","pip","install","--no-cache-dir",
     "pybind11","ninja","pybind11-stubgen==0.10.0"])

# 1) Clean clone/reset
if not os.path.isdir(REPO):
    run(["git","clone",URL,REPO], check=True)
else:
    run(["git","fetch"], cwd=REPO, check=True)
    run(["git","reset","--hard","origin/main"], cwd=REPO, check=True)

# 2) Remove previous build artifacts + stale egg-link
for p in ["build","cbench.egg-info"]:
    shutil.rmtree(os.path.join(REPO,p), ignore_errors=True)

for sp in set(site.getsitepackages() + [site.getusersitepackages()]):
    egg = os.path.join(sp, "cbench.egg-link")
    if os.path.exists(egg):
        print("🧹 remove:", egg)
        os.remove(egg)

# 3) Editable install with build isolation disabled (important!)
ok = run([sys.executable,"-m","pip","install","--no-build-isolation","-e",REPO])

# 4) Fallback to setup.py if pip -e fails
if not ok:
    print("⚠️ pip -e failed → trying fallback: setup.py build develop")
    ok = run([sys.executable,"setup.py","build","develop"], cwd=REPO)

if not ok:
    raise SystemExit("❌ Install failed. Check STDERR above.")

# 5) Import test
try:
    import cbench
except ModuleNotFoundError:
    # Path fallback (in case only egg-link was created or nested layout)
    sys.path.insert(0, REPO)
    alt = os.path.join(REPO, "cbench_BaSIC")
    if os.path.isdir(alt):
        sys.path.insert(0, alt)
    import cbench

print("cbench from:", inspect.getfile(cbench))

from cbench.nn.layers.slimmable_layers import (
    DynamicConv2d, DynamicGDN,
    DynamicResidualBlock, DynamicResidualBlockUpsample,
)
print("✅ Layers OK:", DynamicConv2d, DynamicGDN)


>>> /usr/bin/python3 -m pip install --no-cache-dir compressai==1.2.6 pytorch-msssim==0.2.1

>>> /usr/bin/python3 -m pip install --no-cache-dir pybind11 ninja pybind11-stubgen==0.10.0

>>> git clone https://github.com/worldlife123/cbench_BaSIC.git /content/cbench_BaSIC
>>> /usr/bin/python3 -m pip install --no-build-isolation -e /content/cbench_BaSIC
Obtaining file:///content/cbench_BaSIC
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'error'

❌ STDERR:
   error: subprocess-exited-with-error
  
  × python setup.py egg_info did not run successfully.
  │ exit code: 1
  ╰─> See above for output.
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

⚠️ pip -e failed → trying fallback: setup.py build 

In [None]:
# slimcae_wrapper.py
import torch
import torch.nn as nn
from cbench.nn.layers.slimmable_layers import (
    DynamicConv2d, DynamicGDN,
    DynamicResidualBlock, DynamicResidualBlockUpsample
)

class SlimCAE(nn.Module):
    def __init__(self, channels_list=(64, 96, 128)):
        super().__init__()
        C = list(channels_list)            # e.g., [64, 96, 128]
        C_out3 = [3] * len(C)              # IMPORTANT: same length as the number of slim levels

        # Encoder
        self.enc = nn.Sequential(
            DynamicConv2d(3, C, 5, 2, None, bias=True),
            DynamicGDN(C),
            DynamicResidualBlock(C[-1], C),
            DynamicConv2d(C, C, 5, 2, None, bias=True),
            DynamicGDN(C),
        )

        # Decoder
        self.dec = nn.Sequential(
            DynamicResidualBlockUpsample(C[-1], C, upsample=2),
            DynamicResidualBlock(C[-1], C),
            DynamicConv2d(C, C_out3, 5, 1, None, bias=True),  # NOTE: out-channels must be a list here
        )

    def set_level(self, level: int):
        # Propagate the selected slim level to all dynamic modules
        for m in self.modules():
            if hasattr(m, "set_complex_level"):
                m.set_complex_level(level)

    def forward(self, x):
        z = self.enc(x)
        y = self.dec(z)
        return y, z



In [None]:
net = SlimCAE((64,96,128))
net.set_level(1)
out, z = net(torch.randn(1,3,256,256))
print(out.shape)

torch.Size([1, 3, 128, 128])


This revised SlimCAE class ensures the output image has the same spatial dimensions as the input image (e.g., 256x256).

In [None]:
# slimcae_wrapper.py (Corrected)
import torch
import torch.nn as nn
from cbench.nn.layers.slimmable_layers import (
    DynamicConv2d, DynamicGDN,
    DynamicResidualBlock, DynamicResidualBlockUpsample
)

class SlimCAE(nn.Module):
    """
    Corrected Slimmable Convolutional Autoencoder (SlimCAE).

    This version has a symmetric encoder-decoder architecture, ensuring the
    reconstructed image has the same dimensions as the input.
    """
    def __init__(self, channels_list=(64, 96, 128)):
        super().__init__()
        # Ensure channels_list is a list
        C = list(channels_list)
        # Output channels for the final layer (always 3 for an RGB image)
        C_out3 = [3] * len(C)

        # Encoder: Downsamples the input image by a factor of 4
        self.enc = nn.Sequential(
            DynamicConv2d(3, C, kernel_size=5, stride=2, groups_list=None, bias=True),
            DynamicGDN(C),
            DynamicConv2d(C, C, kernel_size=5, stride=2, groups_list=None, bias=True),
            DynamicGDN(C),
        )

        # Decoder: Upsamples the latent representation back to the original image size
        self.dec = nn.Sequential(
            DynamicGDN(C, inverse=True),
            # The `in_channels` argument for dynamic blocks is typically the max channel count
            DynamicResidualBlockUpsample(C[-1], C, upsample=2),
            DynamicGDN(C, inverse=True),
            DynamicResidualBlockUpsample(C[-1], C_out3, upsample=2),
        )

    def set_level(self, level: int):
        """
        Sets the complexity level for all dynamic modules in the network.
        Level 0 is the smallest model, and level n-1 is the largest.
        """
        for m in self.modules():
            if hasattr(m, "set_complex_level"):
                m.set_complex_level(level)

    def forward(self, x):
        """
        Forward pass through the autoencoder.
        """
        # z is the latent representation
        z = self.enc(x)
        # y is the reconstructed image
        y = self.dec(z)
        return y, z

# --- Verification Step ---
# Create a sample input tensor
test_net = SlimCAE((64, 96, 128))
test_net.set_level(2) # Set to the largest model
test_input = torch.randn(1, 3, 256, 256)

# Pass the input through the model
reconstruction, latent = test_net(test_input)

# Check if the output shape matches the input shape
print(f"Input shape: {test_input.shape}")
print(f"Latent shape: {latent.shape}")
print(f"Reconstruction shape: {reconstruction.shape}")
assert reconstruction.shape == test_input.shape
print("✅ Model architecture verified successfully!")

This cell sets up all the necessary components for training: hyperparameters, device configuration, and data loaders. We will use the CIFAR-10 dataset from torchvision and resize the images to 256x256 to work with the model's architecture.

In [None]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# --- Configuration & Hyperparameters ---
CONFIG = {
    "epochs": 10,
    "batch_size": 16,
    "learning_rate": 1e-4,
    "lambda_rate": 1e-2, # Weight for the rate term in the loss function
    "num_levels": 3, # Corresponds to channels_list=(64, 96, 128)
    "channels_list": (64, 96, 128),
    "val_split": 0.1, # 10% of training data for validation
}

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# --- Data Preparation ---
# Define transformations: Resize, convert to tensor, and normalize
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Download and load the full CIFAR-10 training dataset
full_train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                  download=True, transform=transform)

# Split training data into training and validation sets
num_train = len(full_train_dataset)
indices = list(range(num_train))
split = int(np.floor(CONFIG["val_split"] * num_train))
np.random.shuffle(indices)
train_idx, val_idx = indices[split:], indices[:split]

train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)

# Create data loaders
train_loader = DataLoader(full_train_dataset, batch_size=CONFIG["batch_size"], sampler=train_sampler, num_workers=2)
val_loader = DataLoader(full_train_dataset, batch_size=CONFIG["batch_size"], sampler=val_sampler, num_workers=2)

# Load the test dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                                 download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=2)

print("Data loaders created successfully.")

Training and Evaluation

In [None]:
from tqdm import tqdm
import math

def rate_distortion_loss(reconstruction, original, latent, lambda_rate):
    """
    Calculates the rate-distortion loss.
    L = Distortion + lambda * Rate
    """
    # Distortion: Mean Squared Error
    mse = F.mse_loss(reconstruction, original)

    # Rate: A simple estimate of bits per pixel (BPP).
    # We model the latent distribution as a zero-mean Gaussian and calculate
    # its entropy. A lower entropy means a more compressible representation.
    # This is a common proxy for the actual bitrate.
    bpp = torch.mean(torch.log2(1 + latent.pow(2)))

    # Total Loss
    total_loss = mse + lambda_rate * bpp

    return total_loss, mse, bpp

def train_one_epoch(model, dataloader, optimizer, epoch, num_levels, lambda_rate, device):
    model.train()
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
    total_loss_avg = 0

    for i, (images, _) in enumerate(progress_bar):
        images = images.to(device)
        optimizer.zero_grad()

        total_loss_batch = 0

        # --- Sandwich Rule: Train on all complexity levels ---
        for level in range(num_levels):
            model.set_level(level)
            reconstruction, latent = model(images)
            loss, mse, bpp = rate_distortion_loss(reconstruction, images, latent, lambda_rate)
            total_loss_batch += loss

        # Average the loss over all levels and backpropagate
        avg_loss = total_loss_batch / num_levels
        avg_loss.backward()
        optimizer.step()

        total_loss_avg += avg_loss.item()
        progress_bar.set_postfix(loss=f"{total_loss_avg / (i+1):.4f}")

    return total_loss_avg / len(dataloader)


def evaluate(model, dataloader, num_levels, lambda_rate, device):
    model.eval()
    # Dictionaries to store metrics for each level
    losses, mses, bpps, psnrs = [0]*num_levels, [0]*num_levels, [0]*num_levels, [0]*num_levels

    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)

            for level in range(num_levels):
                model.set_level(level)
                reconstruction, latent = model(images)
                loss, mse, bpp = rate_distortion_loss(reconstruction, images, latent, lambda_rate)

                # Accumulate metrics
                losses[level] += loss.item()
                mses[level] += mse.item()
                bpps[level] += bpp.item()
                # Calculate PSNR from MSE
                psnrs[level] += 20 * math.log10(1.0 / math.sqrt(mse.item()))

    # Average the metrics over the dataset
    num_batches = len(dataloader)
    results = []
    for level in range(num_levels):
        results.append({
            "level": level,
            "loss": losses[level] / num_batches,
            "mse": mses[level] / num_batches,
            "bpp": bpps[level] / num_batches,
            "psnr": psnrs[level] / num_batches,
        })

    return results

Main Execution and Visualization

In [None]:
# --- Initialization ---
model = SlimCAE(channels_list=CONFIG["channels_list"]).to(device)
optimizer = optim.Adam(model.parameters(), lr=CONFIG["learning_rate"])
train_losses = []
val_results_history = []

# --- Training Loop ---
for epoch in range(CONFIG["epochs"]):
    # Train for one epoch
    train_loss = train_one_epoch(
        model, train_loader, optimizer, epoch,
        CONFIG["num_levels"], CONFIG["lambda_rate"], device
    )
    train_losses.append(train_loss)

    # Evaluate on the validation set
    val_results = evaluate(
        model, val_loader, CONFIG["num_levels"], CONFIG["lambda_rate"], device
    )
    val_results_history.append(val_results)

    # Print summary
    print(f"\n--- Epoch {epoch+1} Validation Results ---")
    for res in val_results:
        print(f"Level {res['level']}: Loss={res['loss']:.4f}, PSNR={res['psnr']:.2f} dB, BPP={res['bpp']:.4f}")
    print("-" * 35)


print("\n✅ Training finished.")

# --- Final Testing ---
print("\nRunning final evaluation on the test set...")
test_results = evaluate(model, test_loader, CONFIG["num_levels"], CONFIG["lambda_rate"], device)

print("\n--- Final Test Results ---")
for res in test_results:
    print(f"Level {res['level']} (Channels: {CONFIG['channels_list'][res['level']]}): PSNR={res['psnr']:.2f} dB, BPP={res['bpp']:.4f}")
print("-" * 28)


# --- Visualization ---
# 1. Plot training loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title("Average Training Loss vs. Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)

# 2. Plot Rate-Distortion-Complexity Tradeoff
plt.subplot(1, 2, 2)
colors = ['blue', 'green', 'red']
for level in range(CONFIG["num_levels"]):
    psnr = test_results[level]['psnr']
    bpp = test_results[level]['bpp']
    channels = CONFIG['channels_list'][level]
    plt.scatter(bpp, psnr, color=colors[level], label=f'Level {level} ({channels} channels)', s=100)

plt.title("Rate-Distortion-Complexity Tradeoff (Test Set)")
plt.xlabel("Rate (BPP)")
plt.ylabel("Distortion (PSNR)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# 3. Visualize some reconstructions
model.eval()
images, _ = next(iter(test_loader))
images = images[:4].to(device) # Take 4 images

fig, axes = plt.subplots(CONFIG['num_levels'], 5, figsize=(15, 8))
fig.suptitle("Original vs. Reconstructions at Different Complexity Levels")
for i in range(4):
    axes[0, i+1].imshow(images[i].cpu().permute(1, 2, 0) * 0.5 + 0.5)
    axes[0, i+1].set_title(f"Original {i+1}")
    axes[0, i+1].axis('off')

for level in range(CONFIG['num_levels']):
    model.set_level(level)
    reconstructions, _ = model(images)
    axes[level, 0].text(0.5, 0.5, f"Level {level}", ha='center', va='center', fontsize=12)
    axes[level, 0].axis('off')

    for i in range(4):
        axes[level, i+1].imshow(reconstructions[i].cpu().detach().permute(1, 2, 0) * 0.5 + 0.5)
        axes[level, i+1].axis('off')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

Todo: Rate Distortion Complexity Trade Off