diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py index 23e94bc679a6..3fd3fc11ad19 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py @@ -740,6 +740,10 @@ def preprocess_train(examples): # Resize. combined_im = train_resize(combined_im) + # Flipping. + if not args.no_flip and random.random() < 0.5: + combined_im = train_flip(combined_im) + # Cropping. if not args.random_crop: y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0))) @@ -749,11 +753,6 @@ def preprocess_train(examples): y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution)) combined_im = crop(combined_im, y1, x1, h, w) - # Flipping. - if random.random() < 0.5: - x1 = combined_im.shape[2] - x1 - combined_im = train_flip(combined_im) - crop_top_left = (y1, x1) crop_top_lefts.append(crop_top_left) combined_im = normalize(combined_im)