# Dynamic Programming

动态规划

1. 最有子结构
2. 子问题重复
3. 解决子问题(Memo)

## fibonacci

In [1]:
from functools import lru_cache

In [2]:
@lru_cache(2**10)
def fib(n):
    if n <= 1:
        return n
    else:
        return fib(n-1) + fib(n-2)

In [3]:
fib(10)

55

## 钢条切割

In [4]:
import collections

In [5]:
original_price = [1, 5, 8, 9, 10, 17, 17, 20, 24, 30, 33]

In [6]:
price = collections.defaultdict(int)
for i, p in enumerate(original_price):
    price[i+1] = p

In [7]:
price

defaultdict(int,
            {1: 1,
             2: 5,
             3: 8,
             4: 9,
             5: 10,
             6: 17,
             7: 17,
             8: 20,
             9: 24,
             10: 30,
             11: 33})

In [8]:
@lru_cache(2**10)
def r(n):
    return max([price[n]] + [r(i) + r(n-i) for i in range(1, n)])

In [9]:
r(20)

60

In [10]:
solutions = collections.defaultdict()


@lru_cache(2**10)
def r(n):
    revenue_left_right = [(price[n], 0, n)] # no split
    
    for i in range(1, n):
        revenue_left_right.append((r(i)+r(n-i), i, n-i))

    best_price, left, right = max(revenue_left_right, key=lambda x: x[0])

    solutions[n] = (left, right)

    return best_price

In [11]:
r(100)

300

In [12]:
solutions

defaultdict(None,
            {1: (0, 1),
             2: (0, 2),
             3: (0, 3),
             4: (2, 2),
             5: (2, 3),
             6: (0, 6),
             7: (1, 6),
             8: (2, 6),
             9: (3, 6),
             10: (0, 10),
             11: (0, 11),
             12: (2, 10),
             13: (2, 11),
             14: (3, 11),
             15: (2, 13),
             16: (6, 10),
             17: (6, 11),
             18: (2, 16),
             19: (2, 17),
             20: (10, 10),
             21: (10, 11),
             22: (11, 11),
             23: (2, 21),
             24: (2, 22),
             25: (3, 22),
             26: (6, 20),
             27: (6, 21),
             28: (6, 22),
             29: (2, 27),
             30: (10, 20),
             31: (10, 21),
             32: (10, 22),
             33: (11, 22),
             34: (2, 32),
             35: (2, 33),
             36: (3, 33),
             37: (6, 31),
             38: (6, 32),
     

In [13]:
def parse_solution(n):
    left, right = solutions[n]

    if left == 0:
        return [right]
    else:
        return parse_solution(left) + parse_solution(right)

In [14]:
parse_solution(34)

[2, 10, 11, 11]

## Edit Distance

In [15]:
OP_ADD = 'ADD'
OP_DEL = 'DEL'
OP_SUB = 'SUB'
OP_NONE = 'NONE'

In [16]:
ed_solution = {}


@lru_cache(2**10)
def edit_distance(s1, s2):
    """
    Edit distance
    """
    if len(s1) == 0:return len(s2)

    if len(s2) == 0: return len(s1)

    tail_s1 = s1[-1]
    tail_s2 = s2[-1]

    candidates = [
        (edit_distance(s1[:-1], s2) + 1, '{} {}'.format(OP_DEL, tail_s1)),
        (edit_distance(s1, s2[:-1]) + 1, '{} {}'.format(OP_ADD, tail_s2)),
    ]

    if tail_s1 == tail_s2:
        candidates.append((edit_distance(s1[:-1], s2[:-1]) + 0, '{}'.format(OP_NONE)))
    else:
        candidates.append((edit_distance(s1[:-1], s2[:-1]) + 1, '{} {} => {}'.format(OP_SUB, tail_s1, tail_s2)))

    distance, operation = min(candidates, key=lambda x: x[0])

    ed_solution[(s1, s2)] = operation

    return distance

In [17]:
edit_distance('tets', 'test')

2

In [23]:
def parse_solution(origin_s1, origin_s2):
    """
    Edit distance parse solution.
    """
    
    def traverse_soution(s1, s2):
        if len(s1) == 0 and len(s2) != 0:
            return [(s1, s2, OP_ADD)]

        if len(s1) != 0 and len(s2) == 0:
            return [(s1, s2, OP_DEL)]

        if (s1, s2) not in ed_solution:
            return []

        solution = ed_solution[(s1, s2)]
        op = solution.split()[0]

        if op == OP_DEL:
            return [(s1, s2, op)] + traverse_soution(s1[:-1], s2)
        elif op == OP_ADD:
            return [(s1, s2, op)] + traverse_soution(s1, s2[:-1])
        elif op == OP_SUB:
            return [(s1, s2, op)] + traverse_soution(s1[:-1], s2[:-1])
        else:
            return [(s1, s2, op)] + traverse_soution(s1[:-1], s2[:-1])

    solutions = traverse_soution(origin_s1, origin_s2)

    # parse solutions
    update_s1 = origin_s1
    correct_history = [update_s1]

    for sol in solutions:
        s1, s2, op = sol
        update = False

        if op == OP_ADD:
            update = True
            update_s1 = update_s1[:len(s1)] + s2[-1] + update_s1[len(s1):]
        elif op == OP_DEL:
            update = True
            update_s1 = update_s1[:len(s1)-1] + update_s1[len(s1):]
        elif op == OP_SUB:
            update_s1 = update_s1[:len(s1)-1] + s2[-1] + update_s1[len(s1):]
            update = True

        else:
            update = False

        if update:
            correct_history.append(update_s1)

    return correct_history

In [24]:
distance = edit_distance('teset', 'test')
correct_history = parse_solution('teset', 'test')

print('distance: {}, correct history: {}'.format(distance, '->'.join(correct_history)))

distance: 1, correct history: teset->test


In [25]:
def debug_edit_distance(s1, s2):
    distance = edit_distance(s1, s2)
    correct_history = parse_solution(s1, s2)

    print('distance: {}, correct history: {}'.format(distance, '->'.join(correct_history)))

In [26]:
debug_edit_distance('ierachecal', 'hierarchical')

distance: 3, correct history: ierachecal->ierachical->ierarchical->hierarchical


## Pinyin Auto Correction Problem

In [1]:
chinese_dataset = 'article_9k.txt'

In [2]:
CHINESE_CHARACTERS = open(chinese_dataset).read()

In [67]:
import pinyin

In [7]:
pinyin.get('你好', format='strip', delimiter=' ')

'ni hao'

In [13]:
def chinese_to_pinyin(character):
    return pinyin.get(character, format='strip', delimiter=' ')

In [15]:
chinese_to_pinyin('我喜欢Python')

'wo xi huan P y t h o n'

In [16]:
CHINESE_PINYIN_COPY = chinese_to_pinyin(CHINESE_CHARACTERS)

In [19]:
CHINESE_PINYIN_COPY[:10]

'ci wai zi '

In [17]:
import re

In [23]:
def tokens(text):
    return re.findall('[a-z]+', text.lower())

In [26]:
tokens(CHINESE_PINYIN_COPY[:10])

['ci', 'wai', 'zi']

In [27]:
from collections import Counter, defaultdict

In [28]:
PINYIN_COUNT = Counter(tokens(CHINESE_PINYIN_COPY))

In [48]:
def splits(word):
    "Return a list of all possible (first, rest) pairs that comprise pinyin."
    return [(word[:i], word[i:]) 
            for i in range(len(word)+1)]

In [113]:
splits('pinyin')

[('', 'pinyin'),
 ('p', 'inyin'),
 ('pi', 'nyin'),
 ('pin', 'yin'),
 ('piny', 'in'),
 ('pinyi', 'n'),
 ('pinyin', '')]

In [116]:
def edist1(word):
    """
    Return all strings that are one edit away from this pinyin.
    """
    alphabet = 'abcdefghijklmnopqrstuvwxyz'

    def get_deletes(pairs):
        return [a+b[1:] for (a, b) in pairs if b]

    def get_transposes(pairs):
        return [a+b[1]+b[0]+b[2:] for (a, b) in pairs if len(b) > 1]

    def get_raplces(pairs):
        return [a+c+b[1:]  for (a, b) in pairs for c in alphabet if b]

    def get_inserts(pairs):
        return [a+c+b for (a, b) in pairs for c in alphabet]

    pairs      = splits(word)
    deletes    = get_deletes(pairs)
    transposes = get_transposes(pairs)
    replaces   = get_raplces(pairs)
    inserts    = get_inserts(pairs)

    return set(deletes + transposes + replaces + inserts)

In [51]:
def known(words):
    return [w for w in words if w in PINYIN_COUNT]

In [52]:
def edist0(word):
    return {word}

def edist2(word):
    return {e2 for e1 in edist1(word) for e2 in edist1(e1)}

def edist3(word):
    return {e3 for e3 in edist2(word)}

In [104]:
def correct(word):
    """
    Find the most possible pinyin based on edit distance.
    """

    candidates = (known(edist0(word)) or 
                  known(edits1(word)) or
                  known(edist2(word)) or
                 [word])

    return max(candidates, key=PINYIN_COUNT.get)    # return the most possible

In [105]:
correct('pign')

'ping'

In [63]:
def correct_sequence_pinyin(pinyin_seq):
    return ' '.join(map(correct, pinyin_seq.split()))

In [72]:
correct_sequence_pinyin('zhe sih yi ge ce shi')

'zhe shi yi ge ce shi'

In [108]:
correct_sequence_pinyin('bai du shi yi jia ren gogn zhi nng gong si')

'bai du shi yi jia ren gong zhi neng gong si'