In [None]:
!pip install compressai

Collecting compressai
  Downloading compressai-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (295 kB)
[?25l[K     |█                               | 10 kB 23.5 MB/s eta 0:00:01[K     |██▏                             | 20 kB 6.0 MB/s eta 0:00:01[K     |███▎                            | 30 kB 6.7 MB/s eta 0:00:01[K     |████▍                           | 40 kB 8.1 MB/s eta 0:00:01[K     |█████▌                          | 51 kB 9.6 MB/s eta 0:00:01[K     |██████▋                         | 61 kB 10.3 MB/s eta 0:00:01[K     |███████▊                        | 71 kB 10.7 MB/s eta 0:00:01[K     |████████▉                       | 81 kB 7.5 MB/s eta 0:00:01[K     |██████████                      | 92 kB 8.3 MB/s eta 0:00:01[K     |███████████                     | 102 kB 9.1 MB/s eta 0:00:01[K     |████████████▏                   | 112 kB 9.1 MB/s eta 0:00:01[K     |█████████████▎                  | 122 kB 9.1 MB/s eta 0:00:01[K     |██████████████▍   

In [None]:
import argparse
import math
import random
import shutil
import sys
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision

from compressai.datasets import ImageFolder
from compressai.zoo import image_models
import compressai

In [None]:
from compressai.zoo import (bmshj2018_factorized, bmshj2018_hyperprior, mbt2018_mean, mbt2018, cheng2020_anchor)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
metric = 'mse'  # only pre-trained model for mse are available for now
quality = 1     # lower quality -> lower bit-rate (use lower quality to clearly see visual differences in the notebook)
networks = {
    'bmshj2018-factorized': bmshj2018_factorized(quality=quality, pretrained=True).eval().to(device),
    'bmshj2018-hyperprior': bmshj2018_hyperprior(quality=quality, pretrained=True).eval().to(device),
    'mbt2018-mean': mbt2018_mean(quality=quality, pretrained=True).eval().to(device),
    'mbt2018': mbt2018(quality=quality, pretrained=True).eval().to(device),
    'cheng2020-anchor': cheng2020_anchor(quality=quality, pretrained=True).eval().to(device),
}

net = networks['bmshj2018-factorized']

Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-factorized-prior-1-446d5c7f.pth.tar" to /root/.cache/torch/hub/checkpoints/bmshj2018-factorized-prior-1-446d5c7f.pth.tar


  0%|          | 0.00/11.5M [00:00<?, ?B/s]

Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-1-7eb97409.pth.tar" to /root/.cache/torch/hub/checkpoints/bmshj2018-hyperprior-1-7eb97409.pth.tar


  0%|          | 0.00/20.2M [00:00<?, ?B/s]

Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-mean-1-e522738d.pth.tar" to /root/.cache/torch/hub/checkpoints/mbt2018-mean-1-e522738d.pth.tar


  0%|          | 0.00/27.6M [00:00<?, ?B/s]

Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-1-3f36cd77.pth.tar" to /root/.cache/torch/hub/checkpoints/mbt2018-1-3f36cd77.pth.tar


  0%|          | 0.00/61.8M [00:00<?, ?B/s]

Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020-anchor-1-dad2ebff.pth.tar" to /root/.cache/torch/hub/checkpoints/cheng2020-anchor-1-dad2ebff.pth.tar


  0%|          | 0.00/49.1M [00:00<?, ?B/s]

### Model Classes

In [None]:
class RateDistortionLoss(nn.Module):
    """Custom rate distortion loss with a Lagrangian parameter."""

    def __init__(self, lmbda=1e-2):
        super().__init__()
        self.mse = nn.MSELoss()
        self.lmbda = lmbda

    def forward(self, output, target):
        N, _, H, W = target.size()
        out = {}
        num_pixels = N * H * W

        out["bpp_loss"] = sum(
            (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
            for likelihoods in output["likelihoods"].values()
        )
        out["mse_loss"] = self.mse(output["x_hat"], target)
        out["loss"] = self.lmbda * 255**2 * out["mse_loss"] + out["bpp_loss"]

        return out


class AverageMeter:
    """Compute running average."""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class CustomDataParallel(nn.DataParallel):
    """Custom DataParallel to access the module methods."""

    def __getattr__(self, key):
        try:
            return super().__getattr__(key)
        except AttributeError:
            return getattr(self.module, key)


def configure_optimizers(net, args):
    """Separate parameters for the main optimizer and the auxiliary optimizer.
    Return two optimizers"""

    parameters = {
        n
        for n, p in net.named_parameters()
        if not n.endswith(".quantiles") and p.requires_grad
    }
    aux_parameters = {
        n
        for n, p in net.named_parameters()
        if n.endswith(".quantiles") and p.requires_grad
    }

    # Make sure we don't have an intersection of parameters
    params_dict = dict(net.named_parameters())
    inter_params = parameters & aux_parameters
    union_params = parameters | aux_parameters

    assert len(inter_params) == 0
    assert len(union_params) - len(params_dict.keys()) == 0

    optimizer = optim.Adam(
        (params_dict[n] for n in sorted(parameters)),
        lr=args.learning_rate,
    )
    aux_optimizer = optim.Adam(
        (params_dict[n] for n in sorted(aux_parameters)),
        lr=args.aux_learning_rate,
    )
    return optimizer, aux_optimizer

### Train and Test Epochs

In [None]:
def train_one_epoch(
    model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm
):
    model.train()
    device = next(model.parameters()).device

    for i, d in enumerate(train_dataloader):
        d = d[0].to(device)

        optimizer.zero_grad()
        aux_optimizer.zero_grad()

        out_net = model(d)

        out_criterion = criterion(out_net, d)
        out_criterion["loss"].backward()
        if clip_max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
        optimizer.step()

        aux_loss = model.aux_loss()
        aux_loss.backward()
        aux_optimizer.step()

        if i % 300 == 0:
            print(
                f"Train epoch {epoch}: ["
                f"{i*len(d)}/{len(train_dataloader.dataset)}"
                f" ({100. * i / len(train_dataloader):.0f}%)]"
                f'\tLoss: {out_criterion["loss"].item():.3f} |'
                f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |'
                f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |'
                f"\tAux loss: {aux_loss.item():.2f}"
            )


def test_epoch(epoch, test_dataloader, model, criterion):
    model.eval()
    device = next(model.parameters()).device

    loss = AverageMeter()
    bpp_loss = AverageMeter()
    mse_loss = AverageMeter()
    aux_loss = AverageMeter()

    with torch.no_grad():
        for d in test_dataloader:
            d = d[0].to(device)
            out_net = model(d)
            out_criterion = criterion(out_net, d)

            aux_loss.update(model.aux_loss())
            bpp_loss.update(out_criterion["bpp_loss"])
            loss.update(out_criterion["loss"])
            mse_loss.update(out_criterion["mse_loss"])

    print(
        f"Test epoch {epoch}: Average losses:"
        f"\tLoss: {loss.avg:.3f} |"
        f"\tMSE loss: {mse_loss.avg:.3f} |"
        f"\tBpp loss: {bpp_loss.avg:.2f} |"
        f"\tAux loss: {aux_loss.avg:.2f}\n"
    )

    return loss.avg


def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, "checkpoint_best_loss.pth.tar")


### Main

In [None]:
def main(model, num_workers, batch_size, cuda, epoch, patch_size, learning_rate, lmbda):
    # args = parse_args(argv)

    # if args.seed is not None:
    #     torch.manual_seed(args.seed)
    #     random.seed(args.seed)

    class arguments:
      def __init__(self, model, num_workers, batch_size, cuda, epoch, patch_size, learning_rate, lmbda):
        self.model = model
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.cuda = cuda
        self.epochs = epoch
        self.patch_size = patch_size
        self.learning_rate = learning_rate
        self.aux_learning_rate = learning_rate
        self.lmbda = lmbda
        self.save = True
        self.seed = False
        self.clip_max_norm = 1.0
        self.checkpoint = False


    args = arguments(model, num_workers, batch_size, cuda, epoch, patch_size, learning_rate, lmbda)

    train_transforms = transforms.Compose(
        [transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
    )

    test_transforms = transforms.Compose(
        [transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
    )

    # train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
    # test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)

    train_dataset = torchvision.datasets.CIFAR10('./CIFAR-10/',train=True,download=True, transform=train_transforms)
    test_dataset = torchvision.datasets.CIFAR10('./CIFAR-10/',train=False,download=True, transform=test_transforms)

    device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        pin_memory=(device == "cuda"),
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=(device == "cuda"),
    )

    net = model
    net = net.to(device)

    if args.cuda and torch.cuda.device_count() > 1:
        net = CustomDataParallel(net)

    optimizer, aux_optimizer = configure_optimizers(net, args)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
    criterion = RateDistortionLoss(lmbda=args.lmbda)

    last_epoch = 0
    if args.checkpoint:  # load from previous checkpoint
        print("Loading", args.checkpoint)
        checkpoint = torch.load(args.checkpoint, map_location=device)
        last_epoch = checkpoint["epoch"] + 1
        net.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

    best_loss = float("inf")
    for epoch in range(last_epoch, args.epochs):
        T11 = time.time()

        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        train_one_epoch(
            net,
            criterion,
            train_dataloader,
            optimizer,
            aux_optimizer,
            epoch,
            args.clip_max_norm,
        )
        loss = test_epoch(epoch, test_dataloader, net, criterion)
        lr_scheduler.step(loss)
        T22 = time.time()
        print(f"Time: {T22-T11:.4f}")
        is_best = loss < best_loss
        best_loss = min(loss, best_loss)

        if args.save:
            save_checkpoint(
                {
                    "epoch": epoch,
                    "state_dict": net.state_dict(),
                    "loss": loss,
                    "optimizer": optimizer.state_dict(),
                    "aux_optimizer": aux_optimizer.state_dict(),
                    "lr_scheduler": lr_scheduler.state_dict(),
                },
                is_best,
            )


In [None]:
ic_net = main(net, 2, 64, 1, 30, 32, 0.001, 0.0003)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./CIFAR-10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./CIFAR-10/cifar-10-python.tar.gz to ./CIFAR-10/
Files already downloaded and verified
Learning rate: 0.001
Test epoch 0: Average losses:	Loss: 0.338 |	MSE loss: 0.013 |	Bpp loss: 0.08 |	Aux loss: 1215.00

Time: 25.9359
Learning rate: 0.001
Test epoch 1: Average losses:	Loss: 0.321 |	MSE loss: 0.012 |	Bpp loss: 0.09 |	Aux loss: 472.20

Time: 25.1789
Learning rate: 0.001
Test epoch 2: Average losses:	Loss: 0.423 |	MSE loss: 0.010 |	Bpp loss: 0.23 |	Aux loss: 239.56

Time: 25.2242
Learning rate: 0.001
Test epoch 3: Average losses:	Loss: 0.304 |	MSE loss: 0.010 |	Bpp loss: 0.11 |	Aux loss: 191.10

Time: 25.2351
Learning rate: 0.001
Test epoch 4: Average losses:	Loss: 0.302 |	MSE loss: 0.010 |	Bpp loss: 0.11 |	Aux loss: 158.22

Time: 25.1001
Learning rate: 0.001
Test epoch 5: Average losses:	Loss: 0.307 |	MSE loss: 0.010 |	Bpp loss: 0.11 |	Aux loss: 107.47

Time: 25.1866
Learning rate: 0.001
Test epoch 6: Average losses:	Loss: 0.301 |	MSE loss: 0.010 |	Bpp loss: 0.11 |	Aux loss:

## Data Prepration for Classification

In [None]:
train_transforms = transforms.Compose(
    [transforms.RandomCrop(32), transforms.ToTensor()]
)

test_transforms = transforms.Compose(
    [transforms.CenterCrop(32), transforms.ToTensor()]
)

# train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
# test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)

train_dataset = torchvision.datasets.CIFAR10('./CIFAR-10/',train=True,download=True, transform=train_transforms)
test_dataset = torchvision.datasets.CIFAR10('./CIFAR-10/',train=False,download=True, transform=test_transforms)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
checkpoint = torch.load('checkpoint.pth.tar', map_location=device)
net.load_state_dict(checkpoint["state_dict"])
net.to('cuda')

FactorizedPrior(
  (entropy_bottleneck): EntropyBottleneck(
    (likelihood_lower_bound): LowerBound()
  )
  (g_a): Sequential(
    (0): Conv2d(3, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): GDN(
      (beta_reparam): NonNegativeParametrizer(
        (lower_bound): LowerBound()
      )
      (gamma_reparam): NonNegativeParametrizer(
        (lower_bound): LowerBound()
      )
    )
    (2): Conv2d(128, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (3): GDN(
      (beta_reparam): NonNegativeParametrizer(
        (lower_bound): LowerBound()
      )
      (gamma_reparam): NonNegativeParametrizer(
        (lower_bound): LowerBound()
      )
    )
    (4): Conv2d(128, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (5): GDN(
      (beta_reparam): NonNegativeParametrizer(
        (lower_bound): LowerBound()
      )
      (gamma_reparam): NonNegativeParametrizer(
        (lower_bound): LowerBound()
      )
    )
    (6): Conv2d(128, 192, kernel_s

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=64,
    num_workers=2,
    shuffle=True,
    pin_memory=(device == "cuda"),
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=64,
    num_workers=2,
    shuffle=False,
    pin_memory=(device == "cuda"),
)

## ResNet

In [None]:
resnet = torchvision.models.resnet18(pretrained=True)
resnet.fc = nn.Linear(in_features=512, out_features=10, bias=True)
resnet.to('cuda')
Loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [None]:
Nepoch = 100
for epoch in range(Nepoch): 
    train_loss = 0.0
    test_loss = 0.0
    t1 = time.time()
    train_acc = 0.0
    test_acc = 0.0

    for i, data in enumerate(train_dataloader):
        images, labels = data
        images = images.cuda()
        ic_out = net(images)
        images = ic_out["x_hat"]
        labels = labels.cuda()
        optimizer.zero_grad()
        predicted_output = resnet(images)
        fit = Loss(predicted_output,labels)
        fit.backward()
        optimizer.step()   
        train_loss += fit.item()
        train_acc += torch.sum(labels == predicted_output.argmax(dim=1)).item()

    for i, data in enumerate(test_dataloader):
        with torch.no_grad():
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            ic_out = net(images)
            images = ic_out["x_hat"]
            labels = labels.cuda()
            predicted_output = resnet(images)
            fit = Loss(predicted_output,labels)
            test_loss += fit.item()
            test_acc += torch.sum(labels == predicted_output.argmax(dim=1)).item()

    train_loss = train_loss/len(train_dataset)
    test_loss = test_loss/len(test_dataset)
    train_acc = train_acc/len(train_dataset)
    test_acc = test_acc/len(test_dataset)
    t2 = time.time()

    print(f'Epoch: {epoch} \tTrain Loss: {train_loss:.5f} \tTrain Acc: {train_acc:.4f} \tTest Loss: {test_loss:.5f} \tTest Acc: {test_acc:.4f} \tTime: {t2-t1:.4f}')

Epoch: 0 	Train Loss: 0.02409 	Train Acc: 0.4472 	Test Loss: 0.02280 	Test Acc: 0.4823 	Time: 39.6702
Epoch: 1 	Train Loss: 0.02101 	Train Acc: 0.5229 	Test Loss: 0.02126 	Test Acc: 0.5186 	Time: 39.6350
Epoch: 2 	Train Loss: 0.01949 	Train Acc: 0.5616 	Test Loss: 0.02053 	Test Acc: 0.5378 	Time: 39.6441
Epoch: 3 	Train Loss: 0.01833 	Train Acc: 0.5870 	Test Loss: 0.02015 	Test Acc: 0.5419 	Time: 39.7244
Epoch: 4 	Train Loss: 0.01659 	Train Acc: 0.6290 	Test Loss: 0.02043 	Test Acc: 0.5464 	Time: 39.6776
Epoch: 5 	Train Loss: 0.01510 	Train Acc: 0.6620 	Test Loss: 0.02060 	Test Acc: 0.5607 	Time: 39.6817
Epoch: 6 	Train Loss: 0.01384 	Train Acc: 0.6893 	Test Loss: 0.02117 	Test Acc: 0.5542 	Time: 39.6667
Epoch: 7 	Train Loss: 0.01160 	Train Acc: 0.7411 	Test Loss: 0.02266 	Test Acc: 0.5511 	Time: 39.7196
Epoch: 8 	Train Loss: 0.01001 	Train Acc: 0.7748 	Test Loss: 0.02332 	Test Acc: 0.5463 	Time: 39.8814
Epoch: 9 	Train Loss: 0.00854 	Train Acc: 0.8092 	Test Loss: 0.02569 	Test Acc: 0.