
## HMM模型训练

使用PKU98语料训练HMM模型的初始概率向量，状态转移概率矩阵，发射概率矩阵和每个字可能的状态矩阵

In [3]:
import os
PKU98 = "pku98"
PKU199801 = os.path.join(PKU98, '199801.txt')
PKU199801_TRAIN = os.path.join(PKU98, '199801-train.txt')
PKU199801_TEST = os.path.join(PKU98, '199801-test.txt')

MIN_FLOAT = -3.14e100

### 转换语料格式

将语料中的每个句子的分词结果进行转换，每句话放到一个元组中，第一个元素是每个字符构成的列表，第二个元素是每个字符的分词标签。所有的句子放到一个列表中。

In [5]:
# 转换前去掉预料中标注的命名实体
import re 
fw = open("pku98/199801-train-cleaned.txt","w",encoding="utf-8")
with open(PKU199801_TRAIN,encoding="utf-8") as fr:
    for line in fr:
        line = re.sub(r"\[","",line)
        line = re.sub(r"\]/[a-z]","",line)
        fw.write(line)
fw.close()

In [6]:
def character_tagging(input_file):
    input_data = open(input_file, 'r', encoding='utf-8')
    ret_list = []
    for line in input_data.readlines():
        char_list = []
        tag_list = []
        word_list = line.strip().split()
        if len(word_list) == 0:
            continue
        for word in word_list:
            word,pos = word.split("/")
            if len(word) == 1:
                char_list.append(word)
                tag_list.append(("S",pos))
            elif len(word) >= 2:
                char_list.append(word[0])
                tag_list.append(("B",pos))
                for w in word[1: len(word)-1]:
                    char_list.append(w)
                    tag_list.append(("M",pos))
                tag_list.append(("E",pos))
                char_list.append(word[-1])
                # 声明两个列表的长度相等，不相等报错
                assert(len(char_list)==len(tag_list))
        ret_list.append((char_list,tag_list))
    input_data.close()
    return ret_list

In [8]:
ret_list = character_tagging("199801-train-cleaned.txt")

In [9]:
print(ret_list[0])

(['迈', '向', '充', '满', '希', '望', '的', '新', '世', '纪', '—', '—', '一', '九', '九', '八', '年', '新', '年', '讲', '话', '（', '附', '图', '片', '１', '张', '）'], [('B', 'v'), ('E', 'v'), ('B', 'v'), ('E', 'v'), ('B', 'n'), ('E', 'n'), ('S', 'u'), ('S', 'a'), ('B', 'n'), ('E', 'n'), ('B', 'w'), ('E', 'w'), ('B', 't'), ('M', 't'), ('M', 't'), ('M', 't'), ('E', 't'), ('B', 't'), ('E', 't'), ('B', 'n'), ('E', 'n'), ('S', 'w'), ('S', 'v'), ('B', 'n'), ('E', 'n'), ('S', 'm'), ('S', 'q'), ('S', 'w')])


### 统计频率

统计每个字可能的状态，初始状态发生的频率，状态转移频率和发射频率，分别放到四个字典中

In [12]:
import traceback

def create_empty_dict(keys):
    ret_dict = {}
    for key in keys:
        ret_dict[key] = 0
    return ret_dict    

char_list = []
for item in ret_list:
    char_list.extend(item[0])
char_set = set(char_list)

state_list = []
for item in ret_list:
    state_list.extend(item[1])
state_set = set(state_list)

# 每个字可能的状态
char_state_dict = {}
# 每个状态作为初始状态的频率
initial_state_dict = create_empty_dict(state_set)
# 状态转移频率
transition_dict = {}
# 每个状态的发射频率
emission_dict = {}

for item in ret_list:
    try:
        # 统计每个字可能的状态
        for i in range(len(item[1])):
            char = item[0][i]
            state = item[1][i]
            if char not in char_state_dict:
                # 使用字典保存可能的state，可以自动去重
                char_state_dict[char] = {}
            char_state_dict[char][state] = 1
        
        # 计算初始状态频率
        first_state = item[1][0]
        initial_state_dict[first_state] += 1
        
        # 计算状态转移频率
        for i in range(len(item[1])-1):
            left_state = item[1][i]
            if left_state not in transition_dict:
                transition_dict[left_state] = {}
            right_state = item[1][i+1]
            transition_dict[left_state][right_state] = transition_dict[left_state].setdefault(right_state, 0) + 1
        
        # 计算发射频率
        for i in range(len(item[1])):
            state = item[1][i]
            if state not in emission_dict:
                emission_dict[state] = {}
            char = item[0][i]
            if char not in emission_dict[state]:
                emission_dict[state][char] = 0
            emission_dict[state][char] += 1      
    except:
        print(item)
        print(traceback.format_exc())

In [14]:
for key in char_state_dict:
    char_state_dict[key] = tuple(char_state_dict[key])

In [15]:
print(char_state_dict["迈"])

(('B', 'v'), ('S', 'v'), ('M', 'nr'), ('B', 'nr'), ('E', 'ns'), ('E', 'nr'), ('E', 'a'), ('B', 'ns'), ('E', 'z'), ('E', 'v'), ('M', 'ns'))


In [62]:
print(initial_state_dict[('B', 'v')])

1163


In [63]:
import pprint
pprint.pprint(transition_dict)

{('B', 'a'): {('E', 'a'): 19535, ('M', 'a'): 58},
 ('B', 'ad'): {('E', 'ad'): 5034, ('M', 'ad'): 2},
 ('B', 'an'): {('E', 'an'): 2458},
 ('B', 'b'): {('E', 'b'): 5615, ('M', 'b'): 626},
 ('B', 'c'): {('E', 'c'): 5732, ('M', 'c'): 299},
 ('B', 'd'): {('E', 'd'): 15346, ('M', 'd'): 1033},
 ('B', 'e'): {('E', 'e'): 4},
 ('B', 'f'): {('E', 'f'): 4049, ('M', 'f'): 72},
 ('B', 'i'): {('M', 'i'): 4498},
 ('B', 'j'): {('E', 'j'): 2994, ('M', 'j'): 2142},
 ('B', 'jt'): {('E', 'jt'): 552, ('M', 'jt'): 314},
 ('B', 'jz'): {('E', 'jz'): 4, ('M', 'jz'): 24},
 ('B', 'l'): {('E', 'l'): 23, ('M', 'l'): 5353},
 ('B', 'ls'): {('M', 'ls'): 1},
 ('B', 'lt'): {('M', 'lt'): 73},
 ('B', 'm'): {('E', 'm'): 12117, ('M', 'm'): 8354},
 ('B', 'n'): {('E', 'n'): 166289, ('M', 'n'): 26937},
 ('B', 'na'): {('M', 'na'): 1},
 ('B', 'nr'): {('E', 'nr'): 4033, ('M', 'nr'): 13647},
 ('B', 'nrt'): {('E', 'nrt'): 2},
 ('B', 'ns'): {('E', 'ns'): 16937, ('M', 'ns'): 8560},
 ('B', 'nss'): {('E', 'nss'): 4, ('M', 'nss'): 12},


 ('S', 'j'): {('B', 'a'): 31,
              ('B', 'ad'): 8,
              ('B', 'an'): 4,
              ('B', 'b'): 12,
              ('B', 'c'): 2,
              ('B', 'd'): 19,
              ('B', 'f'): 45,
              ('B', 'i'): 3,
              ('B', 'j'): 28,
              ('B', 'jt'): 11,
              ('B', 'l'): 10,
              ('B', 'm'): 22,
              ('B', 'n'): 744,
              ('B', 'ns'): 19,
              ('B', 'nt'): 45,
              ('B', 'nz'): 67,
              ('B', 'nzt'): 1,
              ('B', 'p'): 4,
              ('B', 'r'): 9,
              ('B', 's'): 22,
              ('B', 't'): 20,
              ('B', 'v'): 341,
              ('B', 'vd'): 3,
              ('B', 'vn'): 74,
              ('S', 'Ng'): 9,
              ('S', 'Vg'): 1,
              ('S', 'a'): 13,
              ('S', 'b'): 2,
              ('S', 'c'): 18,
              ('S', 'd'): 49,
              ('S', 'f'): 22,
              ('S', 'j'): 919,
              ('S', 'm'): 206,
     

### 计算概率

注意这里计算概率值的同时取了对数

#### 计算初始状态概率向量

In [64]:
import math
prob_start = {}
for key in initial_state_dict:
    total_freq = sum(initial_state_dict.values())
    if initial_state_dict[key] != 0:
        prob_start[key] = math.log(initial_state_dict[key]/total_freq)
    else:
        prob_start[key] = MIN_FLOAT

In [65]:
prob_start

{('B', 'f'): -5.900753151666388,
 ('B', 'z'): -6.636459946645129,
 ('B', 'r'): -2.4317673272541627,
 ('B', 'w'): -5.454466049037968,
 ('M', 'v'): -3.14e+100,
 ('S', 'Ag'): -8.67334187390617,
 ('B', 'an'): -7.692512620894442,
 ('E', 'vnt'): -3.14e+100,
 ('E', 'vn'): -3.14e+100,
 ('E', 'nst'): -3.14e+100,
 ('B', 'a'): -5.108515068462212,
 ('S', 'Ng'): -7.374058889775908,
 ('E', 'q'): -3.14e+100,
 ('S', 'w'): -2.6852162280637013,
 ('S', 'c'): -6.306218259774552,
 ('S', 'Tg'): -7.692512620894442,
 ('S', 'j'): -4.904419712118696,
 ('M', 's'): -3.14e+100,
 ('E', 't'): -3.14e+100,
 ('M', 'p'): -3.14e+100,
 ('S', 'Mg'): -3.14e+100,
 ('B', 'vnt'): -3.14e+100,
 ('E', 'lt'): -3.14e+100,
 ('B', 'vn'): -4.837480229443587,
 ('E', 'ad'): -3.14e+100,
 ('B', 'nst'): -3.14e+100,
 ('M', 'nzt'): -3.14e+100,
 ('M', 'jz'): -3.14e+100,
 ('B', 'q'): -9.771954162574279,
 ('S', 'd'): -4.415367887902266,
 ('M', 'b'): -3.14e+100,
 ('M', 'jt'): -3.14e+100,
 ('E', 'nt'): -3.14e+100,
 ('M', 'nzz'): -3.14e+100,
 ('E'

#### 计算状态转移概率矩阵

In [66]:
prob_trans = {}
for key in transition_dict:
    total_freq = sum(transition_dict[key].values())
    prob_trans[key] = {}
    for state in transition_dict[key]:
        if transition_dict[key][state] != 0:
            prob_trans[key][state] = math.log(transition_dict[key][state]/total_freq)
        else:
            prob_trans[key][state] = MIN_FLOAT

In [67]:
import pprint
pprint.pprint(prob_trans)

{('B', 'a'): {('E', 'a'): -0.002964631081595905,
              ('M', 'a'): -5.822484628024347},
 ('B', 'ad'): {('E', 'ad'): -0.00039721946897661186,
               ('M', 'ad'): -7.831220214604293},
 ('B', 'an'): {('E', 'an'): 0.0},
 ('B', 'b'): {('E', 'b'): -0.10569883776149944,
              ('M', 'b'): -2.2995453338339447},
 ('B', 'c'): {('E', 'c'): -0.050848324502924334,
              ('M', 'c'): -3.0042245400603003},
 ('B', 'd'): {('E', 'd'): -0.06514517274476904, ('M', 'd'): -2.763532836362741},
 ('B', 'e'): {('E', 'e'): 0.0},
 ('B', 'f'): {('E', 'f'): -0.017625915306656272,
              ('M', 'f'): -4.047185012322762},
 ('B', 'i'): {('M', 'i'): 0.0},
 ('B', 'j'): {('E', 'j'): -0.5396642803902233, ('M', 'j'): -0.874534594362103},
 ('B', 'jt'): {('E', 'jt'): -0.4503368622853398,
               ('M', 'jt'): -1.0144919226541818},
 ('B', 'jz'): {('E', 'jz'): -1.9459101490553135,
               ('M', 'jz'): -0.15415067982725836},
 ('B', 'l'): {('E', 'l'): -5.454205666273836,
         

               ('S', 'Ng'): -3.454606491687882,
               ('S', 'Tg'): -6.632660322035828,
               ('S', 'Vg'): -5.716369590161673,
               ('S', 'a'): -3.770459441106359,
               ('S', 'an'): -7.325807502595773,
               ('S', 'c'): -5.128582925259553,
               ('S', 'd'): -4.330075229041782,
               ('S', 'f'): -5.534048033367718,
               ('S', 'j'): -4.067710964574291,
               ('S', 'k'): -6.632660322035828,
               ('S', 'm'): -3.9585116726092986,
               ('S', 'n'): -3.0631276255544573,
               ('S', 'nr'): -6.632660322035828,
               ('S', 'p'): -3.1986731175506815,
               ('S', 'q'): -6.632660322035828,
               ('S', 'r'): -3.6622458564661264,
               ('S', 'u'): -3.4971661061066777,
               ('S', 'v'): -2.4582730521401905,
               ('S', 'w'): -1.2413080895206015,
               ('S', 'y'): -6.227195213927663},
 ('S', 'Yg'): {('S', 'w'): 0.0},
 ('S', 'a'): {

#### 计算发射概率矩阵

In [68]:
prob_emit = {}
for key in emission_dict:
    total_freq = sum(emission_dict[key].values())
    prob_emit[key] = {}
    for char in emission_dict[key]:
        if emission_dict[key][char] != 0:
            prob_emit[key][char] = math.log(emission_dict[key][char]/total_freq)
        else:
            prob_emit[key][char] = MIN_FLOAT

In [69]:
prob_emit[('B', 'v')]

{'迈': -7.057000328837131,
 '充': -6.764013204155658,
 '发': -3.502564537442916,
 '来': -5.2441874504112445,
 '致': -6.144973995749434,
 '继': -5.281326997360701,
 '建': -4.486339001541101,
 '推': -4.710703802752669,
 '前': -6.476331131703876,
 '恢': -6.556869994187889,
 '行': -6.709945982885381,
 '自': -6.327295552543389,
 '保': -4.878869410316492,
 '召': -6.070866023595713,
 '高': -6.556869994187889,
 '总': -7.06769561795388,
 '展': -5.9871670047190655,
 '制': -5.482197913065788,
 '向': -7.66846947838281,
 '获': -5.689498467066608,
 '深': -5.7169727226188565,
 '改': -4.577427025024494,
 '扩': -5.850902125198882,
 '关': -5.641870418077354,
 '相': -5.414086487206641,
 '有': -5.479997692156185,
 '得': -5.44331612552158,
 '取': -4.980221904576779,
 '互': -7.351799869057777,
 '确': -5.923541308838853,
 '加': -4.208263543592544,
 '参': -4.914434164038776,
 '符': -6.83812117630938,
 '顺': -9.202399838308764,
 '走': -6.11549817761648,
 '多': -7.7936326213368154,
 '促': -5.497736516493567,
 '作': -4.960419277280599,
 '开': -3.9539

### HMM词性标注模型

In [70]:
MIN_FLOAT = -3.14e100
MIN_INF = float("-inf")

In [71]:
def viterbi(obs, states, start_p, trans_p, emit_p):
    V = [{}]  # tabular
    mem_path = [{}]
    all_states = trans_p.keys()
    for y in states.get(obs[0], all_states):  # 遍历第一个字可能的隐状态
        V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT)
        mem_path[0][y] = ''
    for t in range(1, len(obs)):
        V.append({})
        mem_path.append({})
        # 获取前一个时刻可能的状态
        prev_states = [x for x in mem_path[t - 1].keys() if len(trans_p[x]) > 0]
        
        # 获取前一时刻的状态可以转移到的所有状态
        prev_states_expect_next = set((y for x in prev_states for y in trans_p[x].keys()))
        
        # 将利用前一时刻的状态计算出的当前时刻的状态，与当前字可能的状态求交集，进一步对可能的状态进行过滤
        obs_states = set(states.get(obs[t], all_states)) & prev_states_expect_next

        if not obs_states:
            obs_states = prev_states_expect_next if prev_states_expect_next else all_states

        for y in obs_states:
            prob, state = max((V[t - 1][y0] + trans_p[y0].get(y, MIN_INF) +
                               emit_p[y].get(obs[t], MIN_FLOAT), y0) for y0 in prev_states)
            V[t][y] = prob
            mem_path[t][y] = state

    last = [(V[-1][y], y) for y in mem_path[-1].keys()]
    # if len(last)==0:
    #     print obs
    prob, state = max(last)

    route = [None] * len(obs)
    i = len(obs) - 1
    while i >= 0:
        route[i] = state
        state = mem_path[i][state]
        i -= 1
    return (prob, route)

def __cut(sentence, char_state_tab_P, start_P, trans_P, emit_P):
    prob, pos_list = viterbi(
        sentence, char_state_tab_P, start_P, trans_P, emit_P)
    begin, nexti = 0, 0

    for i, char in enumerate(sentence):
        pos = pos_list[i][0]
        if pos == 'B':
            begin = i
        elif pos == 'E':
            yield (sentence[begin:i + 1], pos_list[i][1])
            nexti = i + 1
        elif pos == 'S':
            yield (char, pos_list[i][1])
            nexti = i + 1
    if nexti < len(sentence):
        yield (sentence[nexti:], pos_list[nexti][1])


In [72]:
sentence = "扬帆远东做与中国合作的先行"
list(__cut(sentence, char_state_dict, prob_start, prob_trans, prob_emit))

[('扬帆远', 'i'),
 ('东', 'f'),
 ('做', 'v'),
 ('与', 'p'),
 ('中国', 'ns'),
 ('合作', 'vn'),
 ('的', 'u'),
 ('先行', 'a')]