-
Notifications
You must be signed in to change notification settings - Fork 949
/
model_fn.py
114 lines (90 loc) · 4.48 KB
/
model_fn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Define the model."""
import tensorflow as tf
def build_model(mode, inputs, params):
"""Compute logits of the model (output distribution)
Args:
mode: (string) 'train', 'eval', etc.
inputs: (dict) contains the inputs of the graph (features, labels...)
this can be `tf.placeholder` or outputs of `tf.data`
params: (Params) contains hyperparameters of the model (ex: `params.learning_rate`)
Returns:
output: (tf.Tensor) output of the model
"""
sentence = inputs['sentence']
if params.model_version == 'lstm':
# Get word embeddings for each token in the sentence
embeddings = tf.get_variable(name="embeddings", dtype=tf.float32,
shape=[params.vocab_size, params.embedding_size])
sentence = tf.nn.embedding_lookup(embeddings, sentence)
# Apply LSTM over the embeddings
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(params.lstm_num_units)
output, _ = tf.nn.dynamic_rnn(lstm_cell, sentence, dtype=tf.float32)
# Compute logits from the output of the LSTM
logits = tf.layers.dense(output, params.number_of_tags)
else:
raise NotImplementedError("Unknown model version: {}".format(params.model_version))
return logits
def model_fn(mode, inputs, params, reuse=False):
"""Model function defining the graph operations.
Args:
mode: (string) 'train', 'eval', etc.
inputs: (dict) contains the inputs of the graph (features, labels...)
this can be `tf.placeholder` or outputs of `tf.data`
params: (Params) contains hyperparameters of the model (ex: `params.learning_rate`)
reuse: (bool) whether to reuse the weights
Returns:
model_spec: (dict) contains the graph operations or nodes needed for training / evaluation
"""
is_training = (mode == 'train')
labels = inputs['labels']
sentence_lengths = inputs['sentence_lengths']
# -----------------------------------------------------------
# MODEL: define the layers of the model
with tf.variable_scope('model', reuse=reuse):
# Compute the output distribution of the model and the predictions
logits = build_model(mode, inputs, params)
predictions = tf.argmax(logits, -1)
# Define loss and accuracy (we need to apply a mask to account for padding)
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
mask = tf.sequence_mask(sentence_lengths)
losses = tf.boolean_mask(losses, mask)
loss = tf.reduce_mean(losses)
accuracy = tf.reduce_mean(tf.cast(tf.equal(labels, predictions), tf.float32))
# Define training step that minimizes the loss with the Adam optimizer
if is_training:
optimizer = tf.train.AdamOptimizer(params.learning_rate)
global_step = tf.train.get_or_create_global_step()
train_op = optimizer.minimize(loss, global_step=global_step)
# -----------------------------------------------------------
# METRICS AND SUMMARIES
# Metrics for evaluation using tf.metrics (average over whole dataset)
with tf.variable_scope("metrics"):
metrics = {
'accuracy': tf.metrics.accuracy(labels=labels, predictions=predictions),
'loss': tf.metrics.mean(loss)
}
# Group the update ops for the tf.metrics
update_metrics_op = tf.group(*[op for _, op in metrics.values()])
# Get the op to reset the local variables used in tf.metrics
metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="metrics")
metrics_init_op = tf.variables_initializer(metric_variables)
# Summaries for training
tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
# -----------------------------------------------------------
# MODEL SPECIFICATION
# Create the model specification and return it
# It contains nodes or operations in the graph that will be used for training and evaluation
model_spec = inputs
variable_init_op = tf.group(*[tf.global_variables_initializer(), tf.tables_initializer()])
model_spec['variable_init_op'] = variable_init_op
model_spec["predictions"] = predictions
model_spec['loss'] = loss
model_spec['accuracy'] = accuracy
model_spec['metrics_init_op'] = metrics_init_op
model_spec['metrics'] = metrics
model_spec['update_metrics'] = update_metrics_op
model_spec['summary_op'] = tf.summary.merge_all()
if is_training:
model_spec['train_op'] = train_op
return model_spec