-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
49 lines (39 loc) · 1.39 KB
/
model.py
File metadata and controls
49 lines (39 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import tensorflow as tf
class Model:
def __init__(self, data, config):
self.data = data
self.config = config
def create(self):
return _create(self.data, self.config)
def _create(data, config):
contextual_inputs, contextual_outputs = \
_create_contextual_inputs(data, config)
sequential_inputs, sequential_outputs, _ = \
_create_sequential_inputs(data, config)
# Construct the model
# ...
# model = tf.keras.Model(
# inputs={**contextual_inputs, **sequential_inputs},
# outputs=outputs,
# )
# model.compile(...)
# return model
def _create_contextual_inputs(data, config):
inputs = {
name: tf.keras.Input(name=name, shape=(), dtype=data.schema[name].kind)
for name in data.contextual_feature_names
}
layer = tf.keras.layers.DenseFeatures(
feature_columns=data.create_feature_columns('contextual'))
return inputs, layer(inputs)
def _create_sequential_inputs(data, config):
inputs = {
name: tf.keras.Input(name=name,
shape=(None,),
dtype=data.schema[name].kind,
sparse=True)
for name in data.sequential_feature_names
}
layer = tf.keras.experimental.SequenceFeatures(
feature_columns=data.create_feature_columns('sequential'))
return (inputs, *layer(inputs))