In [1]:
import nltk
import re
import numpy as np
from ngram import generate_ngram_models
from nltk.corpus import reuters
from tqdm import tqdm
from project.edit_distance import generate_candidates

# 读取字典库、语料库

In [2]:
vocab = {line.rstrip() for line in open('vocab.txt')}
#用set来存储不用list是因为查找的时候set时间复杂度是O(1),List是O(n)

In [3]:
nltk.download('reuters')
nltk.download('punkt')

[nltk_data] Downloading package reuters to
[nltk_data]     C:\Users\komusama\AppData\Roaming\nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\komusama\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

# 据语料库生成ngram模型

In [4]:
categories = reuters.categories()  # 路透社语料库的类别
corpus = reuters.sents(categories=categories)  # sents()指定分类中的句子

# 构建语言模型：bigram
term_count, bigram_count=generate_ngram_models(corpus,2)

# 从count_1edit.txt获取channel probability

In [5]:
# 用户打错的概率统计 - channel probability
# 创建一个字典来存储channel probabilities
channel_prob = {}
total_errors = 0

i = 0
# 解析错误数据并计算总错误次数
for line in open('count_1edit.txt'):
    i += 1
    print(i)
    # Step1:解析数据
    # 正则表达式找到错误次数
    count = re.findall(r'\d+', line)[-1]
    
    # 从末尾剥离数字
    line = line.replace(count, "")
    # 剥离制表符
    if "\t" in line:
        line = line.replace("\t", "")
    # 判断空格在不在后段
    first, last = line.split("|")

    if " " in last:
        # 去除多个空格为一个    
        if re.match(r" {2,}", line):
            multi_spaces = re.findall(r" {2,}", line)
            for space in multi_spaces:
                line = line.replace(space, " ")
    
    # 正常情况
    correct, mistake = line.split("|")

    count = int(count)
    # Step2:计算错误次数
    if correct not in channel_prob:
        channel_prob[correct] = {}

    channel_prob[correct][mistake] = count
    total_errors += count

# 计算每种错误的概率
for correct in channel_prob:
    for mistake in channel_prob[correct]:
        channel_prob[correct][mistake] /= total_errors

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277


# 测试

In [7]:
import numpy as np
import re
from tqdm import tqdm
from project.edit_distance import generate_candidates

def count_lines(filename):
    with open(filename, 'r') as f:
        return sum(1 for line in f)

# 计算文件的行数
file_path = "testdata.txt"
total_lines = count_lines(file_path)
V = len(term_count.keys())
# 打开文件
with open(file_path, "r") as file:
    results = []
    i = 1
    bar = tqdm(file, total=total_lines, desc="Processing lines")
    
    for line in bar:
        line = re.sub(r"([,.'])", r" \1 ", line)
        items = line.rstrip().split("\t")
        line = items[2].split()
        corrected_line = line
        j = 0
        
        for word in line:
            if word not in vocab:
                # 需要替换word成正确的单词
                # Step1: 生成所有的(valid)候选集合
                # 获得编辑距离小于2的候选列表
                candidates1 = generate_candidates(word, vocab, 1)
                candidates2 = generate_candidates(word, vocab, 2)
                candidates = list(candidates1.union(candidates2))
                probs = []
                
                # 对于每一个candidate, 计算它的score
                # score = p(correct)*p(mistake|correct)
                #       = log p(correct) + log p(mistake|correct)
                # 返回score最大的candidate
                for candi in candidates:
                    prob = 0
                    # 计算channel probability
                    if candi in channel_prob and word in channel_prob[candi]:
                        prob += np.log(channel_prob[candi][word])
                    else:
                        prob += np.log(0.0001)
                    
                    # 计算语言模型的概率
                    """
                   比如s=I like playing football.
                   line=['I','like','playing','football']
                   word为playing时
                   """
                    forward_word = (
                        line[j - 1] + " " + candi
                    )  # 考虑前一个单词,出现like playing的概率

                    if forward_word in bigram_count and line[j - 1] in term_count:
                        prob += np.log(
                            (bigram_count[forward_word] + 1.0)
                            / (term_count[line[j - 1]] + V)
                        )
                    else:
                        prob += np.log(1.0 / V)

                    if j + 1 < len(line):  # 考虑后一个单词，出现playing football的概率
                        backward_word = candi + " " + line[j + 1]
                        if backward_word in bigram_count and candi in term_count:
                            prob += np.log(
                                (bigram_count[backward_word] + 1.0)
                                / (term_count[candi] + V)
                            )
                        else:
                            prob += np.log(1.0 / V)
                    probs.append(prob)

                if probs:
                    max_idx = probs.index(max(probs))
                    if len(word) == 1:
                        corrected_line[j] = word  # 不替换单个字母
                    else:
                        corrected_line[j] = candidates[max_idx]
            j += 1

        corrected_sentence = " ".join(corrected_line)
        corrected_sentence = re.sub(r"\s(['])\s", r"\1", corrected_sentence)  # 去除标点前的空格
        corrected_sentence = re.sub(r"\s([.])\s", r"\1", corrected_sentence)  # 去除标点前的空格
        corrected_sentence = re.sub(r"\s([,])", r"\1", corrected_sentence)  # 去除标点前的空格
        results.append(f"{i}\t{corrected_sentence}")
        i += 1

Processing lines:   0%|          | 3/1000 [00:20<1:51:52,  6.73s/it]


KeyboardInterrupt: 

In [None]:
with open("result.txt", "w") as file:
    file.write("\n".join(results))