-
Notifications
You must be signed in to change notification settings - Fork 0
/
schema.py
57 lines (43 loc) · 1.79 KB
/
schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import tensorflow as tf
from collections import OrderedDict
class Field:
def __init__(self, name, kind, transform=None):
if isinstance(kind, list):
self.scope = 'sequential'
kind = getattr(tf, kind[0])
else:
self.scope = 'contextual'
kind = getattr(tf, kind)
self.name = name
self.kind = kind
self.transform = transform
def to_dense(self, tensor):
value = '-1' if self.kind == tf.string else -1
return tf.sparse.to_dense(tensor, default_value=value)
def to_sparse(self, tensor):
value = '-1' if self.kind == tf.string else -1
index = tf.where(tf.not_equal(tensor, value))
return tf.SparseTensor(index, tf.gather_nd(tensor, index),
tf.shape(tensor, out_type=tf.int64))
def to_feature_column(self, transform):
if self.transform is None:
if self.scope == 'sequential':
function = tf.feature_column.sequence_numeric_column
else:
function = tf.feature_column.numeric_column
return function(self.name)
assert False
class Schema(OrderedDict):
def __init__(self, fields):
fields = list(map(lambda options: Field(**options), fields))
names = map(lambda field: field.name, fields)
super().__init__(zip(names, fields))
def select(self, scope):
return [name for name, field in self.items() if field.scope == scope]
def to_feature_spec(self):
def _process(name):
if self[name].scope == 'sequential':
return tf.io.VarLenFeature(self[name].kind)
else:
return tf.io.FixedLenFeature([], self[name].kind)
return {name: _process(name) for name in self.keys()}