-
Notifications
You must be signed in to change notification settings - Fork 1
/
metrics.py
40 lines (32 loc) · 1.15 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import argparse
import os
import pandas as pd
import piq
from lpips import LPIPS
from PIL import Image
from torchvision.transforms.functional import pil_to_tensor
from tqdm import tqdm
def main(args):
labels = os.listdir(args.d1)
df = pd.DataFrame(labels, columns=["label"])
df["psnr"] = 0
df["ssim"] = 0
df["lpips"] = 0
lpips = LPIPS()
for i, label in tqdm(enumerate(labels), total=len(labels)):
img1 = pil_to_tensor(
Image.open(os.path.join(args.d1, label)).convert("RGB")
).unsqueeze(dim=0)
img2 = pil_to_tensor(
Image.open(os.path.join(args.d2, label)).convert("RGB")
).unsqueeze(dim=0)
df.loc[i, "psnr"] = piq.psnr(img1, img2, data_range=255).item()
df.loc[i, "ssim"] = piq.ssim(img1, img2, data_range=255).item()
df.loc[i, "lpips"] = lpips(img1, img2).item()
df.to_csv("./results.csv", index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-d1", required=True, help="Directory one.")
parser.add_argument("-d2", required=True, help="Directory two.")
args = parser.parse_args()
main(args)