From 6ec5cd6a99911f2ae771536fc9dd382243069105 Mon Sep 17 00:00:00 2001 From: Jerome Rony Date: Thu, 17 Nov 2022 11:53:29 -0500 Subject: [PATCH 01/18] Use in-place operations for EMA --- timm/utils/model_ema.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 073d5c5ea1..0213d582a3 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -117,10 +117,15 @@ def _update(self, model, update_fn): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) - ema_v.copy_(update_fn(ema_v, model_v)) + update_fn(ema_v, model_v) def update(self, model): - self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def ema_update(e, m): + if m.is_floating_point(): + e.mul_(self.decay).add_(m, alpha=1 - self.decay) + + self._update(model, update_fn=ema_update) def set(self, model): - self._update(model, update_fn=lambda e, m: m) + self._update(model, update_fn=lambda e, m: e.copy_(m)) From 3491506fecf5adc85e0e4da724c6c2d43d3f9904 Mon Sep 17 00:00:00 2001 From: Jerome Rony Date: Wed, 30 Nov 2022 14:06:58 -0500 Subject: [PATCH 02/18] Add foreach option for faster EMA --- timm/utils/model_ema.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 0213d582a3..5cefe08bf3 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -102,30 +102,34 @@ class ModelEmaV2(nn.Module): This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. """ - def __init__(self, model, decay=0.9999, device=None): + def __init__(self, model, decay=0.9999, device=None, foreach=False): super(ModelEmaV2, self).__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay + self.foreach = foreach self.device = device # perform ema on different device from model if set - if self.device is not None: + if self.device is not None and device != next(model.parameters()).device: + self.foreach = False # cannot use foreach methods with different devices self.module.to(device=device) - def _update(self, model, update_fn): - with torch.no_grad(): - for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - if self.device is not None: - model_v = model_v.to(device=self.device) - update_fn(ema_v, model_v) - + @torch.no_grad() def update(self, model): + ema_params = tuple(self.module.parameters()) + model_params = tuple(model.parameters()) + if self.foreach: + torch._foreach_mul_(ema_params, scalar=self.decay) + torch._foreach_add_(ema_params, model_params, alpha=1 - self.decay) + else: + for ema_p, model_p in zip(ema_params, model_params): + ema_p.mul_(self.decay).add_(model_p.to(device=self.device), alpha=1 - self.decay) - def ema_update(e, m): - if m.is_floating_point(): - e.mul_(self.decay).add_(m, alpha=1 - self.decay) - - self._update(model, update_fn=ema_update) + # copy buffers instead of EMA + for ema_b, model_b in zip(self.module.buffers(), model.buffers()): + ema_b.copy_(model_b.to(device=self.device)) + @torch.no_grad() def set(self, model): - self._update(model, update_fn=lambda e, m: e.copy_(m)) + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + ema_v.copy_(model_v.to(device=self.device)) From a48ab818f509ed6e58399e6b41316a5d4fd8b7e2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 20 Jan 2024 15:10:20 -0800 Subject: [PATCH 03/18] Improving device flexibility in train. Fix #2081 --- timm/utils/distributed.py | 91 +++++++++++++++++++++++++++------------ train.py | 34 ++++++++++----- 2 files changed, 87 insertions(+), 38 deletions(-) diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index ee9a358cf4..95655f2c94 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -2,18 +2,17 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import logging import os +from typing import Optional import torch from torch import distributed as dist -try: - import horovod.torch as hvd -except ImportError: - hvd = None - from .model import unwrap_model +_logger = logging.getLogger(__name__) + def reduce_tensor(tensor, n): rt = tensor.clone() @@ -84,9 +83,38 @@ def init_distributed_device(args): args.world_size = 1 args.rank = 0 # global rank args.local_rank = 0 + result = init_distributed_device_so( + device=getattr(args, 'device', 'cuda'), + dist_backend=getattr(args, 'dist_backend', None), + dist_url=getattr(args, 'dist_url', None), + ) + args.device = result['device'] + args.world_size = result['world_size'] + args.rank = result['global_rank'] + args.local_rank = result['local_rank'] + device = torch.device(args.device) + return device + + +def init_distributed_device_so( + device: str = 'cuda', + dist_backend: Optional[str] = None, + dist_url: Optional[str] = None, +): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + distributed = False + world_size = 1 + global_rank = 0 + local_rank = 0 + if dist_backend is None: + # FIXME sane defaults for other device backends? + dist_backend = 'nccl' if 'cuda' in device else 'gloo' + dist_url = dist_url or 'env://' # TBD, support horovod? # if args.horovod: + # import horovod.torch as hvd # assert hvd is not None, "Horovod is not installed" # hvd.init() # args.local_rank = int(hvd.local_rank()) @@ -96,42 +124,51 @@ def init_distributed_device(args): # os.environ['LOCAL_RANK'] = str(args.local_rank) # os.environ['RANK'] = str(args.rank) # os.environ['WORLD_SIZE'] = str(args.world_size) - dist_backend = getattr(args, 'dist_backend', 'nccl') - dist_url = getattr(args, 'dist_url', 'env://') if is_distributed_env(): if 'SLURM_PROCID' in os.environ: # DDP via SLURM - args.local_rank, args.rank, args.world_size = world_info_from_env() + local_rank, global_rank, world_size = world_info_from_env() # SLURM var -> torch.distributed vars in case needed - os.environ['LOCAL_RANK'] = str(args.local_rank) - os.environ['RANK'] = str(args.rank) - os.environ['WORLD_SIZE'] = str(args.world_size) + os.environ['LOCAL_RANK'] = str(local_rank) + os.environ['RANK'] = str(global_rank) + os.environ['WORLD_SIZE'] = str(world_size) torch.distributed.init_process_group( backend=dist_backend, init_method=dist_url, - world_size=args.world_size, - rank=args.rank, + world_size=world_size, + rank=global_rank, ) else: # DDP via torchrun, torch.distributed.launch - args.local_rank, _, _ = world_info_from_env() + local_rank, _, _ = world_info_from_env() torch.distributed.init_process_group( backend=dist_backend, init_method=dist_url, ) - args.world_size = torch.distributed.get_world_size() - args.rank = torch.distributed.get_rank() - args.distributed = True + world_size = torch.distributed.get_world_size() + global_rank = torch.distributed.get_rank() + distributed = True - if torch.cuda.is_available(): - if args.distributed: - device = 'cuda:%d' % args.local_rank - else: - device = 'cuda:0' + if 'cuda' in device: + assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' + + if distributed and device != 'cpu': + device, device_idx = device.split(':', maxsplit=1) + + # Ignore manually specified device index in distributed mode and + # override with resolved local rank, fewer headaches in most setups. + if device_idx: + _logger.warning(f'device index {device_idx} removed from specified ({device}).') + + device = f'{device}:{local_rank}' + + if device.startswith('cuda:'): torch.cuda.set_device(device) - else: - device = 'cpu' - args.device = device - device = torch.device(device) - return device + return dict( + device=device, + global_rank=global_rank, + local_rank=local_rank, + world_size=world_size, + distributed=distributed, + ) diff --git a/train.py b/train.py index ed74a72016..12eab69039 100755 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse +import importlib import json import logging import os @@ -168,6 +169,24 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor', help="Enable compilation w/ specified backend (default: inductor).") +# Device & distributed +group = parser.add_argument_group('Device parameters') +group.add_argument('--device', default='cuda', type=str, + help="Device (accelerator) to use.") +group.add_argument('--amp', action='store_true', default=False, + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') +group.add_argument('--amp-dtype', default='float16', type=str, + help='lower precision AMP dtype (default: float16)') +group.add_argument('--amp-impl', default='native', type=str, + help='AMP impl to use, "native" or "apex" (default: native)') +group.add_argument('--no-ddp-bb', action='store_true', default=False, + help='Force broadcast buffers for native DDP to off.') +group.add_argument('--synchronize-step', action='store_true', default=False, + help='torch.cuda.synchronize() end of each step') +group.add_argument("--local_rank", default=0, type=int) +parser.add_argument('--device-modules', default=None, type=str, nargs='+', + help="Python imports for device backend modules.") + # Optimizer parameters group = parser.add_argument_group('Optimizer parameters') group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -350,16 +369,6 @@ help='how many training processes to use (default: 4)') group.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') -group.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA Apex AMP or Native AMP for mixed precision training') -group.add_argument('--amp-dtype', default='float16', type=str, - help='lower precision AMP dtype (default: float16)') -group.add_argument('--amp-impl', default='native', type=str, - help='AMP impl to use, "native" or "apex" (default: native)') -group.add_argument('--no-ddp-bb', action='store_true', default=False, - help='Force broadcast buffers for native DDP to off.') -group.add_argument('--synchronize-step', action='store_true', default=False, - help='torch.cuda.synchronize() end of each step') group.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') group.add_argument('--no-prefetcher', action='store_true', default=False, @@ -372,7 +381,6 @@ help='Best metric (default: "top1"') group.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') -group.add_argument("--local_rank", default=0, type=int) group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') group.add_argument('--log-wandb', action='store_true', default=False, @@ -400,6 +408,10 @@ def main(): utils.setup_default_logging() args, args_text = _parse_args() + if args.device_modules: + for module in args.device_modules: + importlib.import_module(module) + if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True From dd84ef2cd5212ff8a7c1ea141e2f8bc159c1788a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 2 Feb 2024 09:45:04 -0800 Subject: [PATCH 04/18] ModelEmaV3 and MESA experiments --- timm/utils/__init__.py | 2 +- timm/utils/model_ema.py | 136 +++++++++++++++++++++++++++++++++++----- train.py | 33 +++++++++- 3 files changed, 153 insertions(+), 18 deletions(-) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 63fcf4c5b4..4c6a00cad5 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -10,6 +10,6 @@ from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg, ParseKwargs from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model -from .model_ema import ModelEma, ModelEmaV2 +from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3 from .random import random_seed from .summary import update_summary, get_outdir diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 5cefe08bf3..968f4f580c 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -5,6 +5,7 @@ import logging from collections import OrderedDict from copy import deepcopy +from typing import Optional import torch import torch.nn as nn @@ -102,32 +103,139 @@ class ModelEmaV2(nn.Module): This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. """ - def __init__(self, model, decay=0.9999, device=None, foreach=False): - super(ModelEmaV2, self).__init__() + def __init__(self, model, decay=0.9999, device=None): + super().__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) + + +class ModelEmaV3(nn.Module): + """ Model Exponential Moving Average V3 + + Keep a moving average of everything in the model state_dict (parameters and buffers). + V3 of this module leverages for_each and in-place operations for faster performance. + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__( + self, + model, + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + use_warmup: bool = False, + warmup_gamma: float = 1.0, + warmup_power: float = 2/3, + device: Optional[torch.device] = None, + foreach: bool = True, + exclude_buffers: bool = False, + ): + super().__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.use_warmup = use_warmup + self.warmup_gamma = warmup_gamma + self.warmup_power = warmup_power self.foreach = foreach self.device = device # perform ema on different device from model if set + self.exclude_buffers = exclude_buffers if self.device is not None and device != next(model.parameters()).device: self.foreach = False # cannot use foreach methods with different devices self.module.to(device=device) - @torch.no_grad() - def update(self, model): - ema_params = tuple(self.module.parameters()) - model_params = tuple(model.parameters()) - if self.foreach: - torch._foreach_mul_(ema_params, scalar=self.decay) - torch._foreach_add_(ema_params, model_params, alpha=1 - self.decay) + def get_decay(self, step: Optional[int] = None) -> float: + """ + Compute the decay factor for the exponential moving average. + """ + if step is None: + return self.decay + + step = max(0, step - self.update_after_step - 1) + if step <= 0: + return 0.0 + + if self.use_warmup: + decay = 1 - (1 + step / self.warmup_gamma) ** -self.warmup_power + decay = max(min(decay, self.decay), self.min_decay) else: - for ema_p, model_p in zip(ema_params, model_params): - ema_p.mul_(self.decay).add_(model_p.to(device=self.device), alpha=1 - self.decay) + decay = self.decay + + return decay - # copy buffers instead of EMA - for ema_b, model_b in zip(self.module.buffers(), model.buffers()): - ema_b.copy_(model_b.to(device=self.device)) + @torch.no_grad() + def update(self, model, step: Optional[int] = None): + decay = self.get_decay(step) + + if self.exclude_buffers: + # interpolate parameters + ema_params = tuple(self.module.parameters()) + model_params = tuple(model.parameters()) + if self.foreach: + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(ema_params, model_params, weight=1. - decay) + else: + torch._foreach_mul_(ema_params, scalar=decay) + torch._foreach_add_(ema_params, model_params, alpha=1 - decay) + else: + for ema_p, model_p in zip(ema_params, model_params): + ema_p.lerp_(model_p, weight=1. - decay) + + # copy buffers instead of EMA + for ema_b, model_b in zip(self.module.buffers(), model.buffers()): + ema_b.copy_(model_b.to(device=self.device)) + else: + # interpolate parameters and buffers + if self.foreach: + ema_lerp_values = [] + model_lerp_values = [] + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if ema_v.is_floating_point(): + ema_lerp_values.append(ema_v) + model_lerp_values.append(model_v) + else: + ema_v.copy_(model_v) + + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay) + else: + torch._foreach_mul_(ema_lerp_values, scalar=decay) + torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay) + else: + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if ema_v.is_floating_point(): + ema_v.lerp_(model_v, weight=1. - decay) + else: + ema_v.copy_(model_v) @torch.no_grad() def set(self, model): diff --git a/train.py b/train.py index ba917773a0..e3b3e037a6 100755 --- a/train.py +++ b/train.py @@ -586,8 +586,12 @@ def main(): model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper - model_ema = utils.ModelEmaV2( - model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) + model_ema = utils.ModelEmaV3( + model, + decay=args.model_ema_decay, + use_warmup=True, + device='cpu' if args.model_ema_force_cpu else None, + ) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) @@ -847,6 +851,7 @@ def main(): loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + num_updates_total=num_epochs * updates_per_epoch, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -935,6 +940,7 @@ def train_one_epoch( loss_scaler=None, model_ema=None, mixup_fn=None, + num_updates_total=None, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -981,6 +987,27 @@ def _forward(): with amp_autocast(): output = model(input) loss = loss_fn(output, target) + + if num_updates / num_updates_total > 0.25: + with torch.no_grad(): + output_mesa = model_ema.module(input) + + # loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits( + # output, + # torch.sigmoid(output_mesa).detach(), + # reduction='none', + # ).mean() + + # loss_mesa = loss_fn( + # output, torch.sigmoid(output_mesa).detach()) + + loss_mesa = torch.nn.functional.kl_div( + (output / 5).log_softmax(-1), + (output_mesa / 5).log_softmax(-1).detach(), + log_target=True, + reduction='none').sum(-1).mean() + loss += 5 * loss_mesa + if accum_steps > 1: loss /= accum_steps return loss @@ -1026,7 +1053,7 @@ def _backward(_loss): num_updates += 1 optimizer.zero_grad() if model_ema is not None: - model_ema.update(model) + model_ema.update(model, step=num_updates) if args.synchronize_step and device.type == 'cuda': torch.cuda.synchronize() From bee0471f91708e60fcab0ffe05c3866136ec8383 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 3 Feb 2024 16:24:45 -0800 Subject: [PATCH 05/18] forward() pass through for ema model, flag for ema warmup, comment about warmup --- timm/utils/model_ema.py | 87 ++++++++++++++++++++++++----------------- train.py | 16 +++++--- 2 files changed, 62 insertions(+), 41 deletions(-) diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 968f4f580c..3e4916756d 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -126,6 +126,9 @@ def update(self, model): def set(self, model): self._update(model, update_fn=lambda e, m: m) + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + class ModelEmaV3(nn.Module): """ Model Exponential Moving Average V3 @@ -133,6 +136,13 @@ class ModelEmaV3(nn.Module): Keep a moving average of everything in the model state_dict (parameters and buffers). V3 of this module leverages for_each and in-place operations for faster performance. + Decay warmup based on code by @crowsonkb, her comments: + If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are + good values for models you plan to train for a million or more steps (reaches decay + factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models + you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at + 215.4k steps). + This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage @@ -195,49 +205,56 @@ def get_decay(self, step: Optional[int] = None) -> float: @torch.no_grad() def update(self, model, step: Optional[int] = None): decay = self.get_decay(step) - if self.exclude_buffers: - # interpolate parameters - ema_params = tuple(self.module.parameters()) - model_params = tuple(model.parameters()) - if self.foreach: - if hasattr(torch, '_foreach_lerp_'): - torch._foreach_lerp_(ema_params, model_params, weight=1. - decay) + self.apply_update_no_buffers_(model, decay) + else: + self.apply_update_(model, decay) + + def apply_update_(self, model, decay: float): + # interpolate parameters and buffers + if self.foreach: + ema_lerp_values = [] + model_lerp_values = [] + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if ema_v.is_floating_point(): + ema_lerp_values.append(ema_v) + model_lerp_values.append(model_v) else: - torch._foreach_mul_(ema_params, scalar=decay) - torch._foreach_add_(ema_params, model_params, alpha=1 - decay) - else: - for ema_p, model_p in zip(ema_params, model_params): - ema_p.lerp_(model_p, weight=1. - decay) + ema_v.copy_(model_v) - # copy buffers instead of EMA - for ema_b, model_b in zip(self.module.buffers(), model.buffers()): - ema_b.copy_(model_b.to(device=self.device)) + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay) + else: + torch._foreach_mul_(ema_lerp_values, scalar=decay) + torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay) else: - # interpolate parameters and buffers - if self.foreach: - ema_lerp_values = [] - model_lerp_values = [] - for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - if ema_v.is_floating_point(): - ema_lerp_values.append(ema_v) - model_lerp_values.append(model_v) - else: - ema_v.copy_(model_v) - - if hasattr(torch, '_foreach_lerp_'): - torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay) + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if ema_v.is_floating_point(): + ema_v.lerp_(model_v, weight=1. - decay) else: - torch._foreach_mul_(ema_lerp_values, scalar=decay) - torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay) + ema_v.copy_(model_v) + + def apply_update_no_buffers_(self, model, decay: float): + # interpolate parameters, copy buffers + ema_params = tuple(self.module.parameters()) + model_params = tuple(model.parameters()) + if self.foreach: + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(ema_params, model_params, weight=1. - decay) else: - for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - if ema_v.is_floating_point(): - ema_v.lerp_(model_v, weight=1. - decay) - else: - ema_v.copy_(model_v) + torch._foreach_mul_(ema_params, scalar=decay) + torch._foreach_add_(ema_params, model_params, alpha=1 - decay) + else: + for ema_p, model_p in zip(ema_params, model_params): + ema_p.lerp_(model_p, weight=1. - decay) + + for ema_b, model_b in zip(self.module.buffers(), model.buffers()): + ema_b.copy_(model_b.to(device=self.device)) @torch.no_grad() def set(self, model): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): ema_v.copy_(model_v.to(device=self.device)) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) \ No newline at end of file diff --git a/train.py b/train.py index a773d855de..5e02722965 100755 --- a/train.py +++ b/train.py @@ -349,11 +349,13 @@ # Model Exponential Moving Average group = parser.add_argument_group('Model exponential moving average parameters') group.add_argument('--model-ema', action='store_true', default=False, - help='Enable tracking moving average of model weights') + help='Enable tracking moving average of model weights.') group.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') group.add_argument('--model-ema-decay', type=float, default=0.9998, - help='decay factor for model weights moving average (default: 0.9998)') + help='Decay factor for model weights moving average (default: 0.9998)') +group.add_argument('--model-ema-warmup', action='store_true', + help='Enable warmup for model EMA decay.') # Misc group = parser.add_argument_group('Miscellaneous parameters') @@ -601,11 +603,13 @@ def main(): model_ema = utils.ModelEmaV3( model, decay=args.model_ema_decay, - use_warmup=True, + use_warmup=args.model_ema_warmup, device='cpu' if args.model_ema_force_cpu else None, ) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) + if args.torchcompile: + model_ema = torch.compile(model_ema, backend=args.torchcompile) # setup distributed training if args.distributed: @@ -885,7 +889,7 @@ def main(): utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( - model_ema.module, + model_ema, loader_eval, validate_loss_fn, args, @@ -1002,7 +1006,7 @@ def _forward(): if num_updates / num_updates_total > 0.25: with torch.no_grad(): - output_mesa = model_ema.module(input) + output_mesa = model_ema(input) # loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits( # output, @@ -1018,7 +1022,7 @@ def _forward(): (output_mesa / 5).log_softmax(-1).detach(), log_target=True, reduction='none').sum(-1).mean() - loss += 5 * loss_mesa + loss += 10 * loss_mesa if accum_steps > 1: loss /= accum_steps From a08b57e801ab936ee7a7f54ac605f6257a5ee32f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 3 Feb 2024 16:26:15 -0800 Subject: [PATCH 06/18] Fix distributed flag bug w/ flex device handling --- timm/utils/distributed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 95655f2c94..92b8a6b853 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -92,6 +92,7 @@ def init_distributed_device(args): args.world_size = result['world_size'] args.rank = result['global_rank'] args.local_rank = result['local_rank'] + args.distributed = args.world_size > 1 device = torch.device(args.device) return device From c7ac37693d5eca561b016db750ef0af95d672dbd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 4 Feb 2024 10:14:57 -0800 Subject: [PATCH 07/18] Add device arg to validate() calls in train.py --- train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train.py b/train.py index 5e02722965..39e889ede2 100755 --- a/train.py +++ b/train.py @@ -881,6 +881,7 @@ def main(): loader_eval, validate_loss_fn, args, + device=device, amp_autocast=amp_autocast, ) @@ -893,6 +894,7 @@ def main(): loader_eval, validate_loss_fn, args, + device=device, amp_autocast=amp_autocast, log_suffix=' (EMA)', ) From 5a58f4d3dcabf51b13771f3a3830d6ced71f1913 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 14:38:01 -0800 Subject: [PATCH 08/18] Remove test MESA support, no signal that it's helpful so far --- train.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/train.py b/train.py index 39e889ede2..539dff3d77 100755 --- a/train.py +++ b/train.py @@ -1005,27 +1005,6 @@ def _forward(): with amp_autocast(): output = model(input) loss = loss_fn(output, target) - - if num_updates / num_updates_total > 0.25: - with torch.no_grad(): - output_mesa = model_ema(input) - - # loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits( - # output, - # torch.sigmoid(output_mesa).detach(), - # reduction='none', - # ).mean() - - # loss_mesa = loss_fn( - # output, torch.sigmoid(output_mesa).detach()) - - loss_mesa = torch.nn.functional.kl_div( - (output / 5).log_softmax(-1), - (output_mesa / 5).log_softmax(-1).detach(), - log_target=True, - reduction='none').sum(-1).mean() - loss += 10 * loss_mesa - if accum_steps > 1: loss /= accum_steps return loss From 7d121ac2ef4233b22a78921224ec5fc652dcf6bb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 14:57:40 -0800 Subject: [PATCH 09/18] Small tweak of timm ToTensor for clarity --- timm/data/transforms.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 822983fed3..02a069bd70 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -32,16 +32,12 @@ def __call__(self, pil_img): class ToTensor: - + """ ToTensor with no rescaling of values""" def __init__(self, dtype=torch.float32): self.dtype = dtype def __call__(self, pil_img): - np_img = np.array(pil_img, dtype=np.uint8) - if np_img.ndim < 3: - np_img = np.expand_dims(np_img, axis=-1) - np_img = np.rollaxis(np_img, 2) # HWC to CHW - return torch.from_numpy(np_img).to(dtype=self.dtype) + return F.pil_to_tensor(pil_img).to(dtype=self.dtype) # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in From 7bc7798d0eedea5c16acacfe47e1924d2573ac6c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 14:57:58 -0800 Subject: [PATCH 10/18] Type annotation correctness for create_act --- timm/layers/create_act.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/create_act.py b/timm/layers/create_act.py index c473c5a95b..93bcbf0e48 100644 --- a/timm/layers/create_act.py +++ b/timm/layers/create_act.py @@ -148,7 +148,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): return _ACT_LAYER_DEFAULT[name] -def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): +def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs): act_layer = get_act_layer(name) if act_layer is None: return None From 7d3c2dc993162618b56caf0916d1f33c9a8312ba Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 14:58:45 -0800 Subject: [PATCH 11/18] Add group_matcher for DaViT --- timm/models/davit.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/timm/models/davit.py b/timm/models/davit.py index f00cf73384..e7f2ed0eb3 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -547,6 +547,17 @@ def _init_weights(self, m): if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm_pre', (99999,)), + ] + ) + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable From 87fec3dc148ebfda1f3440ddcd7831ae02e0880a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 16:05:58 -0800 Subject: [PATCH 12/18] Update experimental vit model configs --- timm/models/vision_transformer.py | 35 +++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 380e3a6405..0bd0cc7779 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1723,7 +1723,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 256, 256)), 'vit_medium_patch16_reg4_gap_256': _cfg( input_size=(3, 256, 256)), - 'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)), + 'vit_base_patch16_reg4_gap_256': _cfg( + input_size=(3, 256, 256)), + 'vit_so150m_patch16_reg4_gap_256': _cfg( + input_size=(3, 256, 256)), + 'vit_so150m_patch16_reg4_map_256': _cfg( + input_size=(3, 256, 256)), } _quick_gelu_cfgs = [ @@ -2623,13 +2628,35 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio @register_model -def vit_base_patch16_reg8_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: +def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, - no_embed_class=True, global_pool='avg', reg_tokens=8, + no_embed_class=True, global_pool='avg', reg_tokens=4, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, + class_token=False, reg_tokens=4, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, + class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False, ) model = _create_vision_transformer( - 'vit_base_patch16_reg8_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model From d6c2cc91af13e89990c05979638e8ffaa0335b5e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 16:25:33 -0800 Subject: [PATCH 13/18] Make NormMlpClassifier head reset args consistent with ClassifierHead --- timm/layers/classifier.py | 8 ++++---- timm/models/davit.py | 2 +- timm/models/tiny_vit.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 2eb4ec2eb6..71e45c87a7 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -180,10 +180,10 @@ def __init__( self.drop = nn.Dropout(drop_rate) self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def reset(self, num_classes, global_pool=None): - if global_pool is not None: - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + def reset(self, num_classes, pool_type=None): + if pool_type is not None: + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + self.flatten = nn.Flatten(1) if pool_type else nn.Identity() self.use_conv = self.global_pool.is_identity() linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear if self.hidden_size: diff --git a/timm/models/davit.py b/timm/models/davit.py index e7f2ed0eb3..d4d6ad690a 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -569,7 +569,7 @@ def get_classifier(self): return self.head.fc def reset_classifier(self, num_classes, global_pool=None): - self.head.reset(num_classes, global_pool=global_pool) + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 96a88db7f3..b4b2964810 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -535,7 +535,7 @@ def get_classifier(self): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - self.head.reset(num_classes, global_pool=global_pool) + self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): x = self.patch_embed(x) From 0737cf231d3e006a9edcf0ad14fa6fc596c6537d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 17:05:16 -0800 Subject: [PATCH 14/18] Add Next-ViT --- timm/models/__init__.py | 1 + timm/models/nextvit.py | 685 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 686 insertions(+) create mode 100644 timm/models/nextvit.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 0eb9561d54..6b6963dc37 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -39,6 +39,7 @@ from .mvitv2 import * from .nasnet import * from .nest import * +from .nextvit import * from .nfnet import * from .pit import * from .pnasnet import * diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py new file mode 100644 index 0000000000..fcbd1ab9fe --- /dev/null +++ b/timm/models/nextvit.py @@ -0,0 +1,685 @@ +""" Next-ViT + +As described in https://arxiv.org/abs/2207.05501 + +Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-ViT, original copyright below +""" +# Copyright (c) ByteDance Inc. All rights reserved. +from functools import partial + +import torch +import torch.nn.functional as F +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn +from timm.layers import ClassifierHead +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq +from ._registry import generate_default_cfgs, register_model + + +def merge_pre_bn(module, pre_bn_1, pre_bn_2=None): + """ Merge pre BN to reduce inference runtime. + """ + weight = module.weight.data + if module.bias is None: + zeros = torch.zeros(module.out_chs, device=weight.device).type(weight.type()) + module.bias = nn.Parameter(zeros) + bias = module.bias.data + if pre_bn_2 is None: + assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False" + assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False" + + scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5) + extra_weight = scale_invstd * pre_bn_1.weight + extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd + else: + assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False" + assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False" + + assert pre_bn_2.track_running_stats is True, "Unsupported bn_module.track_running_stats is False" + assert pre_bn_2.affine is True, "Unsupported bn_module.affine is False" + + scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5) + scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5) + + extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight + extra_bias = ( + scale_invstd_2 * pre_bn_2.weight + * (pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean) + + pre_bn_2.bias + ) + + if isinstance(module, nn.Linear): + extra_bias = weight @ extra_bias + weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight)) + elif isinstance(module, nn.Conv2d): + assert weight.shape[2] == 1 and weight.shape[3] == 1 + weight = weight.reshape(weight.shape[0], weight.shape[1]) + extra_bias = weight @ extra_bias + weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight)) + weight = weight.reshape(weight.shape[0], weight.shape[1], 1, 1) + bias.add_(extra_bias) + + module.weight.data = weight + module.bias.data = bias + + +class ConvNormAct(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size=3, + stride=1, + groups=1, + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(ConvNormAct, self).__init__() + self.conv = nn.Conv2d( + in_chs, out_chs, kernel_size=kernel_size, stride=stride, + padding=1, groups=groups, bias=False) + self.norm = norm_layer(out_chs) + self.act = act_layer() + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + return x + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class PatchEmbed(nn.Module): + def __init__(self, + in_chs, + out_chs, + stride=1, + norm_layer = nn.BatchNorm2d, + ): + super(PatchEmbed, self).__init__() + + if stride == 2: + self.pool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False) + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False) + self.norm = norm_layer(out_chs) + elif in_chs != out_chs: + self.pool = nn.Identity() + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False) + self.norm = norm_layer(out_chs) + else: + self.pool = nn.Identity() + self.conv = nn.Identity() + self.norm = nn.Identity() + + def forward(self, x): + return self.norm(self.conv(self.pool(x))) + + +class ConvAttention(nn.Module): + """ + Multi-Head Convolutional Attention + """ + + def __init__(self, out_chs, head_dim, norm_layer = nn.BatchNorm2d, act_layer = nn.ReLU): + super(ConvAttention, self).__init__() + self.group_conv3x3 = nn.Conv2d( + out_chs, out_chs, + kernel_size=3, stride=1, padding=1, groups=out_chs // head_dim, bias=False + ) + self.norm = norm_layer(out_chs) + self.act = act_layer() + self.projection = nn.Conv2d(out_chs, out_chs, kernel_size=1, bias=False) + + def forward(self, x): + out = self.group_conv3x3(x) + out = self.norm(out) + out = self.act(out) + out = self.projection(out) + return out + +class NextConvBlock(nn.Module): + """ + Next Convolution Block + """ + + def __init__( + self, + in_chs, + out_chs, + stride=1, + drop_path=0., + drop=0., + head_dim=32, + mlp_ratio=3., + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU + ): + super(NextConvBlock, self).__init__() + self.in_chs = in_chs + self.out_chs = out_chs + assert out_chs % head_dim == 0 + + self.patch_embed = PatchEmbed(in_chs, out_chs, stride, norm_layer=norm_layer) + self.mhca = ConvAttention( + out_chs, + head_dim, + norm_layer=norm_layer, + act_layer=act_layer, + ) + self.attn_drop_path = DropPath(drop_path) + + self.norm = norm_layer(out_chs) + self.mlp = ConvMlp( + out_chs, + hidden_features=int(out_chs * mlp_ratio), + drop=drop, + bias=True, + act_layer=act_layer, + ) + self.mlp_drop_path = DropPath(drop_path) + self.is_fused = False + + @torch.no_grad() + def reparameterize(self): + if not self.is_fused: + merge_pre_bn(self.mlp.fc1, self.norm) + self.norm = None + self.is_fused = True + + def forward(self, x): + x = self.patch_embed(x) + x = x + self.attn_drop_path(self.mhca(x)) + + out = self.norm(x) + x = x + self.mlp_drop_path(self.mlp(out)) + return x + + +class EfficientAttention(nn.Module): + """ + Efficient Multi-Head Self Attention + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim, + out_dim=None, + head_dim=32, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + sr_ratio=1, + norm_layer=nn.BatchNorm1d, + ): + super().__init__() + self.dim = dim + self.out_dim = out_dim if out_dim is not None else dim + self.num_heads = self.dim // head_dim + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.q = nn.Linear(dim, self.dim, bias=qkv_bias) + self.k = nn.Linear(dim, self.dim, bias=qkv_bias) + self.v = nn.Linear(dim, self.dim, bias=qkv_bias) + self.proj = nn.Linear(self.dim, self.out_dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + self.N_ratio = sr_ratio ** 2 + if sr_ratio > 1: + self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio) + self.norm = norm_layer(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + if self.sr is not None: + x = self.sr(x.transpose(1, 2)) + x = self.norm(x).transpose(1, 2) + + k = self.k(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-1, -2) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class NextTransformerBlock(nn.Module): + """ + Next Transformer Block + """ + + def __init__( + self, + in_chs, + out_chs, + drop_path, + stride=1, + sr_ratio=1, + mlp_ratio=2, + head_dim=32, + mix_block_ratio=0.75, + attn_drop=0., + drop=0., + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(NextTransformerBlock, self).__init__() + self.in_chs = in_chs + self.out_chs = out_chs + self.mix_block_ratio = mix_block_ratio + + self.mhsa_out_chs = _make_divisible(int(out_chs * mix_block_ratio), 32) + self.mhca_out_chs = out_chs - self.mhsa_out_chs + + self.patch_embed = PatchEmbed(in_chs, self.mhsa_out_chs, stride) + self.norm1 = norm_layer(self.mhsa_out_chs) + self.e_mhsa = EfficientAttention( + self.mhsa_out_chs, + head_dim=head_dim, + sr_ratio=sr_ratio, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.mhsa_drop_path = DropPath(drop_path * mix_block_ratio) + + self.projection = PatchEmbed(self.mhsa_out_chs, self.mhca_out_chs, stride=1, norm_layer=norm_layer) + self.mhca = ConvAttention( + self.mhca_out_chs, + head_dim=head_dim, + norm_layer=norm_layer, + act_layer=act_layer, + ) + self.mhca_drop_path = DropPath(drop_path * (1 - mix_block_ratio)) + + self.norm2 = norm_layer(out_chs) + self.mlp = ConvMlp( + out_chs, + hidden_features=int(out_chs * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.mlp_drop_path = DropPath(drop_path) + self.is_fused = False + + @torch.no_grad() + def reparameterize(self): + if not self.is_fused: + merge_pre_bn(self.e_mhsa.q, self.norm1) + if self.e_mhsa.norm is not None: + merge_pre_bn(self.e_mhsa.k, self.norm1, self.e_mhsa.norm) + merge_pre_bn(self.e_mhsa.v, self.norm1, self.e_mhsa.norm) + self.e_mhsa.norm = nn.Identity() + else: + merge_pre_bn(self.e_mhsa.k, self.norm1) + merge_pre_bn(self.e_mhsa.v, self.norm1) + self.norm1 = nn.Identity() + + merge_pre_bn(self.mlp.fc1, self.norm2) + self.norm2 = nn.Identity() + self.is_fused = True + + def forward(self, x): + x = self.patch_embed(x) + B, C, H, W = x.shape + + out = self.norm1(x) + out = out.reshape(B, C, -1).transpose(-1, -2) + out = self.mhsa_drop_path(self.e_mhsa(out)) + x = x + out.transpose(-1, -2).reshape(B, C, H, W) + + out = self.projection(x) + out = out + self.mhca_drop_path(self.mhca(out)) + x = torch.cat([x, out], dim=1) + + out = self.norm2(x) + x = x + self.mlp_drop_path(self.mlp(out)) + return x + + +class NextStage(nn.Module): + + def __init__( + self, + in_chs, + block_chs, + block_types, + stride=2, + sr_ratio=1, + mix_block_ratio=1.0, + drop=0., + attn_drop=0., + drop_path=0., + head_dim=32, + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super().__init__() + self.grad_checkpointing = False + + blocks = [] + for block_idx, block_type in enumerate(block_types): + stride = stride if block_idx == 0 else 1 + out_chs = block_chs[block_idx] + block_type = block_types[block_idx] + dpr = drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path + if block_type is NextConvBlock: + layer = NextConvBlock( + in_chs, + out_chs, + stride=stride, + drop_path=dpr, + drop=drop, + head_dim=head_dim, + norm_layer=norm_layer, + act_layer=act_layer, + ) + blocks.append(layer) + elif block_type is NextTransformerBlock: + layer = NextTransformerBlock( + in_chs, + out_chs, + drop_path=dpr, + stride=stride, + sr_ratio=sr_ratio, + head_dim=head_dim, + mix_block_ratio=mix_block_ratio, + attn_drop=attn_drop, + drop=drop, + norm_layer=norm_layer, + act_layer=act_layer, + ) + blocks.append(layer) + in_chs = out_chs + + self.blocks = nn.Sequential(*blocks) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + def forward(self, x): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class NextViT(nn.Module): + def __init__( + self, + in_chans, + num_classes=1000, + global_pool='avg', + stem_chs=(64, 32, 64), + depths=(3, 4, 10, 3), + strides=(1, 2, 2, 2), + sr_ratios=(8, 4, 2, 1), + drop_path_rate=0.1, + attn_drop_rate=0., + drop_rate=0., + head_dim=32, + mix_block_ratio=0.75, + norm_layer=nn.BatchNorm2d, + act_layer=None, + ): + super(NextViT, self).__init__() + self.grad_checkpointing = False + self.num_classes = num_classes + norm_layer = get_norm_layer(norm_layer) + if act_layer is None: + act_layer = partial(nn.ReLU, inplace=True) + else: + act_layer = get_act_layer(act_layer) + + self.stage_out_chs = [ + [96] * (depths[0]), + [192] * (depths[1] - 1) + [256], + [384, 384, 384, 384, 512] * (depths[2] // 5), + [768] * (depths[3] - 1) + [1024] + ] + self.feature_info = [dict( + num_chs=sc[-1], + reduction=2**(i + 2), + module=f'stages.{i}' + ) for i, sc in enumerate(self.stage_out_chs)] + + # Next Hybrid Strategy + self.stage_block_types = [ + [NextConvBlock] * depths[0], + [NextConvBlock] * (depths[1] - 1) + [NextTransformerBlock], + [NextConvBlock, NextConvBlock, NextConvBlock, NextConvBlock, NextTransformerBlock] * (depths[2] // 5), + [NextConvBlock] * (depths[3] - 1) + [NextTransformerBlock]] + + self.stem = nn.Sequential( + ConvNormAct(in_chans, stem_chs[0], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer), + ConvNormAct(stem_chs[0], stem_chs[1], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer), + ConvNormAct(stem_chs[1], stem_chs[2], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer), + ConvNormAct(stem_chs[2], stem_chs[2], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer), + ) + in_chs = out_chs = stem_chs[-1] + stages = [] + idx = 0 + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + for stage_idx in range(len(depths)): + stage = NextStage( + in_chs=in_chs, + block_chs=self.stage_out_chs[stage_idx], + block_types=self.stage_block_types[stage_idx], + stride=strides[stage_idx], + sr_ratio=sr_ratios[stage_idx], + mix_block_ratio=mix_block_ratio, + head_dim=head_dim, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[stage_idx], + norm_layer=norm_layer, + act_layer=act_layer, + ) + in_chs = out_chs = self.stage_out_chs[stage_idx][-1] + stages += [stage] + idx += depths[stage_idx] + self.num_features = out_chs + self.stages = nn.Sequential(*stages) + self.norm = norm_layer(out_chs) + self.head = ClassifierHead(pool_type=global_pool, in_features=out_chs, num_classes=num_classes) + + self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))] + self._initialize_weights() + + def _initialize_weights(self): + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=.02) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + for stage in self.stages: + stage.set_grad_checkpointing(enable=enable) + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.head.reset(num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + """ Remap original checkpoints -> timm """ + if 'head.fc.weight' in state_dict: + return state_dict # non-original + + D = model.state_dict() + out_dict = {} + # remap originals based on order + for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()): + out_dict[ka] = vb + + return out_dict + + +def _create_nextvit(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + NextViT, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'nextvit_small.bd_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_base.bd_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_large.bd_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_small.bd_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'nextvit_base.bd_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'nextvit_large.bd_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + + 'nextvit_small.bd_ssld_6m_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_base.bd_ssld_6m_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_large.bd_ssld_6m_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_small.bd_ssld_6m_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'nextvit_base.bd_ssld_6m_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'nextvit_large.bd_ssld_6m_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), +}) + + +@register_model +def nextvit_small(pretrained=False, **kwargs): + model_args = dict(depths=(3, 4, 10, 3), drop_path_rate=0.1) + model = _create_nextvit( + 'nextvit_small', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def nextvit_base(pretrained=False, **kwargs): + model_args = dict(depths=(3, 4, 20, 3), drop_path_rate=0.2) + model = _create_nextvit( + 'nextvit_base', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def nextvit_large(pretrained=False, **kwargs): + model_args = dict(depths=(3, 4, 30, 3), drop_path_rate=0.2) + model = _create_nextvit( + 'nextvit_large', pretrained=pretrained, **dict(model_args, **kwargs)) + return model From 935950cc11de3103912ab2f00756fa0ef012e178 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 20:14:47 -0800 Subject: [PATCH 15/18] Fix F.sdpa attn drop prob --- timm/models/nextvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index fcbd1ab9fe..7ef56a3828 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -263,7 +263,7 @@ def forward(self, x): if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop if self.training else 0., + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale From ac1b08deb66416cb2d664c6c800e9c1b578de4ce Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 20:15:37 -0800 Subject: [PATCH 16/18] fix_init on vit & relpos vit --- timm/models/vision_transformer.py | 16 ++- timm/models/vision_transformer_relpos.py | 125 +++++++++++++---------- 2 files changed, 85 insertions(+), 56 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 0bd0cc7779..70f91d588b 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -421,6 +421,7 @@ def __init__( attn_drop_rate: float = 0., drop_path_rate: float = 0., weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', + fix_init: bool = False, embed_layer: Callable = PatchEmbed, norm_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None, @@ -449,6 +450,7 @@ def __init__( attn_drop_rate: Attention dropout rate. drop_path_rate: Stochastic depth rate. weight_init: Weight initialization scheme. + fix_init: Apply weight initialization fix (scaling w/ layer index). embed_layer: Patch embedding layer. norm_layer: Normalization layer. act_layer: MLP activation layer. @@ -536,8 +538,18 @@ def __init__( if weight_init != 'skip': self.init_weights(weight_init) + if fix_init: + self.fix_init_weight() - def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None: + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def init_weights(self, mode: str = '') -> None: assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) @@ -737,7 +749,7 @@ def init_weights_vit_moco(module: nn.Module, name: str = '') -> None: module.init_weights() -def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> None: +def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: if 'jax' in mode: return partial(init_weights_vit_jax, head_bias=head_bias) elif 'moco' in mode: diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 2cd37cfe7e..f2054167f5 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -7,7 +7,11 @@ import logging import math from functools import partial -from typing import Optional, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal import torch import torch.nn as nn @@ -15,9 +19,11 @@ from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn +from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType from ._builder import build_model_with_cfg +from ._manipulate import named_apply from ._registry import generate_default_cfgs, register_model +from .vision_transformer import get_init_weights_vit __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this @@ -215,59 +221,61 @@ class VisionTransformerRelPos(nn.Module): def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - global_pool='avg', - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4., - qkv_bias=True, - qk_norm=False, - init_values=1e-6, - class_token=False, - fc_norm=False, - rel_pos_type='mlp', - rel_pos_dim=None, - shared_rel_pos=False, - drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - weight_init='skip', - embed_layer=PatchEmbed, - norm_layer=None, - act_layer=None, - block_fn=RelPosBlock + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal['', 'avg', 'token', 'map'] = 'avg', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = 1e-6, + class_token: bool = False, + fc_norm: bool = False, + rel_pos_type: str = 'mlp', + rel_pos_dim: Optional[int] = None, + shared_rel_pos: bool = False, + drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip', + fix_init: bool = False, + embed_layer: Type[nn.Module] = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = RelPosBlock ): """ Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - global_pool (str): type of global pooling for final sequence (default: 'avg') - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - qk_norm (bool): Enable normalization of query and key in attention - init_values: (float): layer-scale init values - class_token (bool): use class token (default: False) - fc_norm (bool): use pre classifier norm instead of pre-pool - rel_pos_ty pe (str): type of relative position - shared_rel_pos (bool): share relative pos across all blocks - drop_rate (float): dropout rate - proj_drop_rate (float): projection dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - weight_init (str): weight init scheme - embed_layer (nn.Module): patch embedding layer - norm_layer: (nn.Module): normalization layer - act_layer: (nn.Module): MLP activation layer + img_size: input image size + patch_size: patch size + in_chans: number of input channels + num_classes: number of classes for classification head + global_pool: type of global pooling for final sequence (default: 'avg') + embed_dim: embedding dimension + depth: depth of transformer + num_heads: number of attention heads + mlp_ratio: ratio of mlp hidden dim to embedding dim + qkv_bias: enable bias for qkv if True + qk_norm: Enable normalization of query and key in attention + init_values: layer-scale init values + class_token: use class token (default: False) + fc_norm: use pre classifier norm instead of pre-pool + rel_pos_type: type of relative position + shared_rel_pos: share relative pos across all blocks + drop_rate: dropout rate + proj_drop_rate: projection dropout rate + attn_drop_rate: attention dropout rate + drop_path_rate: stochastic depth rate + weight_init: weight init scheme + fix_init: apply weight initialization fix (scaling w/ layer index) + embed_layer: patch embedding layer + norm_layer: normalization layer + act_layer: MLP activation layer """ super().__init__() assert global_pool in ('', 'avg', 'token') @@ -332,13 +340,22 @@ def __init__( if weight_init != 'skip': self.init_weights(weight_init) + if fix_init: + self.fix_init_weight() def init_weights(self, mode=''): assert mode in ('jax', 'moco', '') if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) - # FIXME weight init scheme using PyTorch defaults curently - #named_apply(get_init_weights_vit(mode, head_bias), self) + named_apply(get_init_weights_vit(mode), self) + + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) @torch.jit.ignore def no_weight_decay(self): From 59239d9df5223b09f8273818b844c50b483ad238 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 21:40:57 -0800 Subject: [PATCH 17/18] Cleanup imports for vit relpos --- timm/models/vision_transformer_relpos.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index f2054167f5..ea8cf0ea1d 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -7,7 +7,8 @@ import logging import math from functools import partial -from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List +from typing import Optional, Tuple, Type, Union + try: from typing import Literal except ImportError: From 47c9bc4dc675daeaea337ded67fd02ecf415f5bf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 21:41:14 -0800 Subject: [PATCH 18/18] Fix device idx split --- timm/utils/distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 92b8a6b853..286e8ba499 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -92,7 +92,7 @@ def init_distributed_device(args): args.world_size = result['world_size'] args.rank = result['global_rank'] args.local_rank = result['local_rank'] - args.distributed = args.world_size > 1 + args.distributed = result['distributed'] device = torch.device(args.device) return device @@ -154,12 +154,12 @@ def init_distributed_device_so( assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' if distributed and device != 'cpu': - device, device_idx = device.split(':', maxsplit=1) + device, *device_idx = device.split(':', maxsplit=1) # Ignore manually specified device index in distributed mode and # override with resolved local rank, fewer headaches in most setups. if device_idx: - _logger.warning(f'device index {device_idx} removed from specified ({device}).') + _logger.warning(f'device index {device_idx[0]} removed from specified ({device}).') device = f'{device}:{local_rank}'