-
Notifications
You must be signed in to change notification settings - Fork 674
/
dual_encoder.py
85 lines (68 loc) · 2.8 KB
/
dual_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import tensorflow as tf
import numpy as np
from models import helpers
FLAGS = tf.flags.FLAGS
def get_embeddings(hparams):
if hparams.glove_path and hparams.vocab_path:
tf.logging.info("Loading Glove embeddings...")
vocab_array, vocab_dict = helpers.load_vocab(hparams.vocab_path)
glove_vectors, glove_dict = helpers.load_glove_vectors(hparams.glove_path, vocab=set(vocab_array))
initializer = helpers.build_initial_embedding_matrix(vocab_dict, glove_dict, glove_vectors, hparams.embedding_dim)
else:
tf.logging.info("No glove/vocab path specificed, starting with random embeddings.")
initializer = tf.random_uniform_initializer(-0.25, 0.25)
return tf.get_variable(
"word_embeddings",
shape=[hparams.vocab_size, hparams.embedding_dim],
initializer=initializer)
def dual_encoder_model(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
targets):
# Initialize embedidngs randomly or with pre-trained vectors if available
embeddings_W = get_embeddings(hparams)
# Embed the context and the utterance
context_embedded = tf.nn.embedding_lookup(
embeddings_W, context, name="embed_context")
utterance_embedded = tf.nn.embedding_lookup(
embeddings_W, utterance, name="embed_utterance")
# Build the RNN
with tf.variable_scope("rnn") as vs:
# We use an LSTM Cell
cell = tf.nn.rnn_cell.LSTMCell(
hparams.rnn_dim,
forget_bias=2.0,
use_peepholes=True,
state_is_tuple=True)
# Run the utterance and context through the RNN
rnn_outputs, rnn_states = tf.nn.dynamic_rnn(
cell,
tf.concat(0, [context_embedded, utterance_embedded]),
sequence_length=tf.concat(0, [context_len, utterance_len]),
dtype=tf.float32)
encoding_context, encoding_utterance = tf.split(0, 2, rnn_states.h)
with tf.variable_scope("prediction") as vs:
M = tf.get_variable("M",
shape=[hparams.rnn_dim, hparams.rnn_dim],
initializer=tf.truncated_normal_initializer())
# "Predict" a response: c * M
generated_response = tf.matmul(encoding_context, M)
generated_response = tf.expand_dims(generated_response, 2)
encoding_utterance = tf.expand_dims(encoding_utterance, 2)
# Dot product between generated response and actual response
# (c * M) * r
logits = tf.batch_matmul(generated_response, encoding_utterance, True)
logits = tf.squeeze(logits, [2])
# Apply sigmoid to convert logits to probabilities
probs = tf.sigmoid(logits)
if mode == tf.contrib.learn.ModeKeys.INFER:
return probs, None
# Calculate the binary cross-entropy loss
losses = tf.nn.sigmoid_cross_entropy_with_logits(logits, tf.to_float(targets))
# Mean loss across the batch of examples
mean_loss = tf.reduce_mean(losses, name="mean_loss")
return probs, mean_loss