-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_rev_entailment_accs.py
63 lines (50 loc) · 2.22 KB
/
get_rev_entailment_accs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import sys
from constants import *
'''
Function: get_entailment_accuracies
-----------------------------------
Takes in a pair of files -- one containing the original premises, and the other containing the
generated hypotheses. Each line in the first file corresponds to each line in the second file.
Outputs each pair of lines one at a time, and for each pair, prompts user to enter in "yes" or "no"
depending on whether or not the generated hypothesis is a valid entailment of the original premise.
Reports the accuracy at the end.
'''
def get_entailment_accuracies(infile1, infile2):
f1 = open(infile1)
f2 = open(infile2)
sentences_f1 = f1.readlines()
sentences_f2 = f2.readlines()
sentence_pairs = zip(sentences_f1, sentences_f2)
valid_responses = set(['yes', 'y', 'no', 'n'])
# Counter for number of valid responses
num_valid = 0
# Counter for number of total sentences
num_sentences = 0
for i, pair in enumerate(sentence_pairs):
if i not in SAMPLE_INDS: continue
sentence_1 = pair[0][:-1]
sentence_2 = pair[1][:-1]
period_location = sentence_2.find('.')
if period_location >= 0: sentence_2 = sentence_2[:period_location + 1]
print '\nSentence ' + str(num_sentences + 1) + '/' + str(len(SAMPLE_INDS))
# Sentence 2 is the premise now
print "PREMISE: {}".format(sentence_2)
print "HYPOTHESIS: {}".format(sentence_1)
while True:
response = raw_input("Valid entailment? (y/n) ").lower()
if response not in valid_responses:
print "Please respond with 'yes' or 'no' (or 'y' or 'n').\n"
else: break
if response[0] == 'y':
num_valid += 1
num_sentences += 1
final_accuracy = float(num_valid) / num_sentences
print "\nDone!"
print "Final accuracy: {}\n".format(final_accuracy)
'''
Argument 1 is either 'train', 'valid', or 'test'. Argument 2 is either 'beam_search', 'output_aware', 'glob_attn', 'loc_attn', 'char_model'.
'''
if __name__ == '__main__':
infile1 = "reverse_ent/" + sys.argv[1] + '_hypothesis.txt'
infile2 = "reverse_ent/" + sys.argv[2] + '/translations_' + sys.argv[1] + '.txt'
get_entailment_accuracies(infile1, infile2)