-
Notifications
You must be signed in to change notification settings - Fork 0
/
postprocess.py
108 lines (88 loc) · 3.21 KB
/
postprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""postprocess"""
import os
import math
import glob
import argparse
import numpy as np
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default="/cache/data", help="dataset path.")
parser.add_argument("--save_path", type=str, default="/cache/output", help="output path.")
parser.add_argument("--bin_path", type=str, default="/cache/data", help="lr bin path.")
args = parser.parse_args()
def read_bin(bin_path):
img = np.fromfile(bin_path, dtype=np.float32)
num_pix = img.size
img_shape = int(math.sqrt(num_pix / 3))
if 1 * 3 * img_shape * img_shape != num_pix:
raise RuntimeError(f'bin file error, it not output from dncnn network, {bin_path}')
img = img.reshape(1, 3, img_shape, img_shape)
return img
def read_bin_as_hwc(bin_path):
nchw_img = read_bin(bin_path)
chw_img = nchw_img[0]
hwc_img = chw_img.transpose(1, 2, 0)
return hwc_img
def PSNR(img1, img2):
"""metrics"""
psnr = peak_signal_noise_ratio(img1, img2)
return psnr
def SSIM(img1, img2):
"""metrics"""
ssim = structural_similarity(img1, img2, data_range=255, multichannel=True)
return ssim
def get_metric(ori_path, res_path):
"""metrics"""
files = glob.glob(os.path.join(ori_path, "*"))
names = []
for i in files:
names.append(i.split("/")[-1])
# PSNR
print("PSNR...")
res = 0
for i in names:
ori = Image.open(os.path.join(ori_path, i))
gen = Image.open(os.path.join(res_path, i))
res += PSNR(np.array(ori), np.array(gen))
psnr_res = res / len(names)
# SSIM
print("SSIM...")
res = 0
for i in names:
ori = Image.open(os.path.join(ori_path, i))
gen = Image.open(os.path.join(res_path, i))
res += SSIM(np.array(ori), np.array(gen))
ssim_res = res / len(names)
print("PSNR: ", psnr_res)
print("SSIM: ", ssim_res)
def run_post_process(dataset_path, save_path, bin_path):
"""run post process """
files = os.listdir(dataset_path)
files.sort()
for file in files:
file_name = file.split('.')[0]
bin_file = os.path.join(bin_path, file_name + "_0.bin")
sr = read_bin_as_hwc(bin_file)
out_img = sr
out_img = np.clip(out_img, 0, 255)
out_img = np.uint8(out_img)
out_img = Image.fromarray(out_img)
out_img.save(os.path.join(save_path, file), quality=95)
get_metric(dataset_path, save_path)
if __name__ == "__main__":
run_post_process(args.dataset_path, args.save_path, args.bin_path)