In [1]:
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
import argparse
import math
import random
import shutil
import sys

import torch
import torch.nn as nn
import torch.optim as optim

import wandb

from torch.utils.data import DataLoader
from torchvision import transforms

from compressai.datasets import ImageFolder
from compressai.zoo import models

from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

import os.path

In [3]:
class RateDistortionLossMSE(nn.Module):
    """Custom rate distortion loss with a Lagrangian parameter."""

    def __init__(self, lmbda=1e-2):
        super().__init__()
        self.mse = nn.MSELoss()
        self.lmbda = lmbda

    def forward(self, output, target):
        N, _, H, W = target.size()
        out = {}
        num_pixels = N * H * W

        out["bpp_loss"] = sum(
            (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
            for likelihoods in output["likelihoods"].values()
        )
        out["mse_loss"] = self.mse(output["x_hat"], target)
        out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"]

        return out

In [4]:
class RateDistortionLossMSSSIM(nn.Module):
    """Custom rate distortion loss with a Lagrangian parameter."""

    def __init__(self, lmbda=1e-2):
        super().__init__()
        self.mse = nn.MSELoss()
        self.ms_ssim = ms_ssim
        self.lmbda = lmbda

    def forward(self, output, target):
        N, _, H, W = target.size()
        out = {}
        num_pixels = N * H * W

        out["bpp_loss"] = sum(
            (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
            for likelihoods in output["likelihoods"].values()
        )
        out["msssim_loss"] = 1 - self.ms_ssim(output["x_hat"], target, data_range=1.0, size_average=True)
        out["loss"] = self.lmbda * out["msssim_loss"] + out["bpp_loss"]

        return out

In [5]:
class AverageMeter:
    """Compute running average."""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [6]:
class CustomDataParallel(nn.DataParallel):
    """Custom DataParallel to access the module methods."""

    def __getattr__(self, key):
        try:
            return super().__getattr__(key)
        except AttributeError:
            return getattr(self.module, key)

In [7]:
def configure_optimizers(net, learning_rate, aux_learning_rate):
    """Separate parameters for the main optimizer and the auxiliary optimizer.
    Return two optimizers"""

    parameters = {
        n
        for n, p in net.named_parameters()
        if not n.endswith(".quantiles") and p.requires_grad
    }
    aux_parameters = {
        n
        for n, p in net.named_parameters()
        if n.endswith(".quantiles") and p.requires_grad
    }

    # Make sure we don't have an intersection of parameters
    params_dict = dict(net.named_parameters())
    inter_params = parameters & aux_parameters
    union_params = parameters | aux_parameters

    assert len(inter_params) == 0
    assert len(union_params) - len(params_dict.keys()) == 0

    optimizer = optim.Adam(
        (params_dict[n] for n in sorted(parameters)),
        lr=learning_rate,
    )
    aux_optimizer = optim.Adam(
        (params_dict[n] for n in sorted(aux_parameters)),
        lr=aux_learning_rate,
    )
    return optimizer, aux_optimizer

In [8]:
def train_one_epoch(
    model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm, metric
):
    model.train()
    device = next(model.parameters()).device

    for i, d in enumerate(train_dataloader):
        d = d.to(device)

        optimizer.zero_grad()
        aux_optimizer.zero_grad()

        out_net = model(d)

        out_criterion = criterion(out_net, d)
        out_criterion["loss"].backward()
        if clip_max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
        optimizer.step()

        aux_loss = model.aux_loss()
        aux_loss.backward()
        aux_optimizer.step()

        if i % 10 == 0:
            print(
                f"Train epoch {epoch}: ["
                f"{i*len(d)}/{len(train_dataloader.dataset)}"
                f" ({100. * i / len(train_dataloader):.0f}%)]"
                f'\tLoss: {out_criterion["loss"].item():.3f} |'
                f'\t{metric.upper()} loss: {out_criterion[metric + "_loss"].item():.3f} |'
                f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |'
                f"\tAux loss: {aux_loss.item():.2f}"
            )
            wandb.log({"loss": out_criterion["loss"].item()})
            wandb.log({"loss_" + metric: out_criterion[metric + "_loss"].item()})
            wandb.log({"loss_bpp": out_criterion["bpp_loss"].item()})
            wandb.log({"loss_aux": aux_loss.item()})

In [9]:
def test_epoch(epoch, test_dataloader, model, criterion, metric):
    model.eval()
    device = next(model.parameters()).device

    loss = AverageMeter()
    bpp_loss = AverageMeter()
    metric_loss = AverageMeter()
    aux_loss = AverageMeter()

    with torch.no_grad():
        for d in test_dataloader:
            d = d.to(device)
            out_net = model(d)
            out_criterion = criterion(out_net, d)

            aux_loss.update(model.aux_loss())
            bpp_loss.update(out_criterion["bpp_loss"])
            loss.update(out_criterion["loss"])
            metric_loss.update(out_criterion[metric + "_loss"])

    print(
        f"Test epoch {epoch}: Average losses:"
        f"\tLoss: {loss.avg:.3f} |"
        f"\t{metric.upper()} loss: {metric_loss.avg:.3f} |"
        f"\tBpp loss: {bpp_loss.avg:.2f} |"
        f"\tAux loss: {aux_loss.avg:.2f}\n"
    )

    return loss.avg

In [10]:
def save_checkpoint(state, is_best, filename):
    torch.save(state, filename + ".pth.tar")
    if is_best:
        shutil.copyfile(filename + ".pth.tar", filename + "_best_loss.pth.tar")

In [11]:
patch_size = (512, 512)
dataset = "/home/clemens/Documents/TU Wien/2021W/Bachelor Thesis/datasets/selection"
model_dir = "/home/clemens/Documents/TU Wien/2021W/Bachelor Thesis/final_training"
batch_size = 1
test_batch_size=2
learning_rate=1e-4
aux_learning_rate=1e-4
num_workers = 8
save = True
clip_max_norm=1.0
wandb_project = "Synthetic Image Compression"

epoch_split=1
epoch_final=2
model = "mbt2018"
lmbda = {
    'mse': {
        1: 0.0018,
        2: 0.0035,
        3: 0.0067,
        4: 0.0130,
        5: 0.0250,
        6: 0.0483,
        7: 0.0932,
        8: 0.1800
    },
    'msssim': {
        1: 2.4,
        2: 4.58,
        3: 8.73,
        4: 16.64,
        5: 31.37,
        6: 60.5,
        7: 115.37,
        8: 220
    }
}

wandb_ids = {
    'base': {},
    'fine_mse': {},
    'fine_msssim': {}
}

In [12]:
wandb.config = {
  "learning_rate": learning_rate,
  "epochs": epoch_split,
  "batch_size": batch_size
}

In [13]:
def create_filename(model, quality, loss_fn, target_epochs):
    return model + "_q" + str(quality) + "_" + loss_fn + "_" + str(target_epochs) + "ep"

In [14]:
def load_checkpoint_if_exists(filename, net, optimizer, aux_optimizer, lr_scheduler, device):
    if os.path.exists(filename):
        print("Loading", filename)
        checkpoint = torch.load(filename, map_location=device)
        last_epoch = checkpoint["epoch"] + 1
        net.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        return last_epoch
    return 0

In [15]:
def train_base(model, quality, lmbda, train_dataloader, test_dataloader):
    print("Training " + model + " at quality " + str(quality) + " from scratch; lambda=" + str(lmbda))
    device = "cuda" if torch.cuda.is_available() else "cpu"
    net = models[model](quality=quality)
    net = net.to(device)
    if torch.cuda.device_count() > 1:
        net = CustomDataParallel(net)
    
    optimizer, aux_optimizer = configure_optimizers(net, learning_rate, aux_learning_rate)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
    criterion = RateDistortionLossMSE(lmbda=lmbda)
    
    last_epoch = 0
    
    model_file_base = model_dir + '/' + create_filename(model, quality, "mse", epoch_split)
        
    last_epoch = load_checkpoint_if_exists(
        model_file_base + ".pth.tar", net, optimizer, aux_optimizer, lr_scheduler, device
    )
    if last_epoch >= epoch_split:
        print("Found checkpoint for this model with " + str(last_epoch) + " epochs - nothing to do")
        return
    
    if quality in wandb_ids['fine_' + metric]:
        wandb.init(project=wandb_project, entity="cmw98", resume="allow", reinit=True, id=wandb_ids['fine_' + metric][quality])
    else:
        wandb.init(project=wandb_project, entity="cmw98", resume="allow", reinit=True)
        wandb.run.name = model + "_q" + str(quality) + "_" + metric + "_adapted_lmbda"    
    wandb.watch(net, criterion=criterion, log="gradients", log_freq=1, log_graph=(False))

    best_loss = float("inf")
    for epoch in range(last_epoch, epoch_split):
        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        train_one_epoch(
            net,
            criterion,
            train_dataloader,
            optimizer,
            aux_optimizer,
            epoch,
            clip_max_norm,
            "mse"
        )
        loss = test_epoch(epoch, test_dataloader, net, criterion, "mse")
        lr_scheduler.step(loss)
        
        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        
        if save:
            save_checkpoint(
                {
                    "epoch": epoch,
                    "state_dict": net.state_dict(),
                    "loss": loss,
                    "optimizer": optimizer.state_dict(),
                    "aux_optimizer": aux_optimizer.state_dict(),
                    "lr_scheduler": lr_scheduler.state_dict(),
                },
                is_best,
                filename=model_file_base
            )

In [16]:
def train_fine(model, quality, lmbda, train_dataloader, test_dataloader, metric):
    print("Finetuning " + model + " at quality " + str(quality) + " for " + metric + "; lambda=" + str(lmbda))
    device = "cuda" if torch.cuda.is_available() else "cpu"
    net = models[model](quality=quality)
    net = net.to(device)
    if torch.cuda.device_count() > 1:
        net = CustomDataParallel(net)
    
    optimizer, aux_optimizer = configure_optimizers(net, learning_rate, aux_learning_rate)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
    criterion = (RateDistortionLossMSE(lmbda=lmbda) if metric == 'mse' else RateDistortionLossMSSSIM(lmbda=lmbda))

    last_epoch = 0

    model_file_base = model_dir + '/' + create_filename(model, quality, metric, epoch_final)
    
    last_epoch = load_checkpoint_if_exists(
        model_file_base + ".pth.tar", net, optimizer, aux_optimizer, lr_scheduler, device
    )
    if last_epoch < epoch_split:
        last_epoch = load_checkpoint_if_exists(
            model_dir + '/' + create_filename(model, quality, 'mse', epoch_split) + ".pth.tar",
            net, optimizer, aux_optimizer, lr_scheduler, device
        )
    if last_epoch < epoch_split or last_epoch >= epoch_final:
        print("Base model is at " + str(last_epoch) + " epochs - aborting")
        return
    
    if quality in wandb_ids['fine_' + metric]:
        wandb.init(project=wandb_project, entity="cmw98", resume="allow", reinit=True, id=wandb_ids['fine_' + metric][quality])
    else:
        wandb.init(project=wandb_project, entity="cmw98", resume="allow", reinit=True)
        wandb.run.name = model + "_q" + str(quality) + "_" + metric + "_adapted_lmbda"    
    wandb.watch(net, criterion=criterion, log="gradients", log_freq=1, log_graph=(False))

    best_loss = float("inf")
    for epoch in range(last_epoch, epoch_final):
        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        train_one_epoch(
            net,
            criterion,
            train_dataloader,
            optimizer,
            aux_optimizer,
            epoch,
            clip_max_norm,
            metric
        )
        loss = test_epoch(epoch, test_dataloader, net, criterion, metric)
        lr_scheduler.step(loss)
        
        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        
        if save:
            save_checkpoint(
                {
                    "epoch": epoch,
                    "state_dict": net.state_dict(),
                    "loss": loss,
                    "optimizer": optimizer.state_dict(),
                    "aux_optimizer": aux_optimizer.state_dict(),
                    "lr_scheduler": lr_scheduler.state_dict(),
                },
                is_best,
                filename=model_file_base
            )

In [17]:
train_transforms = transforms.Compose(
    [transforms.RandomCrop(patch_size), transforms.ToTensor()]
)

test_transforms = transforms.Compose(
    [transforms.CenterCrop(patch_size), transforms.ToTensor()]
)

train_dataset = ImageFolder(dataset, split="train", transform=train_transforms)
test_dataset = ImageFolder(dataset, split="test", transform=test_transforms)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True,
    pin_memory=(torch.cuda.is_available()),
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=test_batch_size,
    num_workers=num_workers,
    shuffle=False,
    pin_memory=(torch.cuda.is_available()),
)

for quality in [4, 6, 8]:
    for metric in ['mse', 'msssim']:
        current_lmbda = lmbda[metric][quality]
        if metric == 'mse':
            train_base(model, quality, current_lmbda, train_dataloader, test_dataloader)
            train_fine(model, quality, current_lmbda, train_dataloader, test_dataloader, 'mse')
        else:
            train_fine(model, quality, current_lmbda, train_dataloader, test_dataloader, 'msssim')

Training mbt2018 at quality 4 from scratch; lambda=0.013
Loading /home/clemens/Documents/TU Wien/2021W/Bachelor Thesis/final_training/mbt2018_q4_mse_1ep.pth.tar
Found checkpoint for this model with 1 epochs - nothing to do
Finetuning mbt2018 at quality 4 for mse; lambda=0.013
Loading /home/clemens/Documents/TU Wien/2021W/Bachelor Thesis/final_training/mbt2018_q4_mse_2ep.pth.tar
Base model is at 2 epochs - aborting
Finetuning mbt2018 at quality 4 for msssim; lambda=16.64
Loading /home/clemens/Documents/TU Wien/2021W/Bachelor Thesis/final_training/mbt2018_q4_mse_1ep.pth.tar


[34m[1mwandb[0m: Currently logged in as: [33mcmw98[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Learning rate: 0.0001


KeyboardInterrupt: 