### Cut Rod Problem

Initialize data

In [1]:
from collections import defaultdict

In [44]:
original_price = [1, 5, 8, 9, 10, 17, 17, 20, 24, 30, 35]
price = defaultdict(int)
for i, p in enumerate(original_price):
    price[i+1] = p

In [45]:
price[5]

10

Create a decorator

In [15]:
# from functools import wraps # 用来还原原来的方法名
def memo(f):
    memo.already_computed = {}
    # @wraps(f)
    def wrapper(arg):
        result = None
        
        if arg in memo.already_computed:
            result = memo.already_computed[arg]
        else:
            result = f(arg)
            memo.already_computed[arg] = result
            
        return result
    
    return wrapper

In [30]:
def max_revenue(n):
    """
    Args: n is the iron length
    Returns: the max revenue and the cut method
    """
    solution = {} # record all best cut solution for 1-n length
    
    @memo
    def cut(n):
        max_price, max_split = max(
            [(price[n], 0)] + [(cut(i) + cut(n-i), i) for i in range(1, n)], key=lambda x: x[0]
        )
        solution[n] = (n - max_split, max_split)
        return max_price
    
    def parse_solution(n):
        l_split, r_split = solution[n]
        
        if r_split == 0: return [l_split]
        
        return parse_solution(l_split) + parse_solution(r_split)
    
    revenue = cut(n)
    cutted = parse_solution(n)
    
    # printout the result
    cutted_d = {}
    for c in cutted:
        if c not in cutted_d:
            cutted_d[c] = 1
        else:
            cutted_d[c] += 1
    print('The max revenue of {} length iron is {}'.format(n, revenue))
    print('The cut method is: ' + ','.join(
            ['{}*{}'.format(c, cutted_d[c]) for c in sorted(cutted_d.keys(), reverse=True)])
          )
    
    return revenue, cutted_d

In [46]:
revenue, cutted = max_revenue(20)

The max revenue of 20 length iron is 60
The cut method is: 11*1,6*1,3*1


In [48]:
revenue, cutted = max_revenue(60)

The max revenue of 60 length iron is 188
The cut method is: 11*5,3*1,2*1


In [53]:
revenue, cutted = max_revenue(54)

The max revenue of 54 length iron is 170
The cut method is: 11*4,10*1


### Edit Distance

In [2]:
from functools import lru_cache

In [43]:
def edit_distance(s1, s2):
    
    best_solutions = {}
    
    @lru_cache(maxsize=2**10)
    def ed(s1, s2, tail=''):
        
        if len(s1) == 0: return len(s2)
        if len(s2) == 0: return len(s1)
        
        tail_s1 = s1[-1]
        tail_s2 = s2[-1]
        
        # use tail to record the removed tail
        candidates = [
            (ed(s1[:-1], s2, tail) + 1, 'DEL {} at position {}'.format(tail_s1, len(s1)), tail),
            # string1 delete tail
            (ed(s1, s2[:-1], tail_s2 + tail) + 1, 'ADD {} at position {}'.format(tail_s2, len(s1)+1), tail_s2 + tail)
            # string1 add the tail of string2
        ]
        
        if tail_s1 == tail_s2:
            both_forward = (ed(s1[:-1], s2[:-1], tail_s2 + tail) + 0, '', tail_s2 + tail)
        else:
            both_forward = (ed(s1[:-1], s2[:-1], tail_s2 + tail) + 1, 
                            'SUB {} => {} at position {}'.format(tail_s1, tail_s2, len(s1)), tail_s2 + tail)
            
        candidates.append(both_forward)
        
        min_distance, operation, tail = min(candidates, key=lambda x: x[0])
        
        best_solutions[(s1, s2)] = [operation, tail]
        
        return min_distance
    
    solution = []
    
    def parse_solution(s1, s2):
        if (s1, s2) in best_solutions:
            operation, tail = best_solutions[(s1, s2)]
            if operation.startswith('D'):
                solution.append('({:<10}, {}): {}'.format(s1+tail, s2+tail, operation))
                return parse_solution(s1[:-1], s2)
            elif operation.startswith('A'):
                solution.append('({:<10}, {}): {}'.format(s1+tail[1:], s2+tail[1:], operation))
                return parse_solution(s1, s2[:-1])
            elif operation.startswith('S'):
                solution.append('({:<10}, {}): {}'.format(s1+tail[1:], s2+tail[1:], operation))
                return parse_solution(s1[:-1], s2[:-1])
            else:
                return parse_solution(s1[:-1], s2[:-1])
    
    min_distance = ed(s1, s2)
    parse_solution(s1, s2)
    solution.append('({:<10}, {}): {}'.format(s2, s2, 'Done'))
    
    return min_distance, solution

In [44]:
distance, solution = edit_distance('ABCDE', 'ADEEFG')

In [45]:
distance

5

In [181]:
solution

['(ABCDE     , ADEEFG): ADD G at position 6',
 '(ABCDEG    , ADEEFG): ADD F at position 6',
 '(ABCDEFG   , ADEEFG): ADD E at position 6',
 '(ABCDEEFG  , ADEEFG): DEL C at position 3',
 '(ABDEEFG   , ADEEFG): DEL B at position 2',
 '(ADEEFG    , ADEEFG): Done']

### Pinyin Auto Correction Problem

###### Recode

data processing

In [3]:
chinese_dataset = '../lecture1/article_9k.txt'

In [4]:
data = open(chinese_dataset).read()

In [5]:
import re

In [6]:
def tokens(text):
    # list all the chinese characters
    return ''.join(re.findall('[\u4e00-\u9fff]', text))

In [7]:
CHINESE_CHARACTERS = tokens(data)

In [8]:
len(CHINESE_CHARACTERS)

30365478

In [9]:
import pinyin

In [10]:
pinyin.get('你好', format='strip', delimiter=" ")

'ni hao'

In [11]:
def chinese2pinyin(character):
    return pinyin.get(character, format='strip', delimiter=" ")

In [12]:
CHINESE_PINYIN_COPYS = chinese2pinyin(CHINESE_CHARACTERS)

In [13]:
len(CHINESE_PINYIN_COPYS)

123312338

In [14]:
from collections import Counter, defaultdict

In [15]:
CHINESE_PINYIN_TOKENS = CHINESE_PINYIN_COPYS.split()

In [16]:
len(CHINESE_PINYIN_TOKENS)

30365478

In [17]:
PINYIN_COUNT = Counter(CHINESE_PINYIN_TOKENS)

correct the splitted words

In [18]:
def correct(word):
    # Prefer edit distance 0, then 1, then 2; otherwise default to word itself
    
    candidate = (known(edits0(word)) or
                 known(edits1(word)) or
                 known(edits2(word)) or
                 [word])
    return max(candidate, key=PINYIN_COUNT.get)

In [19]:
def known(words):
    return {w for w in words if w in PINYIN_COUNT}

def edits0(word):
    # return word itself (0 edit distance)
    return {word}

def edits1(word):
    # return all strings that are 1 edit away from this pinyin
    pairs      = splits(word)
    deletes    = [a+b[1:]           for (a, b) in pairs if b]
    transposes = [a+b[1]+b[0]+b[2:] for (a, b) in pairs if len(b) > 1]
    replaces   = [a+c+b[1:]         for (a, b) in pairs for c in alphabet if b]
    inserts    = [a+c+b             for (a, b) in pairs for c in alphabet]
    return set(deletes + transposes + replaces + inserts)

def edits2(word):
    # return all strings that are 2 eidts awat from this pinyin
    return {e2 for e1 in edits1(word) for e2 in edits1(e1)}

def splits(word):
    # return a list of all possible (first, rest) pairs that comprise pinyin
    return [(word[:i], word[i:]) for i in range(len(word)+1)]

alphabet = 'abcdefghijklmnopqrstuvwxyz'

In [20]:
def correct_sequence_pinyin(text_pinyin):
    return ' '.join(map(correct, text_pinyin.split()))

In [21]:
correct_sequence_pinyin('zhe sih yi ge ce sho')

'zhe shi yi ge ce shi'

###### Homework Question ---> auto split the pinyin

In [46]:
def correct_unsplitted_string(string, score_func):
    solution = {}
    @lru_cache(maxsize=2**10)
    def cut_string(string):
        best_split = max(
            [correct(string)] + [cut_string(string[:i])+' '+cut_string(string[i:]) for i in range(1, min(6, len(string)))],
            key = lambda x: score_func(x)
        )
        solution[string] = best_split
        return best_split

    return cut_string(string), solution
        

In [23]:
one_word = CHINESE_PINYIN_TOKENS
two_words = [one_word[i] + ' ' + one_word[i+1] for i in range(len(one_word)-1)]
one_count = PINYIN_COUNT
two_count = Counter(two_words)

In [None]:
# 获取频次的函数
def get_gram_count(word, wc):
    if word in wc: return wc[word]
    else:
        return wc.most_common()[-1][-1]

In [24]:
# two gram language model
def two_gram_model(tokens):
    
    probability = 1
    
    for i in range(len(tokens)-1):
        word = tokens[i]
        next_w = tokens[i+1]
        
        two_gram_c = get_gram_count(word+' '+next_w, two_count)
        one_gram_c = get_gram_count(next_w, one_count)
        pro = two_gram_c / one_gram_c
        
        probability *= pro
        
    return probability

In [431]:
%%time
c, s = correct_unsplitted_string('zhesihyigecesho', score_func=two_gram_model)

CPU times: user 2min 13s, sys: 1.36 s, total: 2min 14s
Wall time: 2min 20s


In [432]:
c

'zi he si shi yi ci shi'

效果不太好，而且时间消耗比较高

In [41]:
two_gram_model('zi he si shi yi ci shi'.split())

9.69602573881772e-13

In [42]:
two_gram_model('zhe shi yi ge ce shi'.split())

6.324734771262851e-10

用2gram模型应该是可以提取出更优的分割+修正情况，添加一个编辑距离的惩罚项

In [47]:
def correct_2(string, score_func):
    solution = {}
    @lru_cache(maxsize=2**10)
    def cut_string(string):
        best_split = max(
            [correct(string)] + [cut_string(string[:i])+' '+cut_string(string[i:]) for i in range(1, min(6, len(string)))],
            key = lambda x: score_func(x) / (edit_distance(''.join(x), string)[0]+1)
        )
        solution[string] = best_split
        return best_split

    return cut_string(string), solution

In [49]:
%%time
correct_2('zhesihyigecesho', two_gram_model)[0]

CPU times: user 1min 55s, sys: 805 ms, total: 1min 55s
Wall time: 1min 57s


'zhi hyige ci shi'

其实我觉得这个切分不应该是通过2-gram模型取最大概率的那个结果，因为根据correct函数，correct('shi')可能会返回'chi', 'si'，而如果切分的话，本来正确的'hua'可能会被切分成'h'+'ua',经过correct变成'ha','ha'，最后的评分可能比正确的更高。因此下面改变一下思路，用以下规则来切分
>1. 总编辑距离最小
2. 切分时优先选择长度大的，比如说'shuang'，可以切分成'shu'+'ang'，但是这里优先切分为'shuang'

In [95]:
def correct_3(string):
    solution = {}
    @lru_cache(maxsize=2**10)  
    def cut_string(string):
        candidates = [cut_string(string[:i])+' '+cut_string(string[i:]) for i in range(1, min(6, len(string)))]
        if len(string) <= 6:
            candidates.append(string)
        splitted = min(candidates, key=lambda x: correct_words(x)[0]+len(x.split()))
        solution[string] = splitted
        return splitted
    
    def correct_words(string):
        distance = 0
        corrected = []
        words = string.split()
        for word in words:
            c, d = correct_one_word(word)
            corrected.append(c)
            distance += d
        return distance, corrected
    
    def correct_one_word(word):
        if word in PINYIN_COUNT:
            return (word, 0)
        else:
            e1 = known(edits1(word))
            if e1:
                return (max(e1, key=PINYIN_COUNT.get), 1)
            else:
                e2 = known(edits2(word))
                if e2:
                    return (max(e2, key=PINYIN_COUNT.get), 2)
                else:
                    return (word, 10)
                
    splitted = cut_string(string)
    
    return ' '.join(correct_words(splitted)[1])

In [96]:
%%time
correct_3('zhesihyigecesho')

CPU times: user 2.24 s, sys: 29.2 ms, total: 2.26 s
Wall time: 2.35 s


'zhe si yi ge ce shi'

In [97]:
%%time
correct_3('zhegozuoyehaonaua')

CPU times: user 1.43 s, sys: 16.9 ms, total: 1.45 s
Wall time: 1.51 s


'zhe guo zuo ye hao na hua'

可以看到，通过对编辑距离的约束，切分效果已经不错，能最大限度保留没有拼错的词，但是对整体的语义没做到很好的筛选。<br>
我试过把2-gram和编辑距离约束结合起来，但是测试了之后花的时间太长，完全不像现有的输入法这样识别能力强并且耗时低。在网上也没有xian