-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
49 lines (41 loc) · 1.79 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import tensorflow as tf
import tensorflow_datasets as tfds
from os.path import join
def get_image_from_coco(coco):
image = coco['image']
image = tf.cast(image, tf.float32)
image_size = tf.shape(image)[:2]
min_length = tf.reduce_min(image_size)
image_size = image_size * 512 // min_length
image = tf.image.resize(image, image_size)
image = tf.image.random_crop(image, [256, 256, 3]) / 255.0
return image
def get_coco_training_set():
split = tfds.Split.TRAIN
coco = tfds.load(name='coco/2017', split=split)
return coco.map(get_image_from_coco, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def get_coco_test_set():
split = tfds.Split.TEST
coco = tfds.load(name='coco/2017', split=split)
return coco.map(get_image_from_coco, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def get_image_from_wikiart(filename):
image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image, channels=3)
image_size = tf.shape(image)[:2]
min_length = tf.reduce_min(image_size)
image_size = image_size * 512 // min_length
image = tf.image.resize(image, image_size)
image = tf.image.random_crop(image, [256, 256, 3]) / 255.0
return image
def get_wikiart_set(file_dir):
names = tf.data.Dataset.list_files(join(file_dir, "**/**/*.jpg"))
images = names.map(get_image_from_wikiart, num_parallel_calls=tf.data.experimental.AUTOTUNE).apply(tf.data.experimental.ignore_errors())
return images
def get_training_set(style_dir):
coco_train = get_coco_training_set()
wikiart_train = get_wikiart_set(style_dir)
return tf.data.Dataset.zip((coco_train, wikiart_train))
def get_test_set(style_dir):
coco_train = get_coco_test_set()
wikiart_train = get_wikiart_set(style_dir)
return tf.data.Dataset.zip((coco_train, wikiart_train))