|
20 | 20 | from collections import OrderedDict |
21 | 21 | from contextlib import suppress |
22 | 22 |
|
23 | | -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models |
| 23 | +from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm |
24 | 24 | from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet |
25 | 25 | from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ |
26 | 26 | decay_batch_step, check_batch_size_retry |
|
117 | 117 | help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") |
118 | 118 | parser.add_argument('--fuser', default='', type=str, |
119 | 119 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") |
| 120 | +parser.add_argument('--fast-norm', default=False, action='store_true', |
| 121 | + help='enable experimental fast-norm') |
120 | 122 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', |
121 | 123 | help='Output csv file for validation results (summary)') |
122 | 124 | parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', |
@@ -150,6 +152,8 @@ def validate(args): |
150 | 152 |
|
151 | 153 | if args.fuser: |
152 | 154 | set_jit_fuser(args.fuser) |
| 155 | + if args.fast_norm: |
| 156 | + set_fast_norm() |
153 | 157 |
|
154 | 158 | # create model |
155 | 159 | model = create_model( |
|
0 commit comments