From 08deee3a9d710f0f7fe7f1ae2585155c8cb71994 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 31 Jan 2023 10:12:33 -0500 Subject: [PATCH 1/2] Bugfix: use label smoothing only when torch version is >= 1.10 --- src/sparseml/pytorch/torchvision/train.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 2e888d77bca..6ac5bd40e0e 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -28,6 +28,7 @@ import torch import torch.utils.data import torchvision +from packaging import version from torch import nn from torch.utils.data.dataloader import DataLoader, default_collate from torchvision.transforms.functional import InterpolationMode @@ -408,7 +409,14 @@ def collate_fn(batch): if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + min_torch_version_with_label_smoothing = version.parse("1.10") + label_smoothing_supported = ( + version.parse(torch.__version__) >= min_torch_version_with_label_smoothing + ) + if label_smoothing_supported: + criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + else: + criterion = nn.CrossEntropyLoss() custom_keys_weight_decay = [] if args.bias_weight_decay is not None: From f333861f30ca9ab8f71cdac4c1b7d0b6ac08b1a9 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 31 Jan 2023 14:14:26 -0500 Subject: [PATCH 2/2] Apply suggestions from code review --- src/sparseml/pytorch/torchvision/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 6ac5bd40e0e..7c77ee3e08c 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -409,12 +409,13 @@ def collate_fn(batch): if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - min_torch_version_with_label_smoothing = version.parse("1.10") - label_smoothing_supported = ( - version.parse(torch.__version__) >= min_torch_version_with_label_smoothing - ) - if label_smoothing_supported: + if version.parse(torch.__version__) >= version.parse("1.10"): criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + elif args.label_smoothing > 0: + raise ValueError( + f"`label_smoothing` not supported for {torch.__version__}, " + f"try upgrading to at-least 1.10" + ) else: criterion = nn.CrossEntropyLoss()