In [1]:
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
import argparse
import math
import random
import shutil
import sys

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import transforms

from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

from piq import psnr

from compressai.datasets import ImageFolder
from compressai.zoo import models

import matplotlib as mpl
import matplotlib.pyplot as plt

import pandas as pd

import os.path

In [3]:
class RateDistortionLoss(nn.Module):
    """Custom rate distortion loss with a Lagrangian parameter."""

    def __init__(self, lmbda=1e-2):
        super().__init__()
        self.ms_ssim = ms_ssim
        self.psnr = psnr
        self.lmbda = lmbda

    def forward(self, output, target):
        N, _, H, W = target.size()
        out = {}
        num_pixels = N * H * W

        out["bpp_loss"] = sum(
            (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
            for likelihoods in output["likelihoods"].values()
        )
        
        out["msssim_loss"] = self.ms_ssim(output["x_hat"].clamp(min=0, max=1), target, data_range=1.0, size_average=True)
        out["psnr_loss"] = self.psnr(output["x_hat"].clamp(min=0, max=1), target, data_range=1.0)
        out["loss"] = self.lmbda * 255 ** 2 * out["psnr_loss"] + out["bpp_loss"]

        return out

In [4]:
class AverageMeter:
    """Compute running average."""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [5]:
class CustomDataParallel(nn.DataParallel):
    """Custom DataParallel to access the module methods."""

    def __getattr__(self, key):
        try:
            return super().__getattr__(key)
        except AttributeError:
            return getattr(self.module, key)

In [6]:
def configure_optimizers(net, learning_rate, aux_learning_rate):
    """Separate parameters for the main optimizer and the auxiliary optimizer.
    Return two optimizers"""

    parameters = {
        n
        for n, p in net.named_parameters()
        if not n.endswith(".quantiles") and p.requires_grad
    }
    aux_parameters = {
        n
        for n, p in net.named_parameters()
        if n.endswith(".quantiles") and p.requires_grad
    }

    # Make sure we don't have an intersection of parameters
    params_dict = dict(net.named_parameters())
    inter_params = parameters & aux_parameters
    union_params = parameters | aux_parameters

    assert len(inter_params) == 0
    assert len(union_params) - len(params_dict.keys()) == 0

    optimizer = optim.Adam(
        (params_dict[n] for n in sorted(parameters)),
        lr=learning_rate,
    )
    aux_optimizer = optim.Adam(
        (params_dict[n] for n in sorted(aux_parameters)),
        lr=aux_learning_rate,
    )
    return optimizer, aux_optimizer

In [7]:
def test_model(test_dataloader, model, criterion):
    model.eval()
    device = next(model.parameters()).device

    loss = AverageMeter()
    bpp_loss = AverageMeter()
    msssim_loss = AverageMeter()
    psnr_loss = AverageMeter()
    aux_loss = AverageMeter()

    with torch.no_grad():
        for d in test_dataloader:
            d = d.to(device)
            out_net = model(d)
            out_criterion = criterion(out_net, d)

            aux_loss.update(model.aux_loss())
            bpp_loss.update(out_criterion["bpp_loss"])
            loss.update(out_criterion["loss"])
            msssim_loss.update(out_criterion["msssim_loss"])
            psnr_loss.update(out_criterion["psnr_loss"])

    print(
        f"Average losses:"
        f"\tLoss: {loss.avg:.3f} |"
        f"\tMS-SSIM loss: {msssim_loss.avg:.3f} |"
        f"\tPSNR loss: {psnr_loss.avg:.3f} |"
        f"\tBpp loss: {bpp_loss.avg:.2f} |"
        f"\tAux loss: {aux_loss.avg:.2f}\n"
    )

    return [bpp_loss.avg, msssim_loss.avg, psnr_loss.avg];

In [8]:
def create_modelname(model, quality, loss_fn, reference_trained):
    return model + "_q" + str(quality) + "_" + loss_fn + ("_ref" if reference_trained else "_special")

def create_filename(model, quality, loss_fn, target_epochs):
    return model + "_q" + str(quality) + "_" + loss_fn + "_" + str(target_epochs) + "ep"

In [9]:
patch_size = (512, 512)
dataset = "/home/clemens/Documents/TU Wien/2021W/Bachelor Thesis/datasets/selection"
model_dir = "/home/clemens/Documents/TU Wien/2021W/Bachelor Thesis/final_training"
ref_model_dir = "/home/clemens/Documents/TU Wien/2021W/Bachelor Thesis/reference_training"
batch_size = 8
test_batch_size=2
num_workers = 8
epoch_final = 500
lmbda = {
    'mse': {
        1: 0.0018,
        2: 0.0035,
        3: 0.0067,
        4: 0.0130,
        5: 0.0250,
        6: 0.0483,
        7: 0.0932,
        8: 0.1800
    },
    'msssim': {
        1: 2.4,
        2: 4.58,
        3: 8.73,
        4: 16.64,
        5: 31.37,
        6: 60.5,
        7: 115.37,
        8: 220
    }
}

In [10]:
test_transforms = transforms.Compose(
    [transforms.CenterCrop(patch_size), transforms.ToTensor()]
)

test_dataset = ImageFolder(dataset, split="test", transform=test_transforms)

device = "cuda" if torch.cuda.is_available() else "cpu"

test_dataloader = DataLoader(
    test_dataset,
    batch_size=test_batch_size,
    num_workers=num_workers,
    shuffle=False,
    pin_memory=(device == "cuda"),
)

results = {}

In [11]:
""" Models trained with animated images """
model = "mbt2018"
for quality in [4,6,8]: 
    for metric in ["mse", "msssim"]:
        filename = model_dir + '/' + create_filename(model, quality, metric, epoch_final) + ".pth.tar"
        print("checking for " + filename)
        if os.path.exists(filename):
            net = models[model](quality=quality, pretrained=False)
            net = net.to(device)
            
            if torch.cuda.device_count() > 1:
                net = CustomDataParallel(net)
            
            criterion = RateDistortionLoss(lmbda=lmbda[metric][quality])
            
            checkpoint = torch.load(filename, map_location=device)
            net.load_state_dict(checkpoint["state_dict"])
            model_label_custom = create_modelname(model, quality, metric, False)
            if checkpoint["epoch"] + 1 != epoch_final:
                print("WARNING: " + model_label_custom + " has " + checkpoint["epoch"] + 1 + ", desired: " + epoch_final)
            losses = test_model(test_dataloader, net, criterion)
            results[model_label_custom] = losses

checking for /home/clemens/Documents/TU Wien/2021W/Bachelor Thesis/final_training/mbt2018_q4_mse_500ep.pth.tar


Traceback (most recent call last):
  File "/home/clemens/.conda/envs/venv-ba/lib/python3.9/multiprocessing/queues.py", line 251, in _feed
    send_bytes(obj)
  File "/home/clemens/.conda/envs/venv-ba/lib/python3.9/multiprocessing/connection.py", line 205, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/clemens/.conda/envs/venv-ba/lib/python3.9/multiprocessing/connection.py", line 416, in _send_bytes
    self._send(header + buf)
  File "/home/clemens/.conda/envs/venv-ba/lib/python3.9/multiprocessing/connection.py", line 373, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

In [None]:
""" Models trained with generic photographs """
model = "mbt2018"
for quality in [4,6,8]: 
    for metric in ["mse", "msssim"]:
        filename = ref_model_dir + '/' + create_filename(model, quality, metric, epoch_final) + ".pth.tar"
        print("checking for " + filename)
        if os.path.exists(filename):
            net = models[model](quality=quality, pretrained=False)
            net = net.to(device)
            
            if torch.cuda.device_count() > 1:
                net = CustomDataParallel(net)
            
            criterion = RateDistortionLoss(lmbda=lmbda[metric][quality])
            
            checkpoint = torch.load(filename, map_location=device)
            net.load_state_dict(checkpoint["state_dict"])
            model_label_custom = create_modelname(model, quality, metric, True)
            if checkpoint["epoch"] + 1 != epoch_final:
                print("WARNING: " + model_label_custom + " has " + checkpoint["epoch"] + 1 + ", desired: " + epoch_final)
            losses = test_model(test_dataloader, net, criterion)
            results[model_label_custom] = losses

In [None]:
labels = list(map(lambda item: item[0], results.items()))
bpp = list(map(lambda item: item[1][0].item(), results.items()))
msssim = list(map(lambda item: item[1][1].item(), results.items()))
psnr = list(map(lambda item: item[1][2].item(), results.items()))
df = pd.DataFrame({
    'bpp': bpp,
    'msssim': msssim,
    'psnr': psnr,
    'model': map(lambda item: item.split("_")[0], labels),
    'quality': map(lambda item: int(item.split("_")[1][1:]), labels),
    'metric': map(lambda item: item.split("_")[2], labels),
    'training': map(lambda item: item.split("_")[3], labels)
})
print(df)

In [None]:
mpl.rcParams['figure.dpi'] = 300
fig, ax = plt.subplots()
model="mbt2018"
for training in ["vanilla", "custom"]:
    for metric in ["mse", "msssim"]:
        line, = ax.plot(
            'bpp',
            'psnr',
            data=df[df['model'] == model][df['metric'] == metric][df['training'] == training],
            label=model + " " + metric + " " + training, linestyle=(':' if training == 'custom' else '-'),
            marker='o'
        )
for bpp, psnr, quality in zip(df['bpp'], df['psnr'], df['quality']):
    ax.annotate('%s' % quality, xy=(bpp, psnr), xytext=(10,-10), textcoords='offset points')
fig.set_figwidth(8)
fig.set_figheight(6)
ax.legend(bbox_to_anchor=(1,1), loc="upper left")
ax.grid(True)

plt.show()