diff --git a/train.py b/train.py index 96b3c2fdc516..91bcd1e1e2e8 100644 --- a/train.py +++ b/train.py @@ -212,7 +212,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK, workers=workers, image_weights=opt.image_weights, quad=opt.quad, - prefix=colorstr('train: ')) + prefix=colorstr('train: '), shuffle=True) mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class nb = len(train_loader) # number of batches assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}' diff --git a/utils/datasets.py b/utils/datasets.py index f153db0d7104..3504998b125d 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -22,7 +22,7 @@ import torch.nn.functional as F import yaml from PIL import ExifTags, Image, ImageOps -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset, dataloader, distributed from tqdm import tqdm from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective @@ -93,13 +93,15 @@ def exif_transpose(image): def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, - rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''): - # Make sure only the first process in DDP process the dataset first, and the following others can use the cache - with torch_distributed_zero_first(rank): + rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False): + if rect and shuffle: + LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False') + shuffle = False + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = LoadImagesAndLabels(path, imgsz, batch_size, - augment=augment, # augment images - hyp=hyp, # augmentation hyperparameters - rect=rect, # rectangular training + augment=augment, # augmentation + hyp=hyp, # hyperparameters + rect=rect, # rectangular batches cache_images=cache, single_cls=single_cls, stride=int(stride), @@ -109,19 +111,18 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers - sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None - loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader - # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() - dataloader = loader(dataset, - batch_size=batch_size, - num_workers=nw, - sampler=sampler, - pin_memory=True, - collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) - return dataloader, dataset - - -class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): + sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) + loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates + return loader(dataset, + batch_size=batch_size, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=True, + collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset + + +class InfiniteDataLoader(dataloader.DataLoader): """ Dataloader that reuses workers Uses same syntax as vanilla DataLoader