Can you, please, summarize your results as follows in a separate section in the end?
The data should be as follows:
1. We take the same original picture for all attempts.
2. We take the same "noisy picture" for all attempts - and we give it MSE and SSIM score.
3. We denoise with original "float" - give MSE and SSIM.
4. We denoise with all types of quantization that you tried and give MSE and SSIM.
5. Please give a conclusion in the end - if some types of quantization are better than others.


    Display original picture 
    Give MSI and SSIM score against itself
    Display noisy picture
    Give MSI and SSIM score between origional and noisy    
     _____________________________________________________________________________________
    | No Quantization                                          | performance %| MSE| SSIM |
    | Static Quantization of standard model - fx_static        | performance %| MSE| SSIM |
    | Dynamic Quantization of standard model - fx_dynamic      | performance %| MSE| SSIM |
    | Static Quantization of QAT model - fx_static             | performance %| MSE| SSIM |
    | Dynamic Quantization of QAT model - fx_dynamic           | performance %| MSE| SSIM |
    | Static Quantization of QAT_fx_static model - fx_static   | performance %| MSE| SSIM |
    | Dynamic Quantization of QAT_fx_static model - fx_dynamic | performance %| MSE| SSIM |
     _____________________________________________________________________________________


In [1]:
%matplotlib notebook

import argparse
import time
import os


In [2]:

import matplotlib.pyplot as plt
from data import NoisyBSDSDataset
from argument import Args
from model import DnCNN, UDnCNN, DUDnCNN, QDUDnCNN
import nntools as nt
#from utils import DenoisingStatsManager, plot, NoisyBSDSDataset
import utils
import utils2
from skimage import metrics

import cv2
import numpy as np
import torch


In [3]:

import torch.utils.data as td
import torch.quantization.quantize_fx as quantize_fx
from torch.quantization.fuse_modules import fuse_known_modules
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td
import torchvision as tv
from PIL import Image
import model
import inference
from prettytable import PrettyTable

In [4]:
def main(img_set,model_path,qat=False,quantize='none'):
    args = Args()
    args.quantize = quantize
    device = 'cpu'
    if qat:
        denoise, qat_state = inference.load_model(model_path, QDUDnCNN, args.D, args.C, device=device)
    else:
        denoise, qat_state = inference.load_model(model_path, DUDnCNN, args.D, args.C, device=device)

    img = []
    titles = ['clean', 'noise', 'denoise']
    x, clean = img_set
    x = x.unsqueeze(0).to(device)
    img.append(clean)
    img.append(x[0])

    if args.quantize:
        print('Quantize model...'+quantize)
        denoise = utils2.quantize_model(args.quantize, denoise, input_example=x, qat_state=qat_state)

    t = time.time()
    with torch.no_grad():
        y = denoise(x)
    img.append(y[0])

    print(f'Elapsed: {(time.time() - t) * 1000:.2f}ms')
    print(f'Image size is {x[0].shape}.')
    y1 = utils2.myimret(img[0])
    y2 = utils2.myimret(img[1])
    y3 = utils2.myimret(img[2])
    utils2.compare_images(y1,y3)
    

    fig, axes = plt.subplots(ncols=3, figsize=(9,5), sharex='all', sharey='all')
    for i in range(len(img)):
        utils2.myimshow(img[i], ax=axes[i])
        axes[i].set_title(f'{titles[i]}')



In [5]:
def eval_q(img_set,model_path,qat=False,quantize='none',display=False): 
    args = Args()
    args.quantize = quantize
    device = 'cpu'
    if qat:
        denoise, qat_state = inference.load_model(model_path, QDUDnCNN, args.D, args.C, device=device)
    else:
        denoise, qat_state = inference.load_model(model_path, DUDnCNN, args.D, args.C, device=device)
    
    img = []
    titles = ['clean', 'noise', 'denoise'] 
    x, clean = img_set
    x = x.unsqueeze(0).to(device)
    
#    if display:
    img.append(clean)
    img.append(x[0])

    if args.quantize:
        denoise = utils2.quantize_model(args.quantize, denoise, input_example=x, qat_state=qat_state)

    t = time.time()
    with torch.no_grad():
        y = denoise(x) 
    img.append(y[0])
#    print(f'Elapsed: {(time.time() - t) * 1000:.2f}ms')
    y1 = utils2.myimret(img[0])
    y2 = utils2.myimret(img[1])
    y3 = utils2.myimret(img[2])
    utils2.compare_images(y1, y3)

    if display:
        fig, axes = plt.subplots(ncols=3, figsize=(9,5), sharex='all', sharey='all')
        for i in range(len(img)):
            utils2.myimshow(img[i], ax=axes[i])
            axes[i].set_title(f'{titles[i]}') 
    return (time.time() - t) * 1000, quantize

In [6]:
def eval_origin(img_set,display=False): 
    device = 'cpu'
    
    img = []
    x, clean = img_set
    x = x.unsqueeze(0).to(device)
    
    img.append(clean)
    img.append(x[0])

    y1 = utils2.myimret(img[0])
    y2 = utils2.myimret(img[1])
    mse,ssim = utils2.compare_images(y1,y1,False)
    return mse, ssim

In [7]:
def eval_noisy(img_set,display=False): 
    device = 'cpu'
    
    img = []
    x, clean = img_set
    x = x.unsqueeze(0).to(device)
    
    img.append(clean)
    img.append(x[0])

    y1 = utils2.myimret(img[0])
    y2 = utils2.myimret(img[1])
    mse,ssim = utils2.compare_images(y1,y2,False)
    return mse, ssim

In [8]:
def eval_image_set(img_set,model_path,qat=False,quantize='none',display=False): 
    args = Args()
    args.quantize = quantize
    device = 'cpu'
    if qat:
        denoise, qat_state = inference.load_model(model_path, QDUDnCNN, args.D, args.C, device=device)
    else:
        denoise, qat_state = inference.load_model(model_path, DUDnCNN, args.D, args.C, device=device)
    
    img = []
    titles = ['clean', 'noise', 'denoise'] 
    x, clean = img_set
    x = x.unsqueeze(0).to(device)
    
    img.append(clean)
    img.append(x[0])

    if args.quantize:
        denoise = utils2.quantize_model(args.quantize, denoise, input_example=x, qat_state=qat_state)

    t = time.time()
    with torch.no_grad():
        y = denoise(x) 
    t_elapsed = (time.time() - t)*1000
    img.append(y[0])
    y1 = utils2.myimret(img[0])
    y2 = utils2.myimret(img[1])
    y3 = utils2.myimret(img[2])
    mse, ssim = utils2.compare_images(y1,y3)

    if display:
        fig, axes = plt.subplots(ncols=3, figsize=(9,5), sharex='all', sharey='all')
        for i in range(len(img)):
            utils2.myimshow(img[i], ax=axes[i])
            axes[i].set_title(f'{titles[i]}') 
    return t_elapsed, quantize, mse, ssim

In [9]:
def experiment(img,display=True):
    img_set = img

    t = PrettyTable(['training','quantization','time','mse','ssim'])
    t.float_format=".5"
    t.float_format['time'] = ".2"
    t.border=True
    t.align="r"
    t.align["quantization"]="l"
    t.align["training"]="l"

    path = os.environ.get('TRAINING_DIR')+'/no_qat/checkpoint.pth' 
    mse_origin,ssim_origin = eval_origin(img_set, False)
    mse_noisy, ssim_noisy = eval_noisy(img_set, False)
    t.add_row(['NA (origin)',"NA","NA",mse_origin,ssim_origin])
    t.add_row(['NA (noisy)',"NA","NA",mse_noisy,ssim_noisy])

    eval_time_none,q, mse, ssim = eval_image_set(img_set, path, False,'none', display) 
    t.add_row(['standard',q,(eval_time_none/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'fx_static', False) 
    t.add_row(['standard',q,(eval_time/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'fx_dynamic', False) 
    t.add_row(['standard',q,(eval_time/eval_time_none)*100,mse,ssim])

    path = os.environ.get('TRAINING_DIR')+'/qat/checkpoint.pth' 
    eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'none', False) 
    t.add_row(['QAT',q,(eval_time/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_static', False) 
    t.add_row(['QAT',q,(eval_time/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_dynamic', False) 
    t.add_row(['QAT',q,(eval_time/eval_time_none)*100,mse,ssim])
    path = os.environ.get('TRAINING_DIR')+'/qat_fx_static/checkpoint.pth' 
    eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'none', False) 
    t.add_row(['QAT_fc_static',q,(eval_time/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_static', False) 
    t.add_row(['QAT_fc_static',q,(eval_time/eval_time_none)*100,mse,ssim])
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_dynamic', False) 
    t.add_row(['QAT_fc_static',q,(eval_time/eval_time_none)*100,mse,ssim])
    print(t)
    

In [10]:

def experiment_test(img,display=False):
    img_set = img

    eval_origin(img_set, display)    
    eval_origin(img_set, display)
    eval_origin(img_set, display)

    path = os.environ.get('TRAINING_DIR')+'/no_qat/checkpoint.pth' 
    eval_image_set(img_set, path, False,'none', display) 
    eval_image_set(img_set, path, False,'fx_static', False) 
    eval_image_set(img_set, path, False,'fx_dynamic', False) 

    path = os.environ.get('TRAINING_DIR')+'/qat/checkpoint.pth' 
    eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'none', False) 
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_static', False) 
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_dynamic', False) 

    path = os.environ.get('TRAINING_DIR')+'/qat_fx_static/checkpoint.pth' 
    eval_time,q, mse, ssim = eval_image_set(img_set, path, False,'none', False) 
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_static', False) 
    eval_time,q, mse, ssim = eval_image_set(img_set, path, True,'fx_dynamic', False) 


In [11]:
dataset_root_dir = os.environ.get('DATA_DIR')+'/images'
test_set = utils2.NoisyBSDSDataset(dataset_root_dir, mode='test', image_size=(320, 320))

In [12]:
type(test_set)

utils2.NoisyBSDSDataset

In [None]:
#experiment_test(test_set[1],True)

In [None]:
experiment(test_set[1],True)

In [None]:
experiment(test_set[1],False)
experiment(test_set[1],False)
experiment(test_set[1],False)

In [None]:
experiment(test_set[0],True)

In [None]:
experiment(test_set[0],False)
experiment(test_set[0],False)
experiment(test_set[0],False)

In [None]:
experiment(test_set[99],True)

In [None]:
experiment(test_set[99],False)
experiment(test_set[99],False)
experiment(test_set[99],False)