Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] add new config for _base_ dir #2053

Merged
merged 3 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions mmagic/configs/_base_/datasets/basicvsr_test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import DefaultSampler

from mmagic.datasets import BasicFramesDataset
from mmagic.datasets.transforms import (GenerateSegmentIndices,
LoadImageFromFile, MirrorSequence,
PackInputs)
from mmagic.engine.runner import MultiTestLoop
from mmagic.evaluation import PSNR, SSIM

# configs for REDS4
reds_data_root = 'data/REDS'

reds_pipeline = [
dict(type=GenerateSegmentIndices, interval_list=[1]),
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'),
dict(type=PackInputs)
]

reds_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='reds_reds4', task_name='vsr'),
data_root=reds_data_root,
data_prefix=dict(img='train_sharp_bicubic/X4', gt='train_sharp'),
ann_file='meta_info_reds4_val.txt',
depth=1,
num_input_frames=100,
fixed_seq_len=100,
pipeline=reds_pipeline))

reds_evaluator = [
dict(type=PSNR, prefix='REDS4-BIx4-RGB'),
dict(type=SSIM, prefix='REDS4-BIx4-RGB')
]

# configs for vimeo90k-bd and vimeo90k-bi
vimeo_90k_data_root = 'data/vimeo90k'
vimeo_90k_file_list = [
'im1.png', 'im2.png', 'im3.png', 'im4.png', 'im5.png', 'im6.png', 'im7.png'
]

vimeo_90k_pipeline = [
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'),
dict(type=MirrorSequence, keys=['img']),
dict(type=PackInputs)
]

vimeo_90k_bd_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='vimeo90k_seq', task_name='vsr'),
data_root=vimeo_90k_data_root,
data_prefix=dict(img='BDx4', gt='GT'),
ann_file='meta_info_Vimeo90K_test_GT.txt',
depth=2,
num_input_frames=7,
fixed_seq_len=7,
load_frames_list=dict(img=vimeo_90k_file_list, gt=['im4.png']),
pipeline=vimeo_90k_pipeline))

vimeo_90k_bi_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='vimeo90k_seq', task_name='vsr'),
data_root=vimeo_90k_data_root,
data_prefix=dict(img='BIx4', gt='GT'),
ann_file='meta_info_Vimeo90K_test_GT.txt',
depth=2,
num_input_frames=7,
fixed_seq_len=7,
load_frames_list=dict(img=vimeo_90k_file_list, gt=['im4.png']),
pipeline=vimeo_90k_pipeline))

vimeo_90k_bd_evaluator = [
dict(type=PSNR, convert_to='Y', prefix='Vimeo-90K-T-BDx4-Y'),
dict(type=SSIM, convert_to='Y', prefix='Vimeo-90K-T-BDx4-Y'),
]

vimeo_90k_bi_evaluator = [
dict(type=PSNR, convert_to='Y', prefix='Vimeo-90K-T-BIx4-Y'),
dict(type=SSIM, convert_to='Y', prefix='Vimeo-90K-T-BIx4-Y'),
]

# config for UDM10 (BDx4)
udm10_data_root = 'data/UDM10'

udm10_pipeline = [
dict(
type=GenerateSegmentIndices,
interval_list=[1],
filename_tmpl='{:04d}.png'),
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'),
dict(type=PackInputs)
]

udm10_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='udm10', task_name='vsr'),
data_root=udm10_data_root,
data_prefix=dict(img='BDx4', gt='GT'),
pipeline=udm10_pipeline))

udm10_evaluator = [
dict(type=PSNR, convert_to='Y', prefix='UDM10-BDx4-Y'),
dict(type=SSIM, convert_to='Y', prefix='UDM10-BDx4-Y')
]

# config for vid4
vid4_data_root = 'data/Vid4'

vid4_pipeline = [
dict(type=GenerateSegmentIndices, interval_list=[1]),
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'),
dict(type=PackInputs)
]
vid4_bd_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='vid4', task_name='vsr'),
data_root=vid4_data_root,
data_prefix=dict(img='BDx4', gt='GT'),
ann_file='meta_info_Vid4_GT.txt',
depth=1,
pipeline=vid4_pipeline))

vid4_bi_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='vid4', task_name='vsr'),
data_root=vid4_data_root,
data_prefix=dict(img='BIx4', gt='GT'),
ann_file='meta_info_Vid4_GT.txt',
depth=1,
pipeline=vid4_pipeline))

vid4_bd_evaluator = [
dict(type=PSNR, convert_to='Y', prefix='VID4-BDx4-Y'),
dict(type=SSIM, convert_to='Y', prefix='VID4-BDx4-Y'),
]
vid4_bi_evaluator = [
dict(type=PSNR, convert_to='Y', prefix='VID4-BIx4-Y'),
dict(type=SSIM, convert_to='Y', prefix='VID4-BIx4-Y'),
]

# config for test
test_cfg = dict(type=MultiTestLoop)
test_dataloader = [
reds_dataloader,
vimeo_90k_bd_dataloader,
vimeo_90k_bi_dataloader,
udm10_dataloader,
vid4_bd_dataloader,
vid4_bi_dataloader,
]
test_evaluator = [
reds_evaluator,
vimeo_90k_bd_evaluator,
vimeo_90k_bi_evaluator,
udm10_evaluator,
vid4_bd_evaluator,
vid4_bi_evaluator,
]
48 changes: 48 additions & 0 deletions mmagic/configs/_base_/datasets/celeba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import DefaultSampler, InfiniteSampler

from mmagic.evaluation import MAE, PSNR, SSIM

# Base config for CelebA-HQ dataset

# dataset settings
dataset_type = 'BasicImageDataset'
data_root = 'data/CelebA-HQ'

train_dataloader = dict(
num_workers=4,
persistent_workers=False,
sampler=dict(type=InfiniteSampler, shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(gt=''),
ann_file='train_celeba_img_list.txt',
test_mode=False,
))

val_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(gt=''),
ann_file='val_celeba_img_list.txt',
test_mode=True,
))

test_dataloader = val_dataloader

val_evaluator = [
dict(type=MAE, mask_key='mask', scaling=100),
# By default, compute with pixel value from 0-1
# scale=2 to align with 1.0
# scale=100 seems to align with readme
dict(type=PSNR),
dict(type=SSIM),
]

test_evaluator = val_evaluator
45 changes: 45 additions & 0 deletions mmagic/configs/_base_/datasets/cifar10_nopad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import DefaultSampler, InfiniteSampler

from mmagic.datasets import CIFAR10
from mmagic.datasets.transforms import Flip, PackInputs

cifar_pipeline = [
dict(type=Flip, keys=['gt'], flip_ratio=0.5, direction='horizontal'),
dict(type=PackInputs)
]
cifar_dataset = dict(
type=CIFAR10,
data_root='./data',
data_prefix='cifar10',
test_mode=False,
pipeline=cifar_pipeline)

# test dataset do not use flip
cifar_pipeline_test = [dict(type=PackInputs)]
cifar_dataset_test = dict(
type=CIFAR10,
data_root='./data',
data_prefix='cifar10',
test_mode=False,
pipeline=cifar_pipeline_test)

train_dataloader = dict(
num_workers=2,
dataset=cifar_dataset,
sampler=dict(type=InfiniteSampler, shuffle=True),
persistent_workers=True)

val_dataloader = dict(
batch_size=32,
num_workers=2,
dataset=cifar_dataset_test,
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)

test_dataloader = dict(
batch_size=32,
num_workers=2,
dataset=cifar_dataset_test,
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)
45 changes: 45 additions & 0 deletions mmagic/configs/_base_/datasets/comp1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import DefaultSampler, InfiniteSampler

from mmagic.evaluation import SAD, ConnectivityError, GradientError, MattingMSE

# Base config for Composition-1K dataset

# dataset settings
dataset_type = 'AdobeComp1kDataset'
data_root = 'data/adobe_composition-1k'

train_dataloader = dict(
num_workers=4,
persistent_workers=False,
sampler=dict(type=InfiniteSampler, shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='training_list.json',
test_mode=False,
))

val_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='test_list.json',
test_mode=True,
))

test_dataloader = val_dataloader

# TODO: matting
val_evaluator = [
dict(type=SAD),
dict(type=MattingMSE),
dict(type=GradientError),
dict(type=ConnectivityError),
]

test_evaluator = val_evaluator
Loading
Loading