### Decorator

In [1]:
import time

In [2]:
def timer(func):
    def wrapper(*args):
        start = time.time()
        result = func(*args)
        end = time.time()
        print('Time cost: {}'.format(end-start))
        return result
    return wrapper

### Fibonacci

#### Naive

In [13]:
@timer
def get_fib_naive(n):
    return fib_naive(n)

In [14]:
def fib_naive(n):
    if n <=2:
        return 1
    else:
        return fib_naive(n-1) + fib_naive(n-2)

In [15]:
get_fib_naive(40)

Time cost: 28.81243920326233


102334155

#### Memo

In [16]:
from collections import defaultdict

In [17]:
import pdb

In [18]:
@timer
def get_fib_memo(n):
    memo = defaultdict()
    return fib_memo(n, memo)

In [19]:
def fib_memo(n, memo):
    if n in memo:
        return memo[n]
    else:
        if n <= 2:
            memo[n] = 1
            return 1
        else:
            result = fib_memo(n-1, memo) + fib_memo(n-2, memo)
            memo[n] = result
            return result

In [20]:
get_fib_memo(1000)

Time cost: 0.002000570297241211


43466557686937456435688527675040625802564660517371780402481729089536555417949051890403879840079255169295922593080322634775209689623239873322471161642996440906533187938298969649928516003704476137795166849228875

#### Memo Decorator

In [21]:
from functools import wraps

In [22]:
def memo(func):
    memo.already_computed = {}
    @wraps(func)
    def _wrap(arg):
        if arg in memo.already_computed:
            result = memo.already_computed[arg]
        else:
            result = func(arg)
            memo.already_computed[arg] = result
        return result
    return _wrap

In [23]:
@timer
def get_fib_memo_dec(n):
    return fib_memo_decorated(n)

In [24]:
@memo
def fib_memo_decorated(n):
    if n <= 2:
        return 1
    else:
        result = fib_memo_decorated(n-1) + fib_memo_decorated(n-2)
        return result

In [25]:
get_fib_memo_dec(1000)

Time cost: 0.0010001659393310547


43466557686937456435688527675040625802564660517371780402481729089536555417949051890403879840079255169295922593080322634775209689623239873322471161642996440906533187938298969649928516003704476137795166849228875

#### bottum up

In [26]:
@timer
def fib_bottum_up(n):
    if n <= 2:
        return 1
    fibs = [1] * n
    for i in range(2,n):
        fibs[i] = fibs[i - 1] + fibs[i - 2]
    return fibs[-1]

In [27]:
fib_bottum_up(1000)

Time cost: 0.0010004043579101562


43466557686937456435688527675040625802564660517371780402481729089536555417949051890403879840079255169295922593080322634775209689623239873322471161642996440906533187938298969649928516003704476137795166849228875

### Cut Rod

#### Defaultdict

In [28]:
from collections import defaultdict

In [29]:
price_list = [1, 5, 8, 9, 10, 17, 17, 20, 24, 30, 35]

In [30]:
price = defaultdict(int)

In [31]:
for i,v in enumerate(price_list):
    price[i + 1] = v

In [32]:
price[10]

30

In [33]:
price[100]

0

#### Naive

In [34]:
def cut_naive(n):
    return max([price[n]] + [cut_naive(i) + cut_naive(n-i) for i in range(1,n)])

In [35]:
@timer
def get_max_cut_naive(n):
    return cut_naive(n)

In [36]:
get_max_cut_naive(15)

Time cost: 5.201864719390869


45

#### Memo Decorator

In [37]:
solutions = {}

In [38]:
@memo
def cut_memo(n):
    max_price, split = max([(price[n], 0)] + [(cut_memo(i) + cut_memo(n-i), i) for i in range(1, n)], key=lambda x: x[0])
    solutions[n] = (n - split, split)
    return max_price

In [39]:
@timer
def get_max_cut_memo(n):
    return cut_memo(n)

In [40]:
get_max_cut_memo(15)

Time cost: 0.0


45

In [41]:
solutions

{1: (1, 0),
 2: (2, 0),
 3: (3, 0),
 4: (2, 2),
 5: (3, 2),
 6: (6, 0),
 7: (6, 1),
 8: (6, 2),
 9: (6, 3),
 10: (10, 0),
 11: (11, 0),
 12: (11, 1),
 13: (11, 2),
 14: (11, 3),
 15: (13, 2)}

#### parse solution

In [44]:
def parse_solution_memo(n):
    left, right = solutions[n]
    if right == 0: return [left]
    return parse_solution_memo(left) + parse_solution_memo(right)

In [45]:
parse_solution_memo(15)

[11, 2, 2]

#### bottom up

In [65]:
@timer
def cut_bottom_up(n):
    max_prices = [(price[i],(i,0)) for i in range(n+1)]
    for i in range(1, n+1):
        temp = max_prices[i]
        for j in range(1, (i>>1) + 1):
            # 这里只循环到 i/2，因为后面的切法与前半段重复
            temp = max([temp, (max_prices[j][0] + max_prices[i - j][0], (j, i-j))], key=lambda x:x[0])
        max_prices[i] = temp
    return max_prices

In [66]:
max_prices = cut_bottom_up(15)
max_prices

Time cost: 0.0


[(0, (0, 0)),
 (1, (1, 0)),
 (5, (2, 0)),
 (8, (3, 0)),
 (10, (2, 2)),
 (13, (2, 3)),
 (17, (6, 0)),
 (18, (1, 6)),
 (22, (2, 6)),
 (25, (3, 6)),
 (30, (10, 0)),
 (35, (11, 0)),
 (36, (1, 11)),
 (40, (2, 11)),
 (43, (3, 11)),
 (45, (2, 13))]

#### parse solution

In [67]:
def parse_solution_bottom_up(n):
    left, right = max_prices[n][1]
    if right == 0: return [left]
    return parse_solution_bottom_up(left) + parse_solution_bottom_up(right)

In [68]:
parse_solution_memo(15)

[11, 2, 2]

### Edit Distance

#### Recursion

In [69]:
from functools import lru_cache

In [70]:
operations = {}

In [127]:
@lru_cache(maxsize=2**10)
def edit_distance(str1, str2):
    if not str1 and not str2:
        return 0
    if not str1:
        return len(str2)
    if not str2:
        return len(str1)
    
    tail1 = str1[-1]
    tail2 = str2[-1]
    candidates = [
        (edit_distance(str1[:-1], str2) + 1, 'DEL {}'.format(tail1)),
        (edit_distance(str1, str2[:-1]) + 1, 'ADD {}'.format(tail2))
    ]
    if tail1 == tail2:
        candidates.append((edit_distance(str1[:-1], str2[:-1]), 'CONTINUE'))
    else:
        candidates.append((edit_distance(str1[:-1], str2[:-1]) + 1, 'SUB {} => {}'.format(tail1, tail2)))
        
    min_distance, operation = min(candidates, key=lambda x:x[0])
    
    operations[(str1, str2)] = operation
    return min_distance

In [128]:
edit_distance('ABCDECG','ABCCEF')

3

In [129]:
operations

{('A', 'A'): 'CONTINUE',
 ('A', 'AB'): 'ADD B',
 ('A', 'ABC'): 'ADD C',
 ('A', 'ABCC'): 'ADD C',
 ('A', 'ABCCE'): 'ADD E',
 ('A', 'ABCCEF'): 'ADD F',
 ('AB', 'A'): 'DEL B',
 ('AB', 'AB'): 'CONTINUE',
 ('AB', 'ABC'): 'ADD C',
 ('AB', 'ABCC'): 'ADD C',
 ('AB', 'ABCCE'): 'ADD E',
 ('AB', 'ABCCEF'): 'ADD F',
 ('ABC', 'A'): 'DEL C',
 ('ABC', 'AB'): 'DEL C',
 ('ABC', 'ABC'): 'CONTINUE',
 ('ABC', 'ABCC'): 'ADD C',
 ('ABC', 'ABCCE'): 'ADD E',
 ('ABC', 'ABCCEF'): 'ADD F',
 ('ABCD', 'A'): 'DEL D',
 ('ABCD', 'AB'): 'DEL D',
 ('ABCD', 'ABC'): 'DEL D',
 ('ABCD', 'ABCC'): 'SUB D => C',
 ('ABCD', 'ABCCE'): 'ADD E',
 ('ABCD', 'ABCCEF'): 'ADD F',
 ('ABCDE', 'A'): 'DEL E',
 ('ABCDE', 'AB'): 'DEL E',
 ('ABCDE', 'ABC'): 'DEL E',
 ('ABCDE', 'ABCC'): 'DEL E',
 ('ABCDE', 'ABCCE'): 'CONTINUE',
 ('ABCDE', 'ABCCEF'): 'ADD F',
 ('ABCDEC', 'A'): 'DEL C',
 ('ABCDEC', 'AB'): 'DEL C',
 ('ABCDEC', 'ABC'): 'DEL C',
 ('ABCDEC', 'ABCC'): 'CONTINUE',
 ('ABCDEC', 'ABCCE'): 'DEL C',
 ('ABCDEC', 'ABCCEF'): 'SUB C => F',
 ('

#### parse solution

In [130]:
def parse_solution_edit_distance(str1, str2):
    if not str1 and not str2:
        print("END")
        return
    if not str1:
        print("''=>{}, ADD {}".format(str2, ",".join([s for s in str2[::-1]])))
        return
    if not str2:
        print("{}=>'', DEL {}".format(str1, ",".join([s for s in str1[::-1]])))
        return
    
    operation = operations[(str1, str2)]
    print("{}=>{}, {}".format(str1, str2, operation))
    if "ADD" in operation:
        parse_solution_edit_distance(str1, str2[:-1])
    elif "DEL" in operation:
        parse_solution_edit_distance(str1[:-1], str2)
    else:
        parse_solution_edit_distance(str1[:-1], str2[:-1])

In [131]:
parse_solution_edit_distance('ABCDECG','ABCCEF')

ABCDECG=>ABCCEF, DEL G
ABCDEC=>ABCCEF, SUB C => F
ABCDE=>ABCCE, CONTINUE
ABCD=>ABCC, SUB D => C
ABC=>ABC, CONTINUE
AB=>AB, CONTINUE
A=>A, CONTINUE
END


#### Loop

In [206]:
@timer
def edit_distance_loop(str1, str2):
    if not str1:
        return len(str2)
    if not str2:
        return len(str1)
    n = len(str1)
    m = len(str2)
    dp = [[0 for i in range(m + 1)] for j in range(n + 1)]
    dp[0][0] = (0, "CONTINUE")
    for i in range(1, n + 1):
        dp[i][0] = (i, "DEL {}".format(str1[i-1]))
    for j in range(1, m + 1):
        dp[0][j] = (j, "ADD {}".format(str2[j-1]))
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            candidates = [
                (dp[i - 1][j][0] + 1, "DEL {}".format(str1[i-1])),
                (dp[i][j - 1][0] + 1, "ADD {}".format(str2[j-1])),
            ]
            if str1[i - 1] == str2[j - 1]:
                candidates.append((dp[i - 1][j - 1][0], "CONTINUE"))
            else:
                candidates.append((dp[i - 1][j - 1][0] + 1, "SUB {}=>{}".format(str1[i-1], str2[j-1])))
            dp[i][j] = min(candidates, key=lambda x:x[0])
    return dp

In [214]:
str1 = 'ABCDECG'
str2 = 'ABCCEF'
ops = edit_distance_loop(str1,str2)

Time cost: 0.0


#### parse solution

In [215]:
def print_ops(str1, str2, ops):
    str1 = "_"+str1
    str2 = " _"+str2
    for j in range(len(str2)):
        print(str2[j], end=" "*12)
    print("")
    for i in range(len(str1)):
        print([str1[i]]+ops[i])

In [216]:
print_ops(str1, str2, ops)

             _            A            B            C            C            E            F            
['_', (0, 'CONTINUE'), (1, 'ADD A'), (2, 'ADD B'), (3, 'ADD C'), (4, 'ADD C'), (5, 'ADD E'), (6, 'ADD F')]
['A', (1, 'DEL A'), (0, 'CONTINUE'), (1, 'ADD B'), (2, 'ADD C'), (3, 'ADD C'), (4, 'ADD E'), (5, 'ADD F')]
['B', (2, 'DEL B'), (1, 'DEL B'), (0, 'CONTINUE'), (1, 'ADD C'), (2, 'ADD C'), (3, 'ADD E'), (4, 'ADD F')]
['C', (3, 'DEL C'), (2, 'DEL C'), (1, 'DEL C'), (0, 'CONTINUE'), (1, 'ADD C'), (2, 'ADD E'), (3, 'ADD F')]
['D', (4, 'DEL D'), (3, 'DEL D'), (2, 'DEL D'), (1, 'DEL D'), (1, 'SUB D=>C'), (2, 'ADD E'), (3, 'ADD F')]
['E', (5, 'DEL E'), (4, 'DEL E'), (3, 'DEL E'), (2, 'DEL E'), (2, 'DEL E'), (1, 'CONTINUE'), (2, 'ADD F')]
['C', (6, 'DEL C'), (5, 'DEL C'), (4, 'DEL C'), (3, 'DEL C'), (2, 'CONTINUE'), (2, 'DEL C'), (2, 'SUB C=>F')]
['G', (7, 'DEL G'), (6, 'DEL G'), (5, 'DEL G'), (4, 'DEL G'), (3, 'DEL G'), (3, 'DEL G'), (3, 'DEL G')]


In [217]:
def parse_ops(str1, str2):
    if not str1 and not str2:
        return
    if not str1:
        print("''=>{}, ADD {}".format(str2, ",".join([s for s in str2[::-1]])))
        return
    if not str2:
        print("{}=>'', DEL {}".format(str1, ",".join([s for s in str1[::-1]])))
        return
    op = ops[len(str1)][len(str2)][1]
    print("{}=>{}, {}".format(str1, str2, op))
    if "ADD" in op:
        parse_ops(str1, str2[:-1])
    elif "DEL" in op:
        parse_ops(str1[:-1], str2)
    else:
        parse_ops(str1[:-1], str2[:-1])

In [218]:
parse_ops(str1, str2)

ABCDECG=>ABCCEF, DEL G
ABCDEC=>ABCCEF, SUB C=>F
ABCDE=>ABCCE, CONTINUE
ABCD=>ABCC, SUB D=>C
ABC=>ABC, CONTINUE
AB=>AB, CONTINUE
A=>A, CONTINUE


### Pinyin auto-correction

In [3]:
import pinyin
import re
from collections import Counter, defaultdict

In [4]:
chinese_dataset = './article_9k.txt'
CHINESE_CHARACTERS = open(chinese_dataset, encoding='utf-8').read()

In [5]:
CHINESE_CHARACTERS[:40]

'此外自本周6月12日起除小米手机6等15款机型外其余机型已暂停更新发布含开发版体'

In [6]:
pinyin.get(CHINESE_CHARACTERS[:40], format='strip',delimiter=' ')

'ci wai zi ben zhou 6 yue 1 2 ri qi chu xiao mi shou ji 6 deng 1 5 kuan ji xing wai qi yu ji xing yi zan ting geng xin fa bu han kai fa ban ti'

In [7]:
def to_pinyin(characters):
    return pinyin.get(characters, format='strip', delimiter=' ')

In [8]:
CHINESE_PINYIN = to_pinyin(CHINESE_CHARACTERS)

In [9]:
len(CHINESE_PINYIN)

129412578

In [10]:
CHINESE_PINYIN[:100]

'ci wai zi ben zhou 6 yue 1 2 ri qi chu xiao mi shou ji 6 deng 1 5 kuan ji xing wai qi yu ji xing yi '

In [11]:
def tokens(text):
    return re.findall('[a-z]+', text.lower())

In [12]:
tokens(CHINESE_PINYIN[:100])

['ci',
 'wai',
 'zi',
 'ben',
 'zhou',
 'yue',
 'ri',
 'qi',
 'chu',
 'xiao',
 'mi',
 'shou',
 'ji',
 'deng',
 'kuan',
 'ji',
 'xing',
 'wai',
 'qi',
 'yu',
 'ji',
 'xing',
 'yi']

In [13]:
CHINESE_PINYIN_TOKENS = tokens(CHINESE_PINYIN)

In [14]:
PINYIN_COUNT = Counter(CHINESE_PINYIN_TOKENS)

In [15]:
len(PINYIN_COUNT)

420

In [16]:
alphabet = 'abcdefghijklmnopqrstuvwxyz'

def splits(word):
    # 依次对拼音中每个位置进行切分
    return [(word[:i], word[i:]) for i in range(len(word)+1)]

def known(words, count, edits):
    # 返回输入词中所有在count中的词语，所有词语均不在其中时，返回空
    known_words = {w for w in words if w in count}
    if known_words:
        return known_words, edits
    else:
        return None

def edits0(word):
    # 返回编辑距离为0的词的集合，即原词
    return {word}

def edits1(word):
    # 返回编辑距离为1的词的集合
    pairs = splits(word)
    deletes = [a+b[1:] for (a,b) in pairs if b]
    transposes = [a+b[1]+b[0]+b[2:] for (a,b) in pairs if len(b) > 1]
    replaces = [a+c+b[1:] for (a,b) in pairs for c in alphabet if b]
    inserts = [a+c+b for (a,b) in pairs for c in alphabet]
    return set(deletes + transposes + replaces + inserts)

def edits2(word):
    # 返回编辑距离为2的词的集合
    return {e2 for e1 in edits1(word) for e2 in edits1(e1)}

In [17]:
def correct(word, count):
    '''
    依据编辑距离，找出输入内容对应最有可能的正确拼音
    '''
    # 用or依次判断每个集合是否为空，返回第一个不为空的集合
    # 如果当前集合为空，则判断下一个集合，否则返回当前集合
    # 可以达到优先选取的效果，顺序为edits0->edits1->edits2->原词
    candidates = (known(edits0(word), count, 0) or
                  known(edits1(word), count, 1) or
                  known(edits2(word), count, 2) or
                  ([word],3))
    return max(candidates[0],key=count.get), candidates[1]

In [18]:
word = 'pin'

In [19]:
known(edits1(word), PINYIN_COUNT, 1)

({'bin',
  'jin',
  'lin',
  'min',
  'nin',
  'pan',
  'pen',
  'pi',
  'pian',
  'pie',
  'pin',
  'ping',
  'qin',
  'xin',
  'yin'},
 1)

In [20]:
correct('pin', PINYIN_COUNT)

('pin', 0)

In [21]:
correct('pign', PINYIN_COUNT)

('ping', 1)

In [22]:
correct('pinnag', PINYIN_COUNT)

('ping', 2)

In [23]:
def correct_sequence_pinyin(text):
    return ' '.join([w[0] for w in map(lambda x:correct(x, PINYIN_COUNT), text.split())])

In [24]:
correct_sequence_pinyin('zhe sih yi ge ce sho')

'zhe shi yi ge ce shi'

In [25]:
correct_sequence_pinyin('wo xiang shagn qinng hua da xeu')

'wo xiang shang qing hua da xue'

In [26]:
correct_sequence_pinyin('zhe jiang gogn ye da xue')

'zhe jiang gong ye da xue'

### 思考题-homework
#### 如何在不带空格的时候完成自动修整？--> 如何完成拼音的自动分割？   
###### 提示：使用第一节课提到的语言模型!

In [27]:
from functools import lru_cache

In [28]:
PINYIN_TOKEN_2_GRAM = [''.join(CHINESE_PINYIN_TOKENS[i:i+2]) for i in range(len(CHINESE_PINYIN_TOKENS[:-1]))]

In [29]:
PINYIN_COUNT_2_GRAM = Counter(PINYIN_TOKEN_2_GRAM)

In [238]:
def prob(word, penalty):
    total = sum(PINYIN_COUNT.values())
    if word in PINYIN_COUNT:
        return (PINYIN_COUNT[word] + 1) / total
    else:
        corrected, edit_dist = correct(word, PINYIN_COUNT)
        if corrected in PINYIN_COUNT:
            return penalty[edit_dist] * (PINYIN_COUNT[corrected] + 1) / total
        return penalty[-1] / total

In [239]:
def prob_2(w1, w2, penalty):
    total = sum(PINYIN_COUNT_2_GRAM.values()) + len(PINYIN_COUNT_2_GRAM)
    if not w1:
        return prob(w2, penalty)
    if not w2:
        return prbo(w1, penalty)
    w1 = correct(w1, PINYIN_COUNT)[0]
    combine = w1 + w2
    if combine in PINYIN_COUNT_2_GRAM:
        return (PINYIN_COUNT_2_GRAM[combine] + 1) / total / prob(w1, penalty)
    else:
        corrected, edit_dist = correct(combine, PINYIN_COUNT_2_GRAM)
        if corrected in PINYIN_COUNT_2_GRAM:
            return penalty[edit_dist] * (PINYIN_COUNT_2_GRAM[corrected] + 1) / total / prob(w1, penalty)
        return 1 / total

In [240]:
# 对编辑距离纠正施加惩罚，对应编辑距离为0，1, 2的情况，以及超过2的情况
penalty=[1,1e-5,1e-7,1e-9]

In [241]:
cut_points = {}

In [256]:
@lru_cache(maxsize=2**10)
def seperate(text, max_len=7):
    # max_len 是单个中文拼音的最大长度+1，如果考虑edits2，可以设为9
    candidates = [(prob(text, penalty), 0)]
    for i in range(1, max_len):
        if i >= len(text):
            break
        candidates.append((seperate(text[:i], max_len) * seperate(text[i:], max_len), i))
    max_prob, cut_point = max(candidates, key=lambda x: x[0])
#     print(text, candidates)
    cut_points[text] = cut_point
    return max_prob

def parse_cut_points(text):
    i = cut_points[text]
    if not i:
        return [text]
    return parse_cut_points(text[:i]) + parse_cut_points(text[i:])

In [250]:
@timer
def seperate_correct(text):
    seperate(text)
    result = " ".join(parse_cut_points(text))
    print(result)
    return correct_sequence_pinyin(result)

In [251]:
penalty=[1,0.1,1e-7,1e-9]
seperate_correct("zehsihyigeecesho")

zeh sih yi ge ece sho
Time cost: 5.333781003952026


'zhe shi yi ge ce shi'

In [257]:
penalty=[1,0.1,1e-7,1e-9]
seperate_correct("beijignminhang")

bei jign min hang
Time cost: 3.4814696311950684


'bei jin min hang'

In [255]:
penalty=[1,1e-5,1e-7,1e-9]
seperate_correct("qinnghuadaxue")

qinng hua da xue
Time cost: 2.8271310329437256


'qing hua da xue'

切分结果不太好，与对纠正词的惩罚系数关系比较大，不清楚是不是概率计算有问题。。

添加2gram

In [51]:
# TODO
def seperate_2_gram(text, penalty, max_len=7):
    candidates = []
    for i in range(1, max_len):
        if i >= len(text):
            break
        for j in range(i, i + max_len):
            if j >= len(text):
                break
            candidates.append((prob_2(text[:i], text[i:j], penalty), i, j))
    if not candidates:
        return [text[pre:]]
    candidates.append((prob(text, penalty), 0, -1))
    max_prob, cut_point_1, cut_point_2 = max(candidates, key=lambda x: x[0])
    result = [text[:cut_point_1]]
    if cut_point_2 > 0:
        result = result + [text[cut_point_1:cut_point_2]] + seperate(text[cut_point_2:], penalty, max_len)
    return result

#### 训练序列标注模型，预测切分点

In [27]:
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split

In [28]:
@timer
def generate_data(tokens, change_ratio = 0.25):
    X = []
    y = []
    for token in tokens:
        sep_flag = True
        change = np.random.random() < change_ratio
        if change:
            token = np.random.choice(list(edits1(token)))
        for w in token:
            if sep_flag:
                y.append(1)
                sep_flag = False
            else:
                y.append(0)
            X.append(w)
    return X, y

In [29]:
len(CHINESE_PINYIN_TOKENS)

31266106

In [None]:
X, y = generate_data(CHINESE_PINYIN_TOKENS[:1000000])

In [51]:
len(y)

98107944

In [52]:
X = np.array(X)[:98103456]

In [53]:
y = y[:98103456]

In [54]:
X = X.reshape(-1, 1)

In [None]:
encoder = OneHotEncoder(sparse=False)
X = encoder.fit_transform(X) 

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.01, shuffle=False)

In [43]:
import tensorflow as tf
from tensorflow.keras import layers, Sequential
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam

In [None]:
model = Sequential([
    layers.Bidirectional(layers.GRU(64, return_sequences=True), input_shape=(32,26)),
    layers.Dense(64, activation='relu'),
    layers.Dense(1)
])

In [None]:
model.compile(loss=BinaryCrossentropy(from_logits=True),
              optimizer=Adam(1e-4),
              metrics=['accuracy'])

In [None]:
# 模型结构
model.summary()

In [None]:
# 模型训练，50个epoch
history = model.fit(X_train,
                    y_train,
                    batch_size=100,
                    epochs=50)

In [None]:
# 获取模型在测试集上的评价指标
loss, accuracy = model.evaluate(X_test, y_test)

In [None]:
# 保存模型结构和参数，文件较大，下次可以直接load_model继续训练
model.save("./output/model_epoch_50.h5")

In [None]:
# 重新加载模型和参数
model = tf.keras.models.load_model("./output/model_epoch_50.h5")