Skip to content

Commit

Permalink
Minor clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
julianser committed Nov 29, 2016
1 parent a2e01b6 commit 0243a03
Showing 1 changed file with 47 additions and 10 deletions.
57 changes: 47 additions & 10 deletions dialog_encdec.py
Expand Up @@ -1108,7 +1108,7 @@ def build_encoder(self, h, x, xmask=None, latent_variable_mask=None, prev_state=

# Just one step further
else:
_res = f_hier(h, xmask, hs_0)
_res = f_hier(h_out, xmask, hs_0)

if isinstance(_res, list) or isinstance(_res, tuple):
hs = _res[0]
Expand Down Expand Up @@ -1428,7 +1428,7 @@ def build_encoder_function(self):

if self.direct_connection_between_encoders_and_decoder:
hs_dummy = self.dialog_dummy_encoder.build_encoder(h, self.x_data)
hs_complete = T.concatenate([hs, h], axis=2)
hs_complete = T.concatenate([hs, hs_dummy], axis=2)

else:
hs_complete = hs
Expand Down Expand Up @@ -1469,12 +1469,12 @@ def build_encoder_function(self):
if self.condition_latent_variable_on_dcgm_encoder:
logger.debug("Build dcgm encoder")
latent_dcgm_res, latent_dcgm_avg, latent_dcgm_n = self.dcgm_encoder.build_encoder(self.x_data, prev_state=[platent_dcgm_avg, platent_dcgm_n])
h_future = self.utterance_encoder_RollLeft.build_encoder( \
h_future = self.utterance_encoder_rolledleft.build_encoder( \
latent_dcgm_res, \
self.x_data)

else:
h_future = self.utterance_encoder_RollLeft.build_encoder( \
h_future = self.utterance_encoder_rolledleft.build_encoder( \
h, \
self.x_data)

Expand All @@ -1491,17 +1491,54 @@ def build_encoder_function(self):
latent_utterance_variable_approx_posterior_mean = _posterior_out[1]
latent_utterance_variable_approx_posterior_var = _posterior_out[2]

#
# NEW STUFF BEGIN
#
training_y = self.x_data[1:self.x_max_length]
if self.direct_connection_between_encoders_and_decoder:
logger.debug("Build dialog dummy encoder")
hs_dummy = self.dialog_dummy_encoder.build_encoder(h, self.x_data, xmask=training_hs_mask)

logger.debug("Build decoder (NCE) with direct connection from encoder(s)")
if self.add_latent_gaussian_per_utterance:
if self.condition_decoder_only_on_latent_variable:
hd_input = latent_utterance_variable_approx_posterior_mean
else:
hd_input = T.concatenate([hs, hs_dummy, latent_utterance_variable_approx_posterior_mean], axis=2)
else:
hd_input = T.concatenate([hs, hs_dummy], axis=2)

_, hd, _, _ = self.utterance_decoder.build_decoder(hd_input, self.x_data, y=training_y, mode=UtteranceDecoder.EVALUATION, prev_state=self.phd)

else:
if self.add_latent_gaussian_per_utterance:
if self.condition_decoder_only_on_latent_variable:
hd_input = latent_utterance_variable_approx_posterior_mean
else:
hd_input = T.concatenate([hs, latent_utterance_variable_approx_posterior_mean], axis=2)
else:
hd_input = hs

logger.debug("Build decoder (EVAL)")
_, hd, _, _ = self.utterance_decoder.build_decoder(hd_input, self.x_data, y=training_y, mode=UtteranceDecoder.EVALUATION, prev_state=self.phd)

#
# NEW STUFF END
#



if self.add_latent_gaussian_per_utterance:
self.encoder_fn = theano.function(inputs=[self.x_data, self.x_data_reversed, \
self.x_max_length], \
outputs=[h, hs_complete], on_unused_input='warn', name="encoder_fn")
outputs=[h, hs_complete, hd], on_unused_input='warn', name="encoder_fn")
#self.encoder_fn = theano.function(inputs=[self.x_data, self.x_data_reversed, \
# self.x_max_length], \
# outputs=[h, hs_complete, hs_and_h_future, latent_utterance_variable_approx_posterior_mean], on_unused_input='warn', name="encoder_fn")
else:
self.encoder_fn = theano.function(inputs=[self.x_data, self.x_data_reversed, \
self.x_max_length], \
outputs=[h, hs_complete], on_unused_input='warn', name="encoder_fn")
outputs=[h, hs_complete, hd], on_unused_input='warn', name="encoder_fn")


return self.encoder_fn
Expand Down Expand Up @@ -1753,9 +1790,9 @@ def __init__(self, state):
# Retrieve hidden state at the end of next utterance from the utterance encoders
# (or at the end of the batch, if there are no end-of-token symbols at the end of the batch)
if self.bidirectional_utterance_encoder:
self.utterance_encoder_RollLeft = DialogLevelRollLeft(self.state, self.qdim_encoder, self.rng, self)
self.utterance_encoder_rolledleft = DialogLevelRollLeft(self.state, self.qdim_encoder, self.rng, self)
else:
self.utterance_encoder_RollLeft = DialogLevelRollLeft(self.state, self.qdim_encoder*2, self.rng, self)
self.utterance_encoder_rolledleft = DialogLevelRollLeft(self.state, self.qdim_encoder*2, self.rng, self)

if self.condition_latent_variable_on_dcgm_encoder:
logger.debug("Initializing dcgm encoder for conditioning input to the utterance-level latent variable")
Expand All @@ -1764,13 +1801,13 @@ def __init__(self, state):
logger.debug("Build dcgm encoder")
latent_dcgm_res, self.latent_dcgm_avg, self.latent_dcgm_n = self.dcgm_encoder.build_encoder(training_x, xmask=training_hs_mask, prev_state=[self.platent_dcgm_avg, self.platent_dcgm_n])

self.h_future = self.utterance_encoder_RollLeft.build_encoder( \
self.h_future = self.utterance_encoder_rolledleft.build_encoder( \
latent_dcgm_res, \
training_x, \
xmask=training_hs_mask)

else:
self.h_future = self.utterance_encoder_RollLeft.build_encoder( \
self.h_future = self.utterance_encoder_rolledleft.build_encoder( \
self.h, \
training_x, \
xmask=training_hs_mask)
Expand Down

2 comments on commit 0243a03

@plison
Copy link

@plison plison commented on 0243a03 Dec 1, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit seems to break the training process - during the initial sampling phase, there is a mismatch in the input dimensions.

For instance, if I run the following training command (where the test has a batch size of 5):
THEANO_FLAGS=mode=FAST_RUN,device=cuda,floatX=float32,dnn.enabled=True python2 train.py --prototype prototype_test > Model_Output.txt

I then get the following error:
ValueError: Input dimension mis-match. (input[0].shape[0] = 1, input[1].shape[0] = 5)

This error did not occur in the previous versions.

@julianser
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting this! I've made a new commit which should fix the bug.

Please sign in to comment.