In [1]:
import argparse
import math
import random
import shutil
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image

from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

from compressai.datasets import ImageFolder
from compressai.zoo import mbt2018

import os

from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt




In [2]:
device = "cpu"
model = "mbt2018"
suffix = "_500ep"

In [6]:
directory = '/Users/clemens/vsc-data/test_images'

def pad (x):
    h, w = x.size(2), x.size(3)
    p = 64  # maximum 6 strides of 2
    new_h = (h + p - 1) // p * p
    new_w = (w + p - 1) // p * p
    padding_left = (new_w - w) // 2
    padding_right = new_w - w - padding_left
    padding_top = (new_h - h) // 2
    padding_bottom = new_h - h - padding_top
    return F.pad(
        x,
        (padding_left, padding_right, padding_top, padding_bottom),
        mode="constant",
        value=0,
    )

for quality in [4, 6, 8]:
    for metric in ['mse', 'msssim']:
        net = mbt2018(quality=quality, pretrained=False).eval().to(device)
        file = "/Users/clemens/vsc-data/mbt/" + model + "_q" + str(quality) + "_" + metric + suffix + ".pth.tar"
        print(file)
        checkpoint = torch.load(file, map_location=device)
        net.load_state_dict(checkpoint["state_dict"])
        print("Actual epochs:" + str(checkpoint["epoch"] + 1))

        for file in os.listdir(directory):
            filename = os.path.join(directory, os.fsdecode(file))
            if filename.endswith(".png"):
                print(filename)
                img = Image.open(filename).convert('RGB')
                x = transforms.ToTensor()(img).unsqueeze(0).to(device)
                x = pad(x)
                
                with torch.no_grad():
                    out_net = net.forward(x)
                out_net['x_hat'].clamp_(0, 1)
                
                decoded_filename = os.path.join('/Users/clemens/vsc-data/mbt2018_q' + str(quality) + '_' + metric + '_fullres', file)
                print(decoded_filename)
                save_image(out_net['x_hat'], decoded_filename)

/Users/clemens/vsc-data/mbt/mbt2018_q4_mse_500ep.pth.tar
Actual epochs:500
/Users/clemens/vsc-data/test_images/c0580.png
/Users/clemens/vsc-data/mbt2018_q4_mse_fullres/c0580.png
/Users/clemens/vsc-data/test_images/g0173.png
/Users/clemens/vsc-data/mbt2018_q4_mse_fullres/g0173.png
/Users/clemens/vsc-data/test_images/g0601.png
/Users/clemens/vsc-data/mbt2018_q4_mse_fullres/g0601.png
/Users/clemens/vsc-data/test_images/g0629.png
/Users/clemens/vsc-data/mbt2018_q4_mse_fullres/g0629.png
/Users/clemens/vsc-data/test_images/c0796.png
/Users/clemens/vsc-data/mbt2018_q4_mse_fullres/c0796.png
/Users/clemens/vsc-data/test_images/c0972.png
/Users/clemens/vsc-data/mbt2018_q4_mse_fullres/c0972.png
/Users/clemens/vsc-data/test_images/g0417.png
/Users/clemens/vsc-data/mbt2018_q4_mse_fullres/g0417.png
/Users/clemens/vsc-data/test_images/g0365.png
/Users/clemens/vsc-data/mbt2018_q4_mse_fullres/g0365.png
/Users/clemens/vsc-data/test_images/g1053.png
/Users/clemens/vsc-data/mbt2018_q4_mse_fullres/g1053.pn

KeyboardInterrupt: 