Skip to content

Commit

Permalink
Tweak logic for num_samples / num_shards, bit less redundancy, always…
Browse files Browse the repository at this point in the history
… check # shards in non-resampled
  • Loading branch information
rwightman committed Apr 16, 2023
1 parent 313930c commit 56f2521
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,13 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni
assert input_shards is not None
resampled = getattr(args, 'dataset_resampled', False) and is_train

num_shards = None
if is_train:
if args.train_num_samples is not None:
num_samples = args.train_num_samples
else:
num_samples, num_shards = get_dataset_size(input_shards)
if num_samples is None:
if not num_samples:
raise RuntimeError(
'Currently, the number of dataset samples must be specified for the training dataset. '
'Please specify it via `--train-num-samples` if no dataset length info is present.')
Expand Down Expand Up @@ -389,8 +390,7 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni

if is_train:
if not resampled:
if is_train and args.train_num_samples is not None:
num_shards = get_dataset_size(input_shards)[1]
num_shards = num_shards or len(expand_urls(input_shards)[0])
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers'
# roll over and repeat a few samples to get same number of full batches on each node
round_fn = math.floor if floor else math.ceil
Expand Down

0 comments on commit 56f2521

Please sign in to comment.