forked from mlforcada/naive-automatic-postediting
/
extract_types.py
166 lines (120 loc) · 4.51 KB
/
extract_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import argparse
from collections import Counter
"""
===================== Global arguments section =================================
"""
parser = argparse.ArgumentParser(description='Algorithm for classifying extracted postedits')
parser.add_argument('postedits', help='File with extracted postedits')
parser.add_argument('prefix', help='Prefix for language pair, e.g. bel-rus')
args = parser.parse_args()
"""
===================== Main code section =================================
"""
def distance(a, b):
"""Calculates the Levenshtein distance between a and b."""
n, m = len(a), len(b)
if n > m:
a, b = b, a
n, m = m, n
current_row = range(n + 1)
for i in range(1, m + 1):
previous_row, current_row = current_row, [i] + [0] * n
for j in range(1, n + 1):
add, delete, change = previous_row[j] + 1, current_row[j - 1] + 1, previous_row[j - 1]
if a[j - 1] != b[i - 1]:
change += 1
current_row[j] = min(add, delete, change)
return current_row[n]
#приведи это говно в порядок, смотреть больно
def find_bidix(postedits, prefix):
bidix_entries = []
grammar_entries = []
grammar_context = {}
other_entries = {}
other_context = {}
for elem in postedits:
elem = elem.split('\',')
for i in range(len(elem)):
elem[i] = elem[i].strip('(')
elem[i] = elem[i].strip(')')
elem[i] = elem[i].strip(' ')
elem[i] = elem[i].strip('\'')
try:
source = elem[0].split(' ')
mt = elem[1].split(' ')
target = elem[2].split(' ')
except:
pass
#хуйня с апострофами
if len(source) == len(mt) and len(mt) == len(target):
for i in range(len(mt)):
if '*' in mt[i]:
bidix_entrie = '%s\t%s\t%s\t' % (source[i], mt[i], target[i])
bidix_entries.append(bidix_entrie)
continue
dis = distance(mt[i], target[i])
letters = len(target[i])
if letters != 0:
edits_percent = ((letters - dis) / letters) * 100
else:
edits_percent = 0
if edits_percent >= 50 and edits_percent < 100:
grammar_entrie = '%s\t%s\t%s\t' % (source[i], mt[i], target[i])
grammar_entries.append(grammar_entrie)
context = '%s\t%s\t%s\t' % (' '.join(source), ' '.join(mt), ' '.join(target))
if grammar_entrie in grammar_context.keys():
grammar_context[grammar_entrie].append(context)
else:
grammar_context[grammar_entrie] = [context]
elif mt[i] != target[i]:
other_entrie = '%s\t%s\t%s\t' % (source[i], mt[i], target[i])
if source[i] not in other_entries.keys():
other_entries[source[i]] = {}
other_entries[source[i]][mt[i]] = Counter()
other_entries[source[i]][mt[i]][target[i]] = 1
else:
if mt[i] not in other_entries[source[i]].keys():
other_entries[source[i]][mt[i]] = Counter()
other_entries[source[i]][mt[i]][target[i]] = 1
else:
if target[i] not in other_entries[source[i]][mt[i]].keys():
other_entries[source[i]][mt[i]][target[i]] = 1
else:
other_entries[source[i]][mt[i]][target[i]] += 1
else:
continue
bidix_counter = Counter()
grammar_counter = Counter()
for entrie in bidix_entries:
bidix_counter[entrie] += 1
for entrie in grammar_entries:
grammar_counter[entrie] += 1
with open(prefix + '-bidix_entries.txt', 'w', encoding='utf-8') as file:
for elem in bidix_counter.most_common():
file.write('%s%s\n' % (elem[0], str(elem[1])))
with open(prefix + '-grammar_entries.txt', 'w', encoding='utf-8') as file:
for elem in grammar_counter.most_common():
file.write('%s%s\n' % (elem[0], str(elem[1])))
with open(prefix + '-grammar_context.txt', 'w', encoding='utf-8') as file:
for key in grammar_context.keys():
file.write('KEY\n%s\n' % (key))
for value in grammar_context[key]:
file.write('%s\n' % (value))
file.write('\n\n\n')
with open(prefix + '-other_entries.txt', 'w', encoding='utf-8') as file:
for key, value in other_entries.items():
if len(value.keys()) > 1:
for sec_key, sec_value in value.items():
if len(sec_value.keys()) > 1:
v = sum(sec_value.values())
mc = sec_value.most_common(1)
if mc[0][1] * 100 / v > 30 and v > 7:
file.write('%s\t%s\t%s\t%s\t%s\n' % (key, sec_key, mc[0][0], mc[0][1], v))
def main():
postedits = args.postedits
prefix = args.prefix
with open(postedits, 'r', encoding='utf-8') as file:
postedits = file.read().strip('\n').split('\n')
find_bidix(postedits, prefix)
if __name__ == '__main__':
main()