Skip to content
Permalink
Browse files

fix sample generation

  • Loading branch information...
yookoon committed Apr 27, 2018
1 parent 6f9be81 commit 29460b24813c16dbbc57da8c23f139aecaa4e85d
@@ -1,5 +1,5 @@
/etc/
/datasets/
datasets/
/cornell_movie_dialogue/
*.orig
*.lprof
0 LICENSE 100644 → 100755
No changes.
0 Readme.md 100644 → 100755
No changes.
0 cornell_preprocess.py 100644 → 100755
No changes.
@@ -442,6 +442,8 @@ def train(self):
self.model.train()
n_total_words = 0

self.evaluate()

for batch_i, (conversations, conversation_length, sentence_length) \
in enumerate(tqdm(self.train_data_loader, ncols=80)):
# conversations: (batch_size) list of conversations
@@ -537,7 +539,7 @@ def train(self):
return epoch_loss_history

def generate_sentence(self, sentences, sentence_length,
input_conversation_length, target_sentences):
input_conversation_length, input_sentences, target_sentences):
"""Generate output of decoder (single batch)"""
self.model.eval()

@@ -549,7 +551,8 @@ def generate_sentence(self, sentences, sentence_length,
target_sentences,
decode=True)

input_sentences = sentences[:-1]
# input_sentences = sentences[:-1]
import ipdb; ipdb.set_trace()

# write output to file
with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f:
@@ -599,9 +602,13 @@ def evaluate(self):
target_sentence_length = to_var(torch.LongTensor(target_sentence_length), eval=True)

if batch_i == 0:
input_conversations = [conv[:-1] for conv in conversations]
input_sentences = [sent for conv in input_conversations for sent in conv]
input_sentences = to_var(torch.LongTensor(input_sentences), eval=True)
self.generate_sentence(sentences,
sentence_length,
input_conversation_length,
input_sentences,
target_sentences)

sentence_logits, kl_div, _, _ = self.model(
0 requirements.txt 100644 → 100755
No changes.
0 ubuntu_preprocess.py 100644 → 100755
No changes.

0 comments on commit 29460b2

Please sign in to comment.
You can’t perform that action at this time.