The following was requested:

1. We take the same original picture for all attempts.
2. We take the same "noisy picture" for all attempts - and we give it MSE and SSIM score.
3. We denoise with original "float" - give MSE and SSIM.
4. We denoise with all types of quantization that you tried and give MSE and SSIM.
5. Please give a conclusion in the end - if some types of quantization are better than others.


In [1]:
%matplotlib notebook

import argparse
import time
import os


In [2]:

import matplotlib.pyplot as plt
from data import NoisyBSDSDataset
from argument import Args
from model import DnCNN, UDnCNN, DUDnCNN, QDUDnCNN
import nntools as nt
#from utils import DenoisingStatsManager, plot, NoisyBSDSDataset
import utils
import utils2
from skimage import metrics

import cv2
import numpy as np
import torch


In [3]:

import torch.utils.data as td
import torch.quantization.quantize_fx as quantize_fx
from torch.quantization.fuse_modules import fuse_known_modules
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td
import torchvision as tv
from PIL import Image
import model
import inference
from prettytable import PrettyTable

In [4]:
def main(img_set,model_path,qat=False,quantize='none'):
    args = Args()
    args.quantize = quantize
    device = 'cpu'
    if qat:
        denoise, qat_state = inference.load_model(model_path, QDUDnCNN, args.D, args.C, device=device)
    else:
        denoise, qat_state = inference.load_model(model_path, DUDnCNN, args.D, args.C, device=device)

    img = []
    titles = ['clean', 'noise', 'denoise']
    x, clean = img_set
    x = x.unsqueeze(0).to(device)
    img.append(clean)
    img.append(x[0])

    if args.quantize:
        print('Quantize model...'+quantize)
        denoise = utils2.quantize_model(args.quantize, denoise, input_example=x, qat_state=qat_state)

    t = time.time()
    with torch.no_grad():
        y = denoise(x)
    img.append(y[0])

    print(f'Elapsed: {(time.time() - t) * 1000:.2f}ms')
    print(f'Image size is {x[0].shape}.')
    y1 = utils2.myimret(img[0])
    y2 = utils2.myimret(img[1])
    y3 = utils2.myimret(img[2])
    utils2.compare_images(y1,y3)
    

    fig, axes = plt.subplots(ncols=3, figsize=(9,5), sharex='all', sharey='all')
    for i in range(len(img)):
        utils2.myimshow(img[i], ax=axes[i])
        axes[i].set_title(f'{titles[i]}')



In [5]:
def eval_origin(img_set,display=False): 
    device = 'cpu'
    
    img = []
    x, clean = img_set
#    x = x.unsqueeze(0).to(device)
    
    img.append(clean)
#    img.append(x[0])
    img.append(x)

    y1 = utils2.myimret(img[0])
    y2 = utils2.myimret(img[1])
    mse,ssim = utils2.compare_images(y1,y1,display)
    return mse, ssim

In [6]:
def eval_noisy(img_set,display=False): 
    device = 'cpu'
    
    img = []
    x, clean = img_set
#    x = x.unsqueeze(0).to(device)
    
    img.append(clean)
#    img.append(x[0])
    img.append(x)

    y1 = utils2.myimret(img[0])
    y2 = utils2.myimret(img[1])
    mse,ssim = utils2.compare_images(y1,y2,display)
    return mse, ssim

In [7]:
def eval_image_set(img_set,model_path,qat=False,quantize='none',display=False,print_out=False): 
    args = Args()
    args.quantize = quantize
    device = 'cpu'
    if qat:
        denoise, qat_state = inference.load_model(model_path, QDUDnCNN, args.D, args.C, device=device)
    else:
        denoise, qat_state = inference.load_model(model_path, DUDnCNN, args.D, args.C, device=device)
    
    img = []
    titles = ['clean', 'noise', 'denoise'] 
    x, clean = img_set
    x = x.unsqueeze(0).to(device)
    
    img.append(clean)
    img.append(x[0])

    if args.quantize:
        denoise = utils2.quantize_model(args.quantize, denoise, input_example=x, qat_state=qat_state)

    t = time.time()
    with torch.no_grad():
        y = denoise(x) 
    t_elapsed = (time.time() - t)*1000
    img.append(y[0])
    y1 = utils2.myimret(img[0])
    y2 = utils2.myimret(img[1])
    y3 = utils2.myimret(img[2])
    mse0, ssim0 = utils2.compare_images(y1,y2,print_out)
    mse, ssim = utils2.compare_images(y1,y3,print_out)

    if display:
        fig, axes = plt.subplots(ncols=3, figsize=(9,5), sharex='all', sharey='all')
        for i in range(len(img)):
            utils2.myimshow(img[i], ax=axes[i])
            axes[i].set_title(f'{titles[i]}') 
    return t_elapsed, quantize, mse, ssim

In [8]:
def experiment(img,display=True):
    img_set = img

    t = PrettyTable(['training','quantization','time','mse','ssim'])
    t.float_format=".5"
    t.float_format['time'] = ".2"
    t.border=True
    t.align="l"
    t.align["mean time"]="r"

    path = os.environ.get('TRAINING_DIR')+'/no_qat/checkpoint.pth' 
    mse_origin,ssim_origin = eval_origin(img_set, False)
    mse_noisy, ssim_noisy = eval_noisy(img_set, False)
    t.add_row(['NA (origin)',"NA","NA",mse_origin,ssim_origin])
    t.add_row(['NA (noisy)',"NA","NA",mse_noisy,ssim_noisy])

    eval_time_none,q, mse, ssim = eval_image_set(img_set, path, False,'none', display) 
    t.add_row(['standard',q,(eval_time_none/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'fx_static', False) 
    t.add_row(['standard',q,(eval_time/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'fx_dynamic', False) 
    t.add_row(['standard',q,(eval_time/eval_time_none)*100,mse,ssim])

    path = os.environ.get('TRAINING_DIR')+'/qat/checkpoint.pth' 
    eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'none', False) 
    t.add_row(['QAT',q,(eval_time/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_static', False) 
    t.add_row(['QAT',q,(eval_time/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_dynamic', False) 
    t.add_row(['QAT',q,(eval_time/eval_time_none)*100,mse,ssim])
    path = os.environ.get('TRAINING_DIR')+'/qat_fx_static/checkpoint.pth' 
    eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'none', False) 
    t.add_row(['QAT_fc_static',q,(eval_time/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_static', False) 
    t.add_row(['QAT_fc_static',q,(eval_time/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_dynamic', False) 
    t.add_row(['QAT_fc_static',q,(eval_time/eval_time_none)*100,mse,ssim])
    print(t)
    

In [9]:
def experiment_mean(img,cycles,display=True,mean_only=False):
    img_set = img
    
    cs = cycles
    rows = 6
    time_array = np.zeros((rows,cycles),dtype = np.float32)
    mse_array = np.zeros((rows,cycles),dtype = np.float32)
    ssim_array = np.zeros((rows,cycles),dtype = np.float32)

    path = os.environ.get('TRAINING_DIR')+'/no_qat/checkpoint.pth' 
    eval_image_set(img_set, path, False,'none', True) 


    for i in range(cs):
        if (mean_only == False):
            print("==> experiment %2d" %i)
        t = PrettyTable(['training','quantization','time','mse','ssim'])
        t.float_format=".5"
        t.float_format['time'] = ".2"
        t.border=True
        t.align="l"
        t.align["mean time"]="r"
        
        path = os.environ.get('TRAINING_DIR')+'/no_qat/checkpoint.pth' 
        mse_origin,ssim_origin = eval_origin(img_set, False)
        mse_noisy, ssim_noisy = eval_noisy(img_set, False)
        t.add_row(['NA (origin)',"NA","NA",mse_origin,ssim_origin])
        t.add_row(['NA (noisy)',"NA","NA",mse_noisy,ssim_noisy])

        eval_time_none,q, mse, ssim = eval_image_set(img_set, path, False,'none', display) 
        time_array[0,i] = (eval_time_none/eval_time_none)*100
        mse_array[0,i] = mse
        ssim_array[0,i] = ssim
        t.add_row(['standard',q,(eval_time_none/eval_time_none)*100,mse,ssim])

        eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'fx_static', False) 
        time_array[1,i] = (eval_time/eval_time_none)*100
        mse_array[1,i] = mse
        ssim_array[1,i] = ssim
        t.add_row(['standard',q,(eval_time/eval_time_none)*100,mse,ssim])

        path = os.environ.get('TRAINING_DIR')+'/qat/checkpoint.pth' 
        eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'none', False) 
        time_array[2,i] = (eval_time/eval_time_none)*100
        mse_array[2,i] = mse
        ssim_array[2,i] = ssim
        t.add_row(['QAT',q,(eval_time/eval_time_none)*100,mse,ssim])
    
        eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_static', False) 
        time_array[3,i] = (eval_time/eval_time_none)*100    
        mse_array[3,i] = mse
        ssim_array[3,i] = ssim
        t.add_row(['QAT',q,(eval_time/eval_time_none)*100,mse,ssim])
        
        path = os.environ.get('TRAINING_DIR')+'/qat_fx_static/checkpoint.pth' 
        eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'none', False)
        time_array[4,i] = (eval_time/eval_time_none)*100
        mse_array[4,i] = mse
        ssim_array[4,i] = ssim
        t.add_row(['QAT_fx_static',q,(eval_time/eval_time_none)*100,mse,ssim])
        
        eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_static', False) 
        time_array[5,i] = (eval_time/eval_time_none)*100
        mse_array[5,i] = mse
        ssim_array[5,i] = ssim
        t.add_row(['QAT_fx_static',q,(eval_time/eval_time_none)*100,mse,ssim])

        if (mean_only == False):    
            print(t)
    
    print("Summary of experiments")
    
    pt = PrettyTable(['training','quantization','mean time','mean mse','mean ssim'])
    pt.float_format = ".5"
    pt.float_format['mean time'] = ".2"
    pt.border=True
    pt.align="l"
    pt.align["mean time"]="r"
    
    pt.add_row(['standard','none',np.mean(time_array[0,:]),np.mean(mse_array[0,:]),np.mean(ssim_array[0,:])])
    pt.add_row(['standard','static',np.mean(time_array[1,:]),np.mean(mse_array[1,:]),np.mean(ssim_array[1,:])])
#    pt.add_row(['QAT','none',np.mean(time_array[2,:]),np.mean(mse_array[2,:]),np.mean(ssim_array[2,:])])    
#    pt.add_row(['QAT','static',np.mean(time_array[3,:]),np.mean(mse_array[3,:]),np.mean(ssim_array[3,:])])
    pt.add_row(['Q-aware','none',np.mean(time_array[4,:]),np.mean(mse_array[4,:]),np.mean(ssim_array[4,:])])
    pt.add_row(['Q-aware','static',np.mean(time_array[5,:]),np.mean(mse_array[5,:]),np.mean(ssim_array[5,:])])
    print(pt)



In [10]:
dataset_root_dir = os.environ.get('DATA_DIR')+'/images'
test_set = NoisyBSDSDataset(dataset_root_dir, mode='test', image_size=(320, 320))
img_set = test_set[1]

## Single Image Test

In [11]:
experiment_mean(img_set,10,False, True)

<IPython.core.display.Javascript object>

  reduce_range will be deprecated in a future release of PyTorch."
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


Summary of experiments
+----------+--------------+-----------+--------------+------------+
| training | quantization | mean time | mean mse     | mean ssim  |
+----------+--------------+-----------+--------------+------------+
| standard | none         |     100.0 | 0.0011254027 | 0.91576415 |
| standard | static       | 60.100544 | 0.001145053  | 0.9137362  |
| Q-aware  | none         | 100.09557 | 0.11767028   | 0.30691195 |
| Q-aware  | static       | 60.278267 | 0.0011510642 | 0.91286737 |
+----------+--------------+-----------+--------------+------------+


## Ten Images Test

In [None]:
x = 3
for i in range(x):
    img_set = test_set[i]
    print(("Experiment {}").format(i))
    experiment_mean(img_set,10,False, True)

Experiment 0


<IPython.core.display.Javascript object>

Summary of experiments
+----------+--------------+-----------+--------------+------------+
| training | quantization | mean time | mean mse     | mean ssim  |
+----------+--------------+-----------+--------------+------------+
| standard | none         |     100.0 | 0.0020540669 | 0.8655432  |
| standard | static       |  61.99311 | 0.0020798005 | 0.86423504 |
| Q-aware  | none         | 102.02004 | 0.10235883   | 0.30135372 |
| Q-aware  | static       |  61.20986 | 0.0020745979 | 0.86549693 |
+----------+--------------+-----------+--------------+------------+
Experiment 1


<IPython.core.display.Javascript object>

Summary of experiments
+----------+--------------+-----------+--------------+------------+
| training | quantization | mean time | mean mse     | mean ssim  |
+----------+--------------+-----------+--------------+------------+
| standard | none         |     100.0 | 0.0010614317 | 0.92161465 |
| standard | static       | 60.833027 | 0.0010791349 | 0.9194784  |
| Q-aware  | none         | 101.64209 | 0.13920563   | 0.31612575 |
| Q-aware  | static       | 60.810005 | 0.0010804672 | 0.91957444 |
+----------+--------------+-----------+--------------+------------+
Experiment 2


<IPython.core.display.Javascript object>

Summary of experiments
+----------+--------------+-----------+--------------+------------+
| training | quantization | mean time | mean mse     | mean ssim  |
+----------+--------------+-----------+--------------+------------+
| standard | none         |     100.0 | 0.0011897937 | 0.9188344  |
| standard | static       |  57.61261 | 0.001213433  | 0.91595393 |
| Q-aware  | none         | 98.935684 | 0.12426771   | 0.25497526 |
| Q-aware  | static       | 56.383747 | 0.0012208648 | 0.9152276  |
+----------+--------------+-----------+--------------+------------+


## Summmary

The experiment clearly showed the significant processing time improvement without much of the quality degradation using all variants of static quantization. Suprisingly, Quantization Aware Training did not produce superior results over fx_static method, and in some cases the quality were reduced. It is possible that QAT will depend on the type of the data and with different dataset or algorithm will produce significantly better results. That assumption needs to be explored. The field is still very new. The project used PyTorch latest quantization techniques and software, some of which is still in pre-release. PyTorch quantization support only x86 and ARM architecures. x86 backend was used in the project.

The testing was done with one original image and and one noisy deriviated image. The time fluctuation between experiments can be attributed to the compute infrastructure and cluster resources. The summary values appeard to be  consistent over multiple experement runs performed.