-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CodeCamp2023-647] Add new configs of EG3D (#1985)
* add_new_config_eg3d * minor fix * fix fid import --------- Co-authored-by: LeoXing1996 <xingzn1996@hotmail.com>
- Loading branch information
1 parent
0476c39
commit 7525478
Showing
4 changed files
with
379 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.hooks import CheckpointHook, LoggerHook | ||
from mmengine.model import MMSeparateDistributedDataParallel | ||
|
||
from mmagic.engine import (IterTimerHook, LogProcessor, | ||
MultiOptimWrapperConstructor, MultiTestLoop, | ||
MultiValLoop) | ||
from mmagic.evaluation import Evaluator | ||
from mmagic.visualization import VisBackend, Visualizer | ||
|
||
default_scope = 'mmagic' | ||
|
||
randomness = dict(seed=2022, diff_rank_seed=True) | ||
# env settings | ||
dist_params = dict(backend='nccl') | ||
# disable opencv multithreading to avoid system being overloaded | ||
opencv_num_threads = 0 | ||
# set multi-process start method as `fork` to speed up the training | ||
mp_start_method = 'fork' | ||
|
||
# configure for default hooks | ||
default_hooks = dict( | ||
# record time of every iteration. | ||
timer=dict(type=IterTimerHook), | ||
# print log every 100 iterations. | ||
logger=dict(type=LoggerHook, interval=100, log_metric_by_epoch=False), | ||
# save checkpoint per 10000 iterations | ||
checkpoint=dict( | ||
type=CheckpointHook, | ||
interval=10000, | ||
by_epoch=False, | ||
max_keep_ckpts=20, | ||
less_keys=['FID-Full-50k/fid', 'FID-50k/fid', 'swd/avg'], | ||
greater_keys=['IS-50k/is', 'ms-ssim/avg'], | ||
save_optimizer=True)) | ||
|
||
# config for environment | ||
env_cfg = dict( | ||
# whether to enable cudnn benchmark. | ||
cudnn_benchmark=True, | ||
# set multi process parameters. | ||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||
# set distributed parameters. | ||
dist_cfg=dict(backend='nccl')) | ||
|
||
# set log level | ||
log_level = 'INFO' | ||
log_processor = dict(type=LogProcessor, by_epoch=False) | ||
|
||
# load from which checkpoint | ||
load_from = None | ||
|
||
# whether to resume training from the loaded checkpoint | ||
resume = None | ||
|
||
# config for model wrapper | ||
model_wrapper_cfg = dict( | ||
type=MMSeparateDistributedDataParallel, | ||
broadcast_buffers=False, | ||
find_unused_parameters=False) | ||
|
||
# set visualizer | ||
vis_backends = [dict(type=VisBackend)] | ||
visualizer = dict(type=Visualizer, vis_backends=vis_backends) | ||
|
||
# config for training | ||
train_cfg = dict(by_epoch=False, val_begin=1, val_interval=10000) | ||
|
||
# config for val | ||
val_cfg = dict(type=MultiValLoop) | ||
val_evaluator = dict(type=Evaluator) | ||
|
||
# config for test | ||
test_cfg = dict(type=MultiTestLoop) | ||
test_evaluator = dict(type=Evaluator) | ||
|
||
# config for optim_wrapper_constructor | ||
optim_wrapper = dict(constructor=MultiOptimWrapperConstructor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa | ||
# mmcv >= 2.0.1 | ||
# mmengine >= 0.8.0 | ||
|
||
from mmengine.config import read_base | ||
from mmengine.dataset import DefaultSampler | ||
|
||
from mmagic.datasets import BasicConditionalDataset | ||
from mmagic.datasets.transforms import LoadImageFromFile, PackInputs | ||
from mmagic.engine import VisualizationHook | ||
from mmagic.evaluation.metrics import FrechetInceptionDistance | ||
from mmagic.models.data_preprocessors import DataPreprocessor | ||
from mmagic.models.editors.eg3d import EG3D, GaussianCamera, TriplaneGenerator | ||
|
||
with read_base(): | ||
from .._base_.gen_default_runtime import * # noqa: F401,F403 | ||
|
||
model = dict( | ||
type=EG3D, | ||
data_preprocessor=dict(type=DataPreprocessor), | ||
generator=dict( | ||
type=TriplaneGenerator, | ||
out_size=512, | ||
triplane_channels=32, | ||
triplane_size=256, | ||
num_mlps=2, | ||
sr_add_noise=False, | ||
sr_in_size=128, | ||
neural_rendering_resolution=128, | ||
renderer_cfg=dict( | ||
ray_start=2.25, | ||
ray_end=3.3, | ||
box_warp=1, | ||
depth_resolution=48, | ||
depth_resolution_importance=48, | ||
white_back=False, | ||
), | ||
rgb2bgr=True), | ||
camera=dict( | ||
type=GaussianCamera, | ||
horizontal_mean=3.14 / 2, | ||
horizontal_std=0.35, | ||
vertical_mean=3.14 / 2 - 0.05, | ||
vertical_std=0.25, | ||
radius=2.7, | ||
fov=18.837, | ||
look_at=[0, 0, 0.2])) | ||
|
||
train_cfg = train_dataloader = optim_wrapper = None | ||
val_cfg = val_dataloader = val_evaluator = None | ||
|
||
inception_pkl = './work_dirs/inception_pkl/eg3d_afhq.pkl' | ||
metrics = [ | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Full', | ||
fake_nums=50000, | ||
inception_pkl=inception_pkl, | ||
need_cond_input=True, | ||
sample_model='orig'), | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Random-Camera', | ||
fake_nums=50000, | ||
inception_pkl=inception_pkl, | ||
sample_model='orig') | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type=LoadImageFromFile, key='img', color_type='color'), | ||
dict(type=PackInputs) | ||
] | ||
test_dataset = dict( | ||
type=BasicConditionalDataset, | ||
data_root='./data/eg3d/afhq', | ||
ann_file='afhq.json', | ||
pipeline=test_pipeline) | ||
test_dataloader = dict( | ||
# NOTE: `batch_size = 4` cost nearly **9.5GB** of GPU memory, | ||
# modification this param by yourself corresponding to your own GPU. | ||
batch_size=4, | ||
persistent_workers=False, | ||
drop_last=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
num_workers=9, | ||
dataset=test_dataset) | ||
|
||
test_evaluator = dict(metrics=metrics) | ||
|
||
custom_hooks = [ | ||
dict( | ||
type=VisualizationHook, | ||
interval=5000, | ||
fixed_input=True, | ||
vis_kwargs_list=dict(type='GAN', name='fake_img')) | ||
] |
101 changes: 101 additions & 0 deletions
101
mmagic/configs/eg3d/eg3d_cvt-official-rgb_ffhq-512x512.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa | ||
# mmcv >= 2.0.1 | ||
# mmengine >= 0.8.0 | ||
|
||
from mmengine.config import read_base | ||
from mmengine.dataset import DefaultSampler | ||
|
||
from mmagic.datasets import BasicConditionalDataset | ||
from mmagic.datasets.transforms import LoadImageFromFile, PackInputs | ||
from mmagic.engine import VisualizationHook | ||
from mmagic.evaluation.metrics import FrechetInceptionDistance | ||
from mmagic.models.data_preprocessors import DataPreprocessor | ||
from mmagic.models.editors.eg3d import EG3D, GaussianCamera, TriplaneGenerator | ||
|
||
with read_base(): | ||
from .._base_.gen_default_runtime import * # noqa: F401,F403 | ||
|
||
model = dict( | ||
type=EG3D, | ||
data_preprocessor=dict(type=DataPreprocessor), | ||
generator=dict( | ||
type=TriplaneGenerator, | ||
out_size=512, | ||
triplane_channels=32, | ||
triplane_size=256, | ||
num_mlps=2, | ||
neural_rendering_resolution=128, | ||
sr_add_noise=False, | ||
sr_in_size=128, | ||
# NOTE: double hidden channels and out channels for FFHQ-512 | ||
sr_hidden_channels=256, | ||
sr_out_channels=128, | ||
renderer_cfg=dict( | ||
ray_start=2.25, | ||
ray_end=3.3, | ||
box_warp=1, | ||
depth_resolution=48, | ||
depth_resolution_importance=48, | ||
white_back=False, | ||
), | ||
rgb2bgr=True), | ||
camera=dict( | ||
type=GaussianCamera, | ||
horizontal_mean=3.14 / 2, | ||
horizontal_std=0.35, | ||
vertical_mean=3.14 / 2 - 0.05, | ||
vertical_std=0.25, | ||
radius=2.7, | ||
fov=18.837, | ||
look_at=[0, 0, 0.2])) | ||
|
||
train_cfg = train_dataloader = optim_wrapper = None | ||
val_cfg = val_dataloader = val_evaluator = None | ||
|
||
inception_pkl = './work_dirs/inception_pkl/eg3d_ffhq_512.pkl' | ||
metrics = [ | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Full', | ||
fake_nums=50000, | ||
inception_pkl=inception_pkl, | ||
need_cond_input=True, | ||
sample_model='orig'), | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Random-Camera', | ||
fake_nums=50000, | ||
inception_pkl=inception_pkl, | ||
sample_model='orig') | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type=LoadImageFromFile, key='img', color_type='color'), | ||
dict(type=PackInputs) | ||
] | ||
test_dataset = dict( | ||
type=BasicConditionalDataset, | ||
data_root='./data/eg3d/ffhq_512', | ||
ann_file='ffhq_512.json', | ||
pipeline=test_pipeline) | ||
test_dataloader = dict( | ||
# NOTE: `batch_size = 4` cost nearly **9.5GB** of GPU memory, | ||
# modification this param by yourself corresponding to your own GPU. | ||
batch_size=4, | ||
persistent_workers=False, | ||
drop_last=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
num_workers=9, | ||
dataset=test_dataset) | ||
|
||
test_evaluator = dict(metrics=metrics) | ||
|
||
custom_hooks = [ | ||
dict( | ||
type=VisualizationHook, | ||
interval=5000, | ||
fixed_input=True, | ||
vis_kwargs_list=dict(type='GAN', name='fake_img')) | ||
] |
102 changes: 102 additions & 0 deletions
102
mmagic/configs/eg3d/eg3d_cvt-official-rgb_shapenet-128x128.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa | ||
# mmcv >= 2.0.1 | ||
# mmengine >= 0.8.0 | ||
|
||
from mmengine.config import read_base | ||
from mmengine.dataset import DefaultSampler | ||
|
||
from mmagic.datasets import BasicConditionalDataset | ||
from mmagic.datasets.transforms import LoadImageFromFile, PackInputs | ||
from mmagic.engine import VisualizationHook | ||
from mmagic.evaluation.metrics import FrechetInceptionDistance | ||
from mmagic.models.data_preprocessors import DataPreprocessor | ||
from mmagic.models.editors.eg3d import EG3D, TriplaneGenerator, UniformCamera | ||
|
||
with read_base(): | ||
from .._base_.gen_default_runtime import * # noqa: F401,F403 | ||
|
||
model = dict( | ||
type=EG3D, | ||
data_preprocessor=dict(type=DataPreprocessor), | ||
generator=dict( | ||
type=TriplaneGenerator, | ||
out_size=128, | ||
zero_cond_input=True, | ||
cond_scale=0, | ||
sr_in_size=64, | ||
renderer_cfg=dict( | ||
# Official implementation set ray_start, ray_end and box_warp as | ||
# 0.1, 2.6 and 1.6 respectively, and FID is 7.2441 | ||
# ray_start=0.1, | ||
# ray_end=2.6, | ||
# box_warp=1.6, | ||
ray_start=0.4, | ||
ray_end=2.0, | ||
box_warp=1.7, | ||
depth_resolution=64, | ||
depth_resolution_importance=64, | ||
white_back=True, | ||
), | ||
rgb2bgr=True), | ||
camera=dict( | ||
type=UniformCamera, | ||
horizontal_mean=3.141, | ||
horizontal_std=3.141, | ||
vertical_mean=3.141 / 2, | ||
vertical_std=3.141 / 2, | ||
focal=1.025390625, | ||
up=[0, 0, 1], | ||
radius=1.2), | ||
) | ||
|
||
train_cfg = train_dataloader = optim_wrapper = None | ||
val_cfg = val_dataloader = val_evaluator = None | ||
|
||
inception_pkl = './work_dirs/inception_pkl/eg3d_shapenet.pkl' | ||
metrics = [ | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Full', | ||
fake_nums=50000, | ||
inception_pkl=inception_pkl, | ||
need_cond_input=True, | ||
sample_model='orig'), | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Random-Camera', | ||
fake_nums=50000, | ||
inception_pkl=inception_pkl, | ||
sample_model='orig'), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type=LoadImageFromFile, key='img', color_type='color'), | ||
dict(type=PackInputs) | ||
] | ||
test_dataset = dict( | ||
type=BasicConditionalDataset, | ||
data_root='./data/eg3d/shapenet-car', | ||
ann_file='shapenet.json', | ||
pipeline=test_pipeline) | ||
test_dataloader = dict( | ||
# NOTE: `batch_size = 16` cost nearly **12GB** of GPU memory, | ||
# modification this param by yourself corresponding to your own GPU. | ||
batch_size=16, | ||
persistent_workers=False, | ||
drop_last=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
num_workers=9, | ||
dataset=test_dataset) | ||
|
||
test_evaluator = dict(metrics=metrics) | ||
|
||
custom_hooks = [ | ||
dict( | ||
type=VisualizationHook, | ||
interval=5000, | ||
fixed_input=True, | ||
# save_at_test=False, | ||
vis_kwargs_list=dict(type='GAN', name='fake_img')) | ||
] |