Skip to content

Commit

Permalink
update calibrate_dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
humu789 committed Apr 23, 2023
1 parent 30e053b commit 8747e03
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
'../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py'
]

val_dataloader = dict(batch_size=32)
_base_.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=val_dataloader,
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
'../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py'
]

val_dataloader = dict(batch_size=32)
_base_.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=val_dataloader,
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
'../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py'
]

val_dataloader = dict(batch_size=32)
_base_.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=val_dataloader,
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
'../../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py'
]

val_dataloader = dict(batch_size=32)
_base_.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=val_dataloader,
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
'../../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py'
]

val_dataloader = dict(batch_size=32)
_base_.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=val_dataloader,
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
'../../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501
]

val_dataloader = dict(batch_size=32)
_base_.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=val_dataloader,
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
'../../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501
]

val_dataloader = dict(batch_size=32)
_base_.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=val_dataloader,
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
'../../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501
]

val_dataloader = dict(batch_size=32)
_base_.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=val_dataloader,
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
'../../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501
]

val_dataloader = dict(batch_size=32)
_base_.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=val_dataloader,
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
'../../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501
]

val_dataloader = dict(batch_size=32)
_base_.val_dataloader.batch_size = 32

test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=val_dataloader,
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)

Expand Down
10 changes: 6 additions & 4 deletions mmrazor/engine/runner/quantization_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,12 @@ def __init__(self,
# Determine whether or not different ranks use different seed.
diff_rank_seed = runner._randomness_cfg.get(
'diff_rank_seed', False)
self.dataloader = runner.build_dataloader(
dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed)
self.calibrate_dataloader = runner.build_dataloader(
calibrate_dataloader,
seed=runner.seed,
diff_rank_seed=diff_rank_seed)
else:
self.dataloader = dataloader
self.calibrate_dataloader = calibrate_dataloader

self.calibrate_steps = calibrate_steps
self.only_val = only_val
Expand All @@ -350,7 +352,7 @@ def run(self) -> dict:
self.runner.model.apply(enable_observer)

print_log('Star calibratiion...')
for idx, data_batch in enumerate(self.dataloader):
for idx, data_batch in enumerate(self.calibrate_dataloader):
if idx == self.calibrate_steps:
break
self.run_iter(idx, data_batch)
Expand Down

0 comments on commit 8747e03

Please sign in to comment.