-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
124 lines (111 loc) · 5.06 KB
/
data.py
File metadata and controls
124 lines (111 loc) · 5.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import tensorflow as tf
import tensorflow_transform as tft
from forecast.schema import Schema
from forecast.support import list_files
class Data:
def __init__(self, config):
# Find the path with the latest preprocessed data by listing all entries
# matching a given pattern and take the last one, assuming that the
# alphabetical order is also chronological
self.path = list_files(config['path'])[-1]
# Find the path of the latest transforms using the same strategy
transform_path = list_files(config['transform_path'])[-1]
# Create a schema according to the configuration file
self.schema = Schema(config['schema'])
# Find the names of contextual columns
self.contextual_names = self.schema.select('contextual')
# Find the names of sequential columns
self.sequential_names = self.schema.select('sequential')
self.modes = config['modes']
# Load all transforms
transforms = [mode['transform'] for mode in self.modes.values()]
self.transforms = {
name: tft.TFTransformOutput( \
os.path.join(transform_path, name, 'transform'))
for name in set(transforms)
}
def create(self, name):
config = self.modes[name]
def _preprocess_transformed(proto):
spec = self.transforms[config['transform']] \
.transformed_feature_spec()
example = tf.io.parse_single_example(proto, spec)
return (
{name: example[name] for name in self.contextual_names},
{
# Convert the sequential columns from sparse to dense
name: self.schema[name].to_dense(example[name])
for name in self.sequential_names
},
)
def _preprocess_untransformed(proto):
spec = self.schema.to_feature_spec()
example = tf.parse_single_example(proto, spec)
for name in self.contextual_names:
example[name] = tf.expand_dims(example[name], -1)
for name in self.sequential_names:
example[name] = tf.sparse.expand_dims(example[name], -1)
example = self.transforms[config['transform']] \
.transform_raw_features(example)
return (
{
name: tf.reshape(example[name], [-1])
for name in self.contextual_names
},
{
# Convert the sequential columns from sparse to dense
name: tf.reshape( \
self.schema[name].to_dense(example[name]), [-1])
for name in self.sequential_names
},
)
def _postprocess(contextual, sequential):
sequential = {
# Convert the sequential columns from dense to sparse
name: self.schema[name].to_sparse(sequential[name])
for name in self.sequential_names
}
return {**contextual, **sequential}
def _shape():
return (
{name: tf.TensorShape([]) for name in self.contextual_names},
{
name: tf.TensorShape([None])
for name in self.sequential_names
},
)
# List all files matching a given pattern
pattern = [self.path, name, 'examples', 'part-*']
dataset = tf.data.Dataset.list_files(os.path.join(*pattern))
# Shuffle the files if needed
if 'shuffle_macro' in config:
dataset = dataset.shuffle(**config['shuffle_macro'])
# Convert the files into datasets of examples stored as TFRecords and
# amalgamate these datasets into one dataset of examples
dataset = dataset \
.interleave(tf.data.TFRecordDataset, **config['interleave'])
# Shuffle the examples if needed
if 'shuffle_micro' in config:
dataset = dataset.shuffle(**config['shuffle_micro'])
# Preprocess the examples with respect to a given spec, pad the examples
# and form batches of different sizes, and postprocess the batches
if config.get('transformed', False):
_preprocess = _preprocess_transformed
else:
_preprocess = _preprocess_untransformed
dataset = dataset \
.map(_preprocess, **config['map']) \
.padded_batch(padded_shapes=_shape(), **config['batch']) \
.map(_postprocess, **config['map'])
# Prefetch the batches if needed
if 'prefetch' in config:
dataset = dataset.prefetch(**config['prefetch'])
# Repeat the data once the source is exhausted if needed
if 'repeat' in config:
dataset = dataset.repeat(**config['repeat'])
return dataset
def create_feature_columns(self, scope):
def _process(name):
return self.schema[name].to_feature_column(self.transform)
return list(map(_process, getattr(self, scope + '_feature_names')))