Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add: Add better FID calculator to verify image quality. #69

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Integrated Design Diffusion Model
├── tools
│ ├── deploy.py
│ ├── FID_calculator.py
│ ├── FID_calculator_plus.py
│ ├── generate.py
│ └── train.py
├── utils
Expand Down Expand Up @@ -114,6 +115,7 @@ Integrated Design Diffusion Model
- [x] 10. Reconstruct the overall structure of the model (2023-12-06)
- [x] 11. Write visual webui interface. (2024-01-23)
- [x] 12. Adding PLMS Sampling Method. (2024-03-12)
- [x] 13. Adding FID calculator to verify image quality. (2024-05-06)

### Training

Expand Down
2 changes: 2 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Integrated Design Diffusion Model
├── tools
│ ├── deploy.py
│ ├── FID_calculator.py
│ ├── FID_calculator_plus.py
│ ├── generate.py
│ └── train.py
├── utils
Expand Down Expand Up @@ -113,6 +114,7 @@ Integrated Design Diffusion Model
- [x] 10. 重构model整体结构(2023-12-06)
- [x] 11. 编写可视化webui界面(2024-01-23)
- [x] 12. 增加PLMS采样方法(2024-03-12)
- [x] 13. 增加FID方法验证图像质量(2024-05-06)

### 训练

Expand Down
71 changes: 71 additions & 0 deletions tools/FID_calculator_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2024/5/4 23:36
@Author : chairc
@Site : https://github.com/chairc
"""
import os
import sys
import argparse
import logging

import coloredlogs

from pytorch_fid.fid_score import save_fid_stats, calculate_fid_given_paths
from pytorch_fid.inception import InceptionV3

sys.path.append(os.path.dirname(sys.path[0]))
from utils.initializer import device_initializer

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")


def main(args):
logger.info(msg=f"[Note]: Input params: {args}")
device_id = args.use_gpu
paths = args.path
batch_size = args.batch_size
num_workers = args.num_workers
dims = args.dims
device = device_initializer(device_id=device_id)
# TODO: Check image size
# Compute fid
if args.save_stats:
save_fid_stats(paths=paths, batch_size=batch_size, device=device, dims=dims, num_workers=num_workers)
return

fid_value = calculate_fid_given_paths(paths=paths, batch_size=batch_size, device=device, dims=dims,
num_workers=num_workers)

logger.info(msg=f"The result of FID: {fid_value}")


if __name__ == "__main__":
# Before calculating
# [Note]: We recommend resizing both sets of images to the same format, the same size, and the same number
parser = argparse.ArgumentParser()
# Function1: Generated image folder and dataset image folder
# Function2: Save stats input path and output path (use `--save_stats`)
parser.add_argument("path", type=str, nargs="*",
default=["/your/generated/image/folder/or/stats/input/path",
"/your/dataset/image/folder/or/stats/output/path"],
help="Paths to the generated images or to .npz statistic files")
# Batch size
parser.add_argument("--batch_size", type=int, default=8,
help="Batch size for calculation.")
# Number of workers
parser.add_argument("--num-workers", type=int, default=0)
# Dimensionality of Inception features to use
# Option: 64/192/768/2048
parser.add_argument("--dims", type=int, default=2048,
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
help="Dimensionality of Inception features to use. By default, uses pool3 features")
parser.add_argument("--save_stats", action="store_true",
help="Generate an npz archive from a directory of samples. "
"The first path is used as input and the second as output.")
# Set the use GPU in normal training (required)
parser.add_argument("--use_gpu", type=int, default=0)
args = parser.parse_args()
main(args)