Skip to content

Commit

Permalink
Formatting, wrap a few really long lines
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Apr 16, 2023
1 parent 56f2521 commit 0b3aa4a
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def expand_urls(urls, weights=None):
if isinstance(urls, str):
urllist = urls.split("::")
weights = weights.split('::')
assert len(weights) == len(urllist), f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match."
assert len(weights) == len(urllist),\
f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match."
weights = [float(weight) for weight in weights]
all_urls, all_weights = [], []
for url, weight in zip(urllist, weights):
Expand Down Expand Up @@ -291,7 +292,8 @@ def __init__(
self.urls = urls
self.weights = weights
if self.weights is not None:
assert len(self.urls) == len(self.weights), f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match."
assert len(self.urls) == len(self.weights),\
f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match."
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.rng = random.Random()
Expand Down Expand Up @@ -345,9 +347,15 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc

if resampled:
pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)]
pipeline = [ResampledShards2(
input_shards,
weights=args.train_data_upsampling_factors,
deterministic=True,
epoch=shared_epoch,
)]
else:
assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)."
assert args.train_data_upsampling_factors is None,\
"--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)."
pipeline = [wds.SimpleShardList(input_shards)]

# at this point we have an iterator over all the shards
Expand Down Expand Up @@ -466,7 +474,14 @@ def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):

class SyntheticDataset(Dataset):

def __init__(self, transform=None, image_size=(224, 224), caption="Dummy caption", dataset_size=100, tokenizer=None):
def __init__(
self,
transform=None,
image_size=(224, 224),
caption="Dummy caption",
dataset_size=100,
tokenizer=None,
):
self.transform = transform
self.image_size = image_size
self.caption = caption
Expand Down

0 comments on commit 0b3aa4a

Please sign in to comment.