In [None]:
!pip install compressai

Collecting compressai
  Downloading compressai-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (295 kB)
[K     |████████████████████████████████| 295 kB 5.4 MB/s 
Collecting pytorch-msssim
  Downloading pytorch_msssim-0.2.1-py3-none-any.whl (7.2 kB)
Installing collected packages: pytorch-msssim, compressai
Successfully installed compressai-1.2.0 pytorch-msssim-0.2.1


In [None]:
import argparse
import math
import random
import shutil
import sys
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision

from compressai.datasets import ImageFolder
from compressai.zoo import image_models
import compressai

In [None]:
from compressai.zoo import (bmshj2018_factorized, bmshj2018_hyperprior, mbt2018_mean, mbt2018, cheng2020_anchor)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
metric = 'mse'  # only pre-trained model for mse are available for now
quality = 1     # lower quality -> lower bit-rate (use lower quality to clearly see visual differences in the notebook)
networks = {
    'bmshj2018-factorized': bmshj2018_factorized(quality=quality, pretrained=True).eval().to(device),
    'bmshj2018-hyperprior': bmshj2018_hyperprior(quality=quality, pretrained=True).eval().to(device),
    'mbt2018-mean': mbt2018_mean(quality=quality, pretrained=True).eval().to(device),
    'mbt2018': mbt2018(quality=quality, pretrained=True).eval().to(device),
    'cheng2020-anchor': cheng2020_anchor(quality=quality, pretrained=True).eval().to(device),
}

net = networks['bmshj2018-hyperprior']

In [None]:
net.aux_loss()

tensor(172.0104, device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
from google.colab import drive
drive.mount('MyDrive')

Mounted at MyDrive


In [None]:
import os 
os.chdir(r'/content/MyDrive/MyDrive/DL_Project_HP')

### Model Classes

In [None]:
class RateDistortionLoss(nn.Module):
    """Custom rate distortion loss with a Lagrangian parameter."""

    def __init__(self, lmbda=1e-2):
        super().__init__()
        self.crossEntropy = nn.CrossEntropyLoss()
        self.lmbda = lmbda

    def forward(self, output, target, preds, labels):
        N, _, H, W = target.size()
        out = {}
        num_pixels = N * H * W

        out["bpp_loss"] = sum(
            (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
            for likelihoods in output["likelihoods"].values()
        )
        out['log_loss'] = self.crossEntropy(preds, labels)
        out["loss"] = self.lmbda * out["log_loss"] + out["bpp_loss"]

        return out


class AverageMeter:
    """Compute running average."""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class CustomDataParallel(nn.DataParallel):
    """Custom DataParallel to access the module methods."""

    def __getattr__(self, key):
        try:
            return super().__getattr__(key)
        except AttributeError:
            return getattr(self.module, key)


def configure_optimizers(net, args):
    """Separate parameters for the main optimizer and the auxiliary optimizer.
    Return two optimizers"""

    parameters = {
        n
        for n, p in net.named_parameters()
        if not n.endswith(".quantiles") and p.requires_grad
    }
    aux_parameters = {
        n
        for n, p in net.named_parameters()
        if n.endswith(".quantiles") and p.requires_grad
    }

    # Make sure we don't have an intersection of parameters
    params_dict = dict(net.named_parameters())
    inter_params = parameters & aux_parameters
    union_params = parameters | aux_parameters

    assert len(inter_params) == 0
    assert len(union_params) - len(params_dict.keys()) == 0

    optimizer = optim.Adam(
        (params_dict[n] for n in sorted(parameters)),
        lr=args.learning_rate
    )
    aux_optimizer = optim.Adam(
        (params_dict[n] for n in sorted(aux_parameters)),
        lr=args.aux_learning_rate,
    )
    return optimizer, aux_optimizer

### Train and Test Epochs

In [None]:
def train_one_epoch(
    model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm
):
    model.train()
    device = next(model.parameters()).device
    train_acc = 0

    for i, d in enumerate(train_dataloader):
        images = d[0].to(device)
        labels = d[1].to(device)
        images = images.cuda()
        labels = labels.cuda()

        optimizer.zero_grad()
        aux_optimizer.zero_grad()

        if clip_max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)

        out_net = model(images)
        preds = out_net['y_hat']
        pred_labels = out_net['y_hat'].argmax(dim=1)
        train_acc += torch.sum(labels == pred_labels).item()
        out_criterion = criterion(out_net, images, preds, labels)
        out_criterion["loss"].backward()



        optimizer.step()

        aux_loss = model.aux_loss()
        aux_loss.backward()
        aux_optimizer.step()

        if i % 100 == 0:
            print(
                f"Train epoch {epoch}: ["
                f"{i*len(images)}/{len(train_dataloader.dataset)}"
                f" ({100. * i / len(train_dataloader):.0f}%)]"

                f'\tLoss: {out_criterion["loss"].item():.3f} |'
                f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |'
                f'\tLog loss: {out_criterion["log_loss"].item():.2f} |'
                f"\tAux loss: {aux_loss.item():.2f}"
            )
    train_acc = train_acc/500
    print(f'\nTrain epoch {epoch}: \tAcc: {train_acc:.3f} |')


def test_epoch(epoch, test_dataloader, model, criterion):
    model.eval()
    device = next(model.parameters()).device

    loss = AverageMeter()
    bpp_loss = AverageMeter()
    mse_loss = AverageMeter()
    aux_loss = AverageMeter()
    test_acc = 0
    with torch.no_grad():
        for d in test_dataloader:
            images = d[0].to(device)
            labels = d[1].to(device)
            images = images.cuda()
            labels = labels.cuda()

            out_net = model(images)
            preds = out_net['y_hat']
            pred_labels = out_net['y_hat'].argmax(dim=1)
            test_acc += torch.sum(labels == pred_labels).item()

            out_criterion = criterion(out_net, images, preds, labels)

            aux_loss.update(model.aux_loss())
            bpp_loss.update(out_criterion["bpp_loss"])
            loss.update(out_criterion["loss"])
    test_acc = test_acc / 100
    print(
        f"Test epoch {epoch}: Average losses:"
        f'\tAcc: {test_acc:.3f} |'
        f"\tLoss: {loss.avg:.3f} |"
        f"\tBpp loss: {bpp_loss.avg:.2f} |"
        f'\tLog loss: {out_criterion["log_loss"].item():.2f} |'
        f"\tAux loss: {aux_loss.avg:.2f}\n"
    )

    return loss.avg


def save_checkpoint(state, epoch, is_best, filename, best_filename):
    torch.save(state, str(epoch)+filename)
    if is_best:
        shutil.copyfile(str(epoch)+ filename, best_filename)


### Main

In [None]:
def main(model, num_workers, batch_size, cuda, epoch, patch_size, learning_rate, lmbda):
    # args = parse_args(argv)

    # if args.seed is not None:
    #     torch.manual_seed(args.seed)
    #     random.seed(args.seed)

    class arguments:
      def __init__(self, model, num_workers, batch_size, cuda, epoch, patch_size, learning_rate, lmbda):
        self.model = model
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.test_batch_size = 100
        self.cuda = cuda
        self.epochs = epoch
        self.patch_size = patch_size
        self.learning_rate = learning_rate
        self.aux_learning_rate = learning_rate
        self.lmbda = lmbda
        self.save = True
        self.seed = False
        self.clip_max_norm = 1.0
        self.checkpoint = False


    tr_mean = np.asarray([0.4914, 0.4822, 0.4465])
    tr_std = np.asarray([0.247, 0.243, 0.261])

    args = arguments(model, num_workers, batch_size, cuda, epoch, patch_size, learning_rate, lmbda)

    train_transforms = transforms.Compose(
        [transforms.Resize((64,64)), transforms.ToTensor(), transforms.RandomCrop(args.patch_size), torchvision.transforms.Normalize(tr_mean, tr_std)]
    )

    test_transforms = transforms.Compose(
        [transforms.Resize((64,64)), transforms.ToTensor(), transforms.CenterCrop(args.patch_size), torchvision.transforms.Normalize(tr_mean, tr_std)]
    )

    # train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
    # test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)

    train_dataset = torchvision.datasets.CIFAR10('./CIFAR-10/',train=True,download=True, transform=train_transforms)
    test_dataset = torchvision.datasets.CIFAR10('./CIFAR-10/',train=False,download=True, transform=test_transforms)

    device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        pin_memory=(device == "cuda"),
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=(device == "cuda"),
    )


    ##########################################################################
    net = model
    resnet = torchvision.models.resnet18(pretrained = True)
    resnet.fc = nn.Linear(in_features=512, out_features=10, bias=True)

    class Net(nn.Module):
        def __init__(self, resnet, net):
            super(Net, self).__init__()

            self.g_a = net.g_a
            self.h_a = net.h_a
            self.g_s = net.g_s
            self.h_s = net.h_s
            self.entropy_bottleneck = net.entropy_bottleneck
            self.gaussian_conditional = net.gaussian_conditional
            self.res = resnet

        def forward(self, x):
            y = self.g_a(x)
            z = self.h_a(torch.abs(y))
            z_hat, z_likelihoods = self.entropy_bottleneck(z)
            scales_hat = self.h_s(z_hat)
            y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat)
            x_hat = self.g_s(y_hat)
            l_hat = self.res(x_hat)

            return {
                "x_hat": x_hat,
                "y_hat": l_hat,
                "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
            }


    net = Net(resnet, net)
    net.aux_loss = model.aux_loss
    net = net.to(device)
    ##########################################################################


    if args.cuda and torch.cuda.device_count() > 1:
        net = CustomDataParallel(net)

    optimizer, aux_optimizer = configure_optimizers(net, args)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min",factor=0.5)
    criterion = RateDistortionLoss(lmbda=args.lmbda)

    filename = str(args.lmbda) + '_check.pth.tar'
    best_filename = 'best' + filename

    last_epoch = 0
    if args.checkpoint:  # load from previous checkpoint
        print("Loading", args.checkpoint)
        checkpoint = torch.load(args.checkpoint, map_location=device)
        last_epoch = checkpoint["epoch"] + 1
        net.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

    best_loss = float("inf")
    for epoch in range(last_epoch, args.epochs):
        T11 = time.time()

        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        train_one_epoch(
            net,
            criterion,
            train_dataloader,
            optimizer,
            aux_optimizer,
            epoch,
            args.clip_max_norm,
        )
        loss = test_epoch(epoch, test_dataloader, net, criterion)
        lr_scheduler.step(loss)
        T22 = time.time()
        print(f"Time: {T22-T11:.4f}")
        is_best = loss < best_loss
        best_loss = min(loss, best_loss)

        if args.save:
            save_checkpoint(
                {
                    "epoch": epoch,
                    "state_dict": net.state_dict(),
                    "loss": loss,
                    "optimizer": optimizer.state_dict(),
                    "aux_optimizer": aux_optimizer.state_dict(),
                    "lr_scheduler": lr_scheduler.state_dict(),
                },
                epoch,
                is_best,
                filename,
                best_filename,
            )


In [None]:
lmbda = np.asarray([0.01, 0.015, 0.020, 0.05, 0.125, 0.5, 1, 5]).astype(float)*0.01

for lmb in lmbda:
  net_out = main(net, 2, 64, 1, 200, 64, 0.001, lmb)