Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
Luolc committed Aug 27, 2018
1 parent 82f14ee commit 3f8f31d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion model/auto_encoder.py
Expand Up @@ -7,7 +7,7 @@

class AutoEncoderModel(Seq2SeqModel):
def __init__(self, name: str = 'auto'):
super(AutoEncoderModel, self).__init__(name)
super(AutoEncoderModel, self).__init__()

self.q_dec = None
self.q_target = None
Expand Down
2 changes: 1 addition & 1 deletion model/base.py
Expand Up @@ -14,7 +14,7 @@

class BasicChatbotModel(BasicModel):
def __init__(self, name: str = 'basic_chatbot_model'):
super(BasicChatbotModel, self).__init__(name)
super(BasicChatbotModel, self).__init__()

self.enc_length = None
self.dec_length = None
Expand Down
12 changes: 6 additions & 6 deletions model/seq.py
Expand Up @@ -30,7 +30,7 @@ def __call__(self, x):

class Seq2SeqModel(BasicChatbotModel):
def __init__(self, name: str = 'seq'):
super(Seq2SeqModel, self).__init__(name)
super(Seq2SeqModel, self).__init__()
self.q_enc = None
self.a_dec = None
self.a_target = None
Expand Down Expand Up @@ -166,10 +166,10 @@ def evaluate(self):
self._restore_checkpoint()

print('Start testing ...')
# batches = self._get_batches('test')
batches = self._get_batches('test')
test_samples = [qa for qa, _ in self.dataset['test_samples']]
test_samples_text = [text for _, text in self.dataset['test_samples']]
batches = [self._create_batch(test_samples)]
# batches = [self._create_batch(test_samples)]
all_inputs = []
all_outputs = []
all_references = []
Expand All @@ -185,10 +185,10 @@ def evaluate(self):
all_inputs += np.transpose(np.array(batch.q_enc_seq)).tolist()
all_references += np.transpose(np.array(batch.a_target_seq)).tolist()

self._write_test_samples_literal(test_samples_text, all_outputs)
self._write_test_samples_results(all_outputs)
# self._write_test_samples_literal(test_samples_text, all_outputs)
# self._write_test_samples_results(all_outputs)
# self._wirte_test_literal(all_inputs, all_outputs)
# self._write_evaluation_results(all_outputs, all_references)
self._write_evaluation_results(all_outputs, all_references)

def _write_test_samples_results(self, outputs: List[List[int]]):
results = {}
Expand Down

0 comments on commit 3f8f31d

Please sign in to comment.