diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 2e888d77bca..7c77ee3e08c 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -28,6 +28,7 @@ import torch import torch.utils.data import torchvision +from packaging import version from torch import nn from torch.utils.data.dataloader import DataLoader, default_collate from torchvision.transforms.functional import InterpolationMode @@ -408,7 +409,15 @@ def collate_fn(batch): if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + if version.parse(torch.__version__) >= version.parse("1.10"): + criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + elif args.label_smoothing > 0: + raise ValueError( + f"`label_smoothing` not supported for {torch.__version__}, " + f"try upgrading to at-least 1.10" + ) + else: + criterion = nn.CrossEntropyLoss() custom_keys_weight_decay = [] if args.bias_weight_decay is not None: