<a href="https://colab.research.google.com/github/neel04/ML_PlayGround/blob/master/mixup_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf

def mixup(images, labels, token_label=None, alpha=0.4, min_mix_weight=0):
    # mix_weight = tfp.distributions.Beta(alpha, alpha).sample([batch_size, 1])
    batch_size = tf.shape(images)[0]
    mix_weight = sample_beta_distribution([batch_size], alpha, alpha)
    mix_weight = tf.maximum(mix_weight, 1.0 - mix_weight)


    # For min_mix_weight=0.1, regard values with `> 0.9` as no mixup, this probability is near `1 - alpha`
    # alpha: no_mixup --> {0.1: 0.8128, 0.2: 0.6736, 0.4: 0.4793, 0.6: 0.3521, 0.8: 0.2636, 1.0: 0.2000}
    if min_mix_weight > 0:
        mix_weight = tf.where(mix_weight > 1 - min_mix_weight, tf.ones_like(mix_weight), mix_weight)


    label_mix_weight = tf.cast(tf.expand_dims(mix_weight, -1), "float32")
    img_mix_weight = tf.cast(tf.reshape(mix_weight, [batch_size, 1, 1, 1]), images.dtype)


    labels = tf.cast(labels, "float32")
    # images = images * img_mix_weight + images[::-1] * (1.0 - img_mix_weight)
    # labels = labels * label_mix_weight + labels[::-1] * (1 - label_mix_weight)
    shuffle_index = tf.random.shuffle(tf.range(batch_size))
    images = images * img_mix_weight + tf.gather(images, shuffle_index) * (1.0 - img_mix_weight)
    labels = labels * label_mix_weight + tf.gather(labels, shuffle_index) * (1 - label_mix_weight)
    if token_label is None:
        return images, labels
    else:
        # token_label shape `[batch, path_height, patch_width, one_hot_labels]`
        token_label = token_label * img_mix_weight + tf.gather(token_label, shuffle_index) * (1 - img_mix_weight)
        return images, labels, token_label

In [3]:
def sample_beta_distribution(shape, concentration_0=0.4, concentration_1=0.4):
    gamma_1_sample = tf.random.gamma(shape=shape, alpha=concentration_1)
    gamma_2_sample = tf.random.gamma(shape=shape, alpha=concentration_0)
    return gamma_1_sample / (gamma_1_sample + gamma_2_sample)

In [None]:
imgs = tf.zeros((6, 512, 512, 6))
labels = tf.ones((6))

output = mixup(imgs, labels)

In [7]:
output[0].shape

TensorShape([6, 512, 512, 6])