-
Notifications
You must be signed in to change notification settings - Fork 509
/
data_augmentation.py
51 lines (41 loc) · 1.19 KB
/
data_augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Created by Marcel Santos - Intel Intelligent Systems Lab - 2021
This script implements several routines for data augmentation.
"""
import tensorflow as tf
import numpy as np
def augment_img(img):
"""Color augmentation
Args:
img: input image
Returns:
img: augmented image
"""
img = tf.image.random_hue(img, 0.08)
img = tf.image.random_saturation(img, 0.6, 1.6)
img = tf.image.random_brightness(img, 0.05)
img = tf.image.random_contrast(img, 0.7, 1.3)
img = tf.clip_by_value(img, clip_value_min=0.0, clip_value_max=1.0)
return img
def augment_cmd(cmd):
"""
Command augmentation
Args:
cmd: input command
Returns:
cmd: augmented command
"""
if not (cmd > 0 or cmd < 0):
coin = np.random.default_rng().uniform(low=0.0, high=1.0, size=None)
if coin < 0.25:
cmd = -1.0
elif coin < 0.5:
cmd = 1.0
return cmd
def flip_sample(img, cmd, label):
coin = np.random.default_rng().uniform(low=0.0, high=1.0, size=None)
if coin < 0.5:
img = tf.image.flip_left_right(img)
cmd = -cmd
label = tf.reverse(label, axis=[0])
return img, cmd, label