In [1]:
!pip install tensorflow-addons

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow-addons
  Downloading tensorflow_addons-0.17.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 5.2 MB/s 
Installing collected packages: tensorflow-addons
Successfully installed tensorflow-addons-0.17.1


In [2]:
import os

os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"

import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa 
import numpy as np
import matplotlib.pyplot as plt
import datetime

tf.config.optimizer.set_jit(True)

In [3]:
[
    (train_features, train_labels),
    (test_features, test_labels),
] = keras.datasets.cifar10.load_data()

train_features = train_features / 255.0
test_features = test_features / 255.0

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [4]:
BATCH_SIZE = 512
IMAGE_SIZE = 32

In [5]:
class Augmentation(keras.layers.Layer):
  def __init__(self):
    super(Augmentation, self).__init__()

  @tf.function
  def random_execute(self, prob: float):
    return tf.random.uniform([], minval=0, maxval=1) < prob

In [6]:
class RandomToGrayscale(Augmentation):
  @tf.function
  def call(self, x: tf.Tensor):
    if self.random_execute(0.2):
      x = tf.image.rgb_to_grayscale(x)
      x =  tf.tile(x, [1, 1, 3])
    return x

In [7]:
class RandomColorJitter(Augmentation):
  @tf.function
  def call(self, x: tf.Tensor):
    if self.random_execute(0.8):
      x = tf.image.random_brightness(x, 0.8)
      x = tf.image.random_contrast(x, 0.4, 1.6)
      x = tf.image.random_saturation(x, 0.4, 1.6)
      x = tf.image.random_hue(x, 0.2)
    
    return x

In [8]:
class RandomFlip(Augmentation):
  @tf.function
  def call(self, x: tf.Tensor):
    if self.random_execute(0.5):
      x = tf.image.random_flip_left_right(x)

    return x

In [9]:
class RandomResizedCrop(Augmentation):
  def __init__(self, image_size):
    super(Augmentation, self).__init__()
    self.image_size = image_size

  def call(self, x: tf.Tensor):
    rand_size = tf.random.uniform(
        shape=[],
        minval=int(0.75 * self.image_size),
        maxval=1 * self.image_size,
        dtype=tf.int32,
    )

    crop = tf.image.random_crop(x, (rand_size, rand_size, self.image_size))
    crop_resize = tf.image.resize(crop, (self.image_size, self.image_size))

    return crop_resize

In [10]:
class RandomSolarize(Augmentation):
  @tf.function
  def call(self, x: tf.Tensor):
    if self.random_execute(0.2):
      x = tf.where(x < 10, x, 255 - x)
    
    return x

In [11]:
class RandomBlur(Augmentation):
  @tf.function
  def call(self, x: tf.Tensor):
    if self.random_execute(0.2):
      s = np.random.random()
      return tfa.image.gaussian_filter2d(image=x, sigma=s)
    return x