Skip to content

Commit ff6a919

Browse files
committed
Add --fast-norm arg to benchmark.py, train.py, validate.py
1 parent 769ab4b commit ff6a919

File tree

4 files changed

+17
-4
lines changed

4 files changed

+17
-4
lines changed

benchmark.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch.nn.parallel
2020

2121
from timm.data import resolve_data_config
22-
from timm.models import create_model, is_model, list_models
22+
from timm.models import create_model, is_model, list_models, set_fast_norm
2323
from timm.optim import create_optimizer_v2
2424
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
2525

@@ -109,7 +109,8 @@
109109
help='convert model torchscript for inference')
110110
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
111111
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
112-
112+
scripting_group.add_argument('--fast-norm', default=False, action='store_true',
113+
help='enable experimental fast-norm')
113114

114115
# train optimizer parameters
115116
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@@ -598,6 +599,9 @@ def main():
598599
model_cfgs = []
599600
model_names = []
600601

602+
if args.fast_norm:
603+
set_fast_norm()
604+
601605
if args.model_list:
602606
args.model = ''
603607
with open(args.model_list) as f:

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,6 @@
6969
from .layers import TestTimePoolHead, apply_test_time_pool
7070
from .layers import convert_splitbn_model, convert_sync_batchnorm
7171
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
72+
from .layers import set_fast_norm
7273
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
7374
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
3434
LabelSmoothingCrossEntropy
3535
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
36-
convert_splitbn_model, convert_sync_batchnorm, model_parameters
36+
convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm
3737
from timm.optim import create_optimizer_v2, optimizer_kwargs
3838
from timm.scheduler import create_scheduler
3939
from timm.utils import ApexScaler, NativeScaler
@@ -135,6 +135,8 @@
135135
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
136136
group.add_argument('--fuser', default='', type=str,
137137
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
138+
group.add_argument('--fast-norm', default=False, action='store_true',
139+
help='enable experimental fast-norm')
138140
group.add_argument('--grad-checkpointing', action='store_true', default=False,
139141
help='Enable gradient checkpointing through model blocks/stages')
140142

@@ -395,6 +397,8 @@ def main():
395397

396398
if args.fuser:
397399
utils.set_jit_fuser(args.fuser)
400+
if args.fast_norm:
401+
set_fast_norm()
398402

399403
model = create_model(
400404
args.model,

validate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from collections import OrderedDict
2121
from contextlib import suppress
2222

23-
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
23+
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
2424
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
2525
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
2626
decay_batch_step, check_batch_size_retry
@@ -117,6 +117,8 @@
117117
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
118118
parser.add_argument('--fuser', default='', type=str,
119119
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
120+
parser.add_argument('--fast-norm', default=False, action='store_true',
121+
help='enable experimental fast-norm')
120122
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
121123
help='Output csv file for validation results (summary)')
122124
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
@@ -150,6 +152,8 @@ def validate(args):
150152

151153
if args.fuser:
152154
set_jit_fuser(args.fuser)
155+
if args.fast_norm:
156+
set_fast_norm()
153157

154158
# create model
155159
model = create_model(

0 commit comments

Comments
 (0)