Skip to content

Commit

Permalink
[Minor] Updated ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
zsyzzsoft committed Jul 2, 2020
1 parent f0c1399 commit 96d6d87
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 32 deletions.
10 changes: 5 additions & 5 deletions DiffAugment-biggan-cifar/DiffAugment_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738.pdf
# https://arxiv.org/pdf/2006.10738

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -36,8 +36,8 @@ def rand_contrast(x):
return x


def rand_translation(x, ratio=(1, 8)):
shift_x, shift_y = x.size(2) * ratio[0] // ratio[1], x.size(3) * ratio[0] // ratio[1]
def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
Expand All @@ -52,8 +52,8 @@ def rand_translation(x, ratio=(1, 8)):
return x


def rand_cutout(x, ratio=(1, 2)):
cutout_size = x.size(2) * ratio[0] // ratio[1], x.size(3) * ratio[0] // ratio[1]
def rand_cutout(x, ratio=0.5):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
Expand Down
24 changes: 13 additions & 11 deletions DiffAugment-stylegan2/DiffAugment_tf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738.pdf
# https://arxiv.org/pdf/2006.10738

import tensorflow as tf

Expand Down Expand Up @@ -37,21 +37,23 @@ def rand_contrast(x):
return x


def rand_translation(x, ratio=(1, 8)):
B, H, W = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
translation_x = tf.random.uniform([B, 1], -(W * ratio[0] // ratio[1]), (W * ratio[0] // ratio[1]) + 1, dtype=tf.int32)
translation_y = tf.random.uniform([B, 1], -(H * ratio[0] // ratio[1]), (H * ratio[0] // ratio[1]) + 1, dtype=tf.int32)
grid_x = tf.clip_by_value(tf.expand_dims(tf.range(W, dtype=tf.int32), 0) + translation_x + 1, 0, W + 1)
grid_y = tf.clip_by_value(tf.expand_dims(tf.range(H, dtype=tf.int32), 0) + translation_y + 1, 0, H + 1)
x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1), [0, 2, 1, 3])
x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1)
def rand_translation(x, ratio=0.125):
batch_size = tf.shape(x)[0]
image_size = tf.shape(x)[1:3]
shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32)
translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32)
grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1, 0, image_size[0] + 1)
grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1, 0, image_size[1] + 1)
x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1)
x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1), [0, 2, 1, 3])
return x


def rand_cutout(x, ratio=(1, 2)):
def rand_cutout(x, ratio=0.5):
batch_size = tf.shape(x)[0]
image_size = tf.shape(x)[1:3]
cutout_size = image_size * ratio[0] // ratio[1]
cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32)
offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32)
grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32), tf.range(cutout_size[0], dtype=tf.int32), tf.range(cutout_size[1], dtype=tf.int32), indexing='ij')
Expand Down
10 changes: 5 additions & 5 deletions DiffAugment_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738.pdf
# https://arxiv.org/pdf/2006.10738

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -36,8 +36,8 @@ def rand_contrast(x):
return x


def rand_translation(x, ratio=(1, 8)):
shift_x, shift_y = x.size(2) * ratio[0] // ratio[1], x.size(3) * ratio[0] // ratio[1]
def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
Expand All @@ -52,8 +52,8 @@ def rand_translation(x, ratio=(1, 8)):
return x


def rand_cutout(x, ratio=(1, 2)):
cutout_size = x.size(2) * ratio[0] // ratio[1], x.size(3) * ratio[0] // ratio[1]
def rand_cutout(x, ratio=0.5):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
Expand Down
24 changes: 13 additions & 11 deletions DiffAugment_tf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738.pdf
# https://arxiv.org/pdf/2006.10738

import tensorflow as tf

Expand Down Expand Up @@ -37,21 +37,23 @@ def rand_contrast(x):
return x


def rand_translation(x, ratio=(1, 8)):
B, H, W = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
translation_x = tf.random.uniform([B, 1], -(W * ratio[0] // ratio[1]), (W * ratio[0] // ratio[1]) + 1, dtype=tf.int32)
translation_y = tf.random.uniform([B, 1], -(H * ratio[0] // ratio[1]), (H * ratio[0] // ratio[1]) + 1, dtype=tf.int32)
grid_x = tf.clip_by_value(tf.expand_dims(tf.range(W, dtype=tf.int32), 0) + translation_x + 1, 0, W + 1)
grid_y = tf.clip_by_value(tf.expand_dims(tf.range(H, dtype=tf.int32), 0) + translation_y + 1, 0, H + 1)
x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1), [0, 2, 1, 3])
x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1)
def rand_translation(x, ratio=0.125):
batch_size = tf.shape(x)[0]
image_size = tf.shape(x)[1:3]
shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32)
translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32)
grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1, 0, image_size[0] + 1)
grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1, 0, image_size[1] + 1)
x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1)
x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1), [0, 2, 1, 3])
return x


def rand_cutout(x, ratio=(1, 2)):
def rand_cutout(x, ratio=0.5):
batch_size = tf.shape(x)[0]
image_size = tf.shape(x)[1:3]
cutout_size = image_size * ratio[0] // ratio[1]
cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32)
offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32)
grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32), tf.range(cutout_size[0], dtype=tf.int32), tf.range(cutout_size[1], dtype=tf.int32), indexing='ij')
Expand Down

0 comments on commit 96d6d87

Please sign in to comment.