Skip to content

Commit

Permalink
[CodeCamp2023-647] Add new configs of EG3D (#1985)
Browse files Browse the repository at this point in the history
* add_new_config_eg3d

* minor fix

* fix fid import

---------

Co-authored-by: LeoXing1996 <xingzn1996@hotmail.com>
  • Loading branch information
RangeKing and LeoXing1996 authored Aug 23, 2023
1 parent 0476c39 commit 7525478
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 0 deletions.
78 changes: 78 additions & 0 deletions mmagic/configs/_base_/gen_default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import CheckpointHook, LoggerHook
from mmengine.model import MMSeparateDistributedDataParallel

from mmagic.engine import (IterTimerHook, LogProcessor,
MultiOptimWrapperConstructor, MultiTestLoop,
MultiValLoop)
from mmagic.evaluation import Evaluator
from mmagic.visualization import VisBackend, Visualizer

default_scope = 'mmagic'

randomness = dict(seed=2022, diff_rank_seed=True)
# env settings
dist_params = dict(backend='nccl')
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = 0
# set multi-process start method as `fork` to speed up the training
mp_start_method = 'fork'

# configure for default hooks
default_hooks = dict(
# record time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, interval=100, log_metric_by_epoch=False),
# save checkpoint per 10000 iterations
checkpoint=dict(
type=CheckpointHook,
interval=10000,
by_epoch=False,
max_keep_ckpts=20,
less_keys=['FID-Full-50k/fid', 'FID-50k/fid', 'swd/avg'],
greater_keys=['IS-50k/is', 'ms-ssim/avg'],
save_optimizer=True))

# config for environment
env_cfg = dict(
# whether to enable cudnn benchmark.
cudnn_benchmark=True,
# set multi process parameters.
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters.
dist_cfg=dict(backend='nccl'))

# set log level
log_level = 'INFO'
log_processor = dict(type=LogProcessor, by_epoch=False)

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = None

# config for model wrapper
model_wrapper_cfg = dict(
type=MMSeparateDistributedDataParallel,
broadcast_buffers=False,
find_unused_parameters=False)

# set visualizer
vis_backends = [dict(type=VisBackend)]
visualizer = dict(type=Visualizer, vis_backends=vis_backends)

# config for training
train_cfg = dict(by_epoch=False, val_begin=1, val_interval=10000)

# config for val
val_cfg = dict(type=MultiValLoop)
val_evaluator = dict(type=Evaluator)

# config for test
test_cfg = dict(type=MultiTestLoop)
test_evaluator = dict(type=Evaluator)

# config for optim_wrapper_constructor
optim_wrapper = dict(constructor=MultiOptimWrapperConstructor)
98 changes: 98 additions & 0 deletions mmagic/configs/eg3d/eg3d_cvt-official-rgb_afhq-512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) OpenMMLab. All rights reserved.

# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0

from mmengine.config import read_base
from mmengine.dataset import DefaultSampler

from mmagic.datasets import BasicConditionalDataset
from mmagic.datasets.transforms import LoadImageFromFile, PackInputs
from mmagic.engine import VisualizationHook
from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.data_preprocessors import DataPreprocessor
from mmagic.models.editors.eg3d import EG3D, GaussianCamera, TriplaneGenerator

with read_base():
from .._base_.gen_default_runtime import * # noqa: F401,F403

model = dict(
type=EG3D,
data_preprocessor=dict(type=DataPreprocessor),
generator=dict(
type=TriplaneGenerator,
out_size=512,
triplane_channels=32,
triplane_size=256,
num_mlps=2,
sr_add_noise=False,
sr_in_size=128,
neural_rendering_resolution=128,
renderer_cfg=dict(
ray_start=2.25,
ray_end=3.3,
box_warp=1,
depth_resolution=48,
depth_resolution_importance=48,
white_back=False,
),
rgb2bgr=True),
camera=dict(
type=GaussianCamera,
horizontal_mean=3.14 / 2,
horizontal_std=0.35,
vertical_mean=3.14 / 2 - 0.05,
vertical_std=0.25,
radius=2.7,
fov=18.837,
look_at=[0, 0, 0.2]))

train_cfg = train_dataloader = optim_wrapper = None
val_cfg = val_dataloader = val_evaluator = None

inception_pkl = './work_dirs/inception_pkl/eg3d_afhq.pkl'
metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full',
fake_nums=50000,
inception_pkl=inception_pkl,
need_cond_input=True,
sample_model='orig'),
dict(
type=FrechetInceptionDistance,
prefix='FID-Random-Camera',
fake_nums=50000,
inception_pkl=inception_pkl,
sample_model='orig')
]

test_pipeline = [
dict(type=LoadImageFromFile, key='img', color_type='color'),
dict(type=PackInputs)
]
test_dataset = dict(
type=BasicConditionalDataset,
data_root='./data/eg3d/afhq',
ann_file='afhq.json',
pipeline=test_pipeline)
test_dataloader = dict(
# NOTE: `batch_size = 4` cost nearly **9.5GB** of GPU memory,
# modification this param by yourself corresponding to your own GPU.
batch_size=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
num_workers=9,
dataset=test_dataset)

test_evaluator = dict(metrics=metrics)

custom_hooks = [
dict(
type=VisualizationHook,
interval=5000,
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
101 changes: 101 additions & 0 deletions mmagic/configs/eg3d/eg3d_cvt-official-rgb_ffhq-512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.

# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0

from mmengine.config import read_base
from mmengine.dataset import DefaultSampler

from mmagic.datasets import BasicConditionalDataset
from mmagic.datasets.transforms import LoadImageFromFile, PackInputs
from mmagic.engine import VisualizationHook
from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.data_preprocessors import DataPreprocessor
from mmagic.models.editors.eg3d import EG3D, GaussianCamera, TriplaneGenerator

with read_base():
from .._base_.gen_default_runtime import * # noqa: F401,F403

model = dict(
type=EG3D,
data_preprocessor=dict(type=DataPreprocessor),
generator=dict(
type=TriplaneGenerator,
out_size=512,
triplane_channels=32,
triplane_size=256,
num_mlps=2,
neural_rendering_resolution=128,
sr_add_noise=False,
sr_in_size=128,
# NOTE: double hidden channels and out channels for FFHQ-512
sr_hidden_channels=256,
sr_out_channels=128,
renderer_cfg=dict(
ray_start=2.25,
ray_end=3.3,
box_warp=1,
depth_resolution=48,
depth_resolution_importance=48,
white_back=False,
),
rgb2bgr=True),
camera=dict(
type=GaussianCamera,
horizontal_mean=3.14 / 2,
horizontal_std=0.35,
vertical_mean=3.14 / 2 - 0.05,
vertical_std=0.25,
radius=2.7,
fov=18.837,
look_at=[0, 0, 0.2]))

train_cfg = train_dataloader = optim_wrapper = None
val_cfg = val_dataloader = val_evaluator = None

inception_pkl = './work_dirs/inception_pkl/eg3d_ffhq_512.pkl'
metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full',
fake_nums=50000,
inception_pkl=inception_pkl,
need_cond_input=True,
sample_model='orig'),
dict(
type=FrechetInceptionDistance,
prefix='FID-Random-Camera',
fake_nums=50000,
inception_pkl=inception_pkl,
sample_model='orig')
]

test_pipeline = [
dict(type=LoadImageFromFile, key='img', color_type='color'),
dict(type=PackInputs)
]
test_dataset = dict(
type=BasicConditionalDataset,
data_root='./data/eg3d/ffhq_512',
ann_file='ffhq_512.json',
pipeline=test_pipeline)
test_dataloader = dict(
# NOTE: `batch_size = 4` cost nearly **9.5GB** of GPU memory,
# modification this param by yourself corresponding to your own GPU.
batch_size=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
num_workers=9,
dataset=test_dataset)

test_evaluator = dict(metrics=metrics)

custom_hooks = [
dict(
type=VisualizationHook,
interval=5000,
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
102 changes: 102 additions & 0 deletions mmagic/configs/eg3d/eg3d_cvt-official-rgb_shapenet-128x128.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) OpenMMLab. All rights reserved.

# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0

from mmengine.config import read_base
from mmengine.dataset import DefaultSampler

from mmagic.datasets import BasicConditionalDataset
from mmagic.datasets.transforms import LoadImageFromFile, PackInputs
from mmagic.engine import VisualizationHook
from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.data_preprocessors import DataPreprocessor
from mmagic.models.editors.eg3d import EG3D, TriplaneGenerator, UniformCamera

with read_base():
from .._base_.gen_default_runtime import * # noqa: F401,F403

model = dict(
type=EG3D,
data_preprocessor=dict(type=DataPreprocessor),
generator=dict(
type=TriplaneGenerator,
out_size=128,
zero_cond_input=True,
cond_scale=0,
sr_in_size=64,
renderer_cfg=dict(
# Official implementation set ray_start, ray_end and box_warp as
# 0.1, 2.6 and 1.6 respectively, and FID is 7.2441
# ray_start=0.1,
# ray_end=2.6,
# box_warp=1.6,
ray_start=0.4,
ray_end=2.0,
box_warp=1.7,
depth_resolution=64,
depth_resolution_importance=64,
white_back=True,
),
rgb2bgr=True),
camera=dict(
type=UniformCamera,
horizontal_mean=3.141,
horizontal_std=3.141,
vertical_mean=3.141 / 2,
vertical_std=3.141 / 2,
focal=1.025390625,
up=[0, 0, 1],
radius=1.2),
)

train_cfg = train_dataloader = optim_wrapper = None
val_cfg = val_dataloader = val_evaluator = None

inception_pkl = './work_dirs/inception_pkl/eg3d_shapenet.pkl'
metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full',
fake_nums=50000,
inception_pkl=inception_pkl,
need_cond_input=True,
sample_model='orig'),
dict(
type=FrechetInceptionDistance,
prefix='FID-Random-Camera',
fake_nums=50000,
inception_pkl=inception_pkl,
sample_model='orig'),
]

test_pipeline = [
dict(type=LoadImageFromFile, key='img', color_type='color'),
dict(type=PackInputs)
]
test_dataset = dict(
type=BasicConditionalDataset,
data_root='./data/eg3d/shapenet-car',
ann_file='shapenet.json',
pipeline=test_pipeline)
test_dataloader = dict(
# NOTE: `batch_size = 16` cost nearly **12GB** of GPU memory,
# modification this param by yourself corresponding to your own GPU.
batch_size=16,
persistent_workers=False,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
num_workers=9,
dataset=test_dataset)

test_evaluator = dict(metrics=metrics)

custom_hooks = [
dict(
type=VisualizationHook,
interval=5000,
fixed_input=True,
# save_at_test=False,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]

0 comments on commit 7525478

Please sign in to comment.