Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch
import torch.utils.data
import torchvision
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader, default_collate
from torchvision.transforms.functional import InterpolationMode
Expand Down Expand Up @@ -408,7 +409,15 @@ def collate_fn(batch):
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
if version.parse(torch.__version__) >= version.parse("1.10"):
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
elif args.label_smoothing > 0:
raise ValueError(
f"`label_smoothing` not supported for {torch.__version__}, "
f"try upgrading to at-least 1.10"
)
else:
criterion = nn.CrossEntropyLoss()

custom_keys_weight_decay = []
if args.bias_weight_decay is not None:
Expand Down