Permalink
Switch branches/tags
Nothing to show
Find file Copy path
3c1c528 Sep 27, 2015
1 contributor

Users who have contributed to this file

126 lines (114 sloc) 4.26 KB
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import sys
import gzip
import marshal
from math import log
from ..utils import frequency
class CharacterBasedGenerativeModel(object):
def __init__(self):
self.l1 = 0.0
self.l2 = 0.0
self.l3 = 0.0
self.status = ('b', 'm', 'e', 's')
self.uni = frequency.NormalProb()
self.bi = frequency.NormalProb()
self.tri = frequency.NormalProb()
def save(self, fname, iszip=True):
d = {}
for k, v in self.__dict__.items():
if hasattr(v, '__dict__'):
d[k] = v.__dict__
else:
d[k] = v
if sys.version_info[0] == 3:
fname = fname + '.3'
if not iszip:
marshal.dump(d, open(fname, 'wb'))
else:
f = gzip.open(fname, 'wb')
f.write(marshal.dumps(d))
f.close()
def load(self, fname, iszip=True):
if sys.version_info[0] == 3:
fname = fname + '.3'
if not iszip:
d = marshal.load(open(fname, 'rb'))
else:
try:
f = gzip.open(fname, 'rb')
d = marshal.loads(f.read())
except IOError:
f = open(fname, 'rb')
d = marshal.loads(f.read())
f.close()
for k, v in d.items():
if hasattr(self.__dict__[k], '__dict__'):
self.__dict__[k].__dict__ = v
else:
self.__dict__[k] = v
def div(self, v1, v2):
if v2 == 0:
return 0
return float(v1)/v2
def train(self, data):
for sentence in data:
now = [('', 'BOS'), ('', 'BOS')]
self.bi.add((('', 'BOS'), ('', 'BOS')), 1)
self.uni.add(('', 'BOS'), 2)
for word, tag in sentence:
now.append((word, tag))
self.uni.add((word, tag), 1)
self.bi.add(tuple(now[1:]), 1)
self.tri.add(tuple(now), 1)
now.pop(0)
tl1 = 0.0
tl2 = 0.0
tl3 = 0.0
samples = sorted(self.tri.samples(), key=lambda x: self.tri.get(x)[1])
for now in samples:
c3 = self.div(self.tri.get(now)[1]-1, self.bi.get(now[:2])[1]-1)
c2 = self.div(self.bi.get(now[1:])[1]-1, self.uni.get(now[1])[1]-1)
c1 = self.div(self.uni.get(now[2])[1]-1, self.uni.getsum()-1)
if c3 >= c1 and c3 >= c2:
tl3 += self.tri.get(now)[1]
elif c2 >= c1 and c2 >= c3:
tl2 += self.tri.get(now)[1]
elif c1 >= c2 and c1 >= c3:
tl1 += self.tri.get(now)[1]
self.l1 = self.div(tl1, tl1+tl2+tl3)
self.l2 = self.div(tl2, tl1+tl2+tl3)
self.l3 = self.div(tl3, tl1+tl2+tl3)
def log_prob(self, s1, s2, s3):
uni = self.l1*self.uni.freq(s3)
bi = self.div(self.l2*self.bi.get((s2, s3))[1], self.uni.get(s2)[1])
tri = self.div(self.l3*self.tri.get((s1, s2, s3))[1],
self.bi.get((s1, s2))[1])
if uni+bi+tri == 0:
return float('-inf')
return log(uni+bi+tri)
def tag(self, data):
now = [((('', 'BOS'), ('', 'BOS')), 0.0, [])]
for w in data:
stage = {}
not_found = True
for s in self.status:
if self.uni.freq((w, s)) != 0:
not_found = False
break
if not_found:
for s in self.status:
for pre in now:
stage[(pre[0][1], (w, s))] = (pre[1], pre[2]+[s])
now = list(map(lambda x: (x[0], x[1][0], x[1][1]),
stage.items()))
continue
for s in self.status:
for pre in now:
p = pre[1]+self.log_prob(pre[0][0], pre[0][1], (w, s))
if (not (pre[0][1],
(w, s)) in stage) or p > stage[(pre[0][1],
(w, s))][0]:
stage[(pre[0][1], (w, s))] = (p, pre[2]+[s])
now = list(map(lambda x: (x[0], x[1][0], x[1][1]), stage.items()))
return zip(data, max(now, key=lambda x: x[1])[2])