New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support K-fold cross-validation #563
Conversation
Codecov Report
@@ Coverage Diff @@
## dev #563 +/- ##
==========================================
+ Coverage 81.78% 82.10% +0.31%
==========================================
Files 118 118
Lines 6820 6855 +35
Branches 1174 1181 +7
==========================================
+ Hits 5578 5628 +50
+ Misses 1082 1063 -19
- Partials 160 164 +4
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
目前还不支持K折交叉吗? |
这个 PR 目前正在开发中,我们会在近期完成支持。 _base_ = [
'../_base_/models/resnet18_cifar.py',
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]
# dataset settings
dataset_type = 'CIFAR10'
img_norm_cfg = dict(
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
to_rgb=False)
train_pipeline = [
dict(type='RandomCrop', size=32, padding=4),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=16,
workers_per_gpu=2,
train=dict(
type='KFoldDataset',
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10',
pipeline=train_pipeline),
fold=0,
num_splits=5),
val=dict(
type='KFoldDataset',
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10',
pipeline=test_pipeline),
fold=0,
num_splits=5),
test=dict(
type='KFoldDataset',
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10',
pipeline=test_pipeline),
fold=0,
num_splits=5)) |
* Support to use `indices` to specify which samples to evaluate. * Add KFoldDataset wrapper * Rename 'K' to 'num_splits' accroding to sklearn * Add `kfold-cross-valid.py` * Add unit tests * Add help doc and docstring
* Support to use `indices` to specify which samples to evaluate. * Add KFoldDataset wrapper * Rename 'K' to 'num_splits' accroding to sklearn * Add `kfold-cross-valid.py` * Add unit tests * Add help doc and docstring
Motivation
K-Fold cross-validation is commonly used in small dataset training. Here we add the support of the K-fold dataset wrapper.
Closing #560
Modification
KFoldDataset
.indices
to specify which samples to evaluate.BC-breaking (Optional)
No
Use cases (Optional)
Here we modify the resnet18_8xb32_in1k.py as an example.
Checklist
Before PR:
After PR: