游戏的规则如下：

- 当数字为3的倍数时，打印FIZZ。
- 当数字为5的倍数时，打印BUZZ。
- 当数字既是3的倍数，又是5的倍数时，打印FIZZBUZZ。


# 使用numpy实现

In [25]:
import numpy as np

#定义变量
input_size = 10
epochs = 1000
batches = 64
lr = 0.01

#sigmod函数
def sig(val):
    return 1 / (1 + np.exp(-val))

#sigmod求导 
def sig_d(val):
    sig_val = sig(val)
    return sig_val * (1 - sig_val)

#十进制数转二进制数据（长度为input_size）
def binary_enc(num):
    #将num转为二进制， 其中 '{0:b}'.format(num) 将一个数转为二进制
    ret = [int(i) for i in '{0:b}'.format(num)]  
    #返回固定长度的二进制数
    return [0] * (input_size - len(ret)) + ret

#二进制数转十进制数
def binary_dec(array):
    ret = 0
    for i in array:
        ret = ret * 2 + int(i)
    return ret

#将数据划分为测试集和数据集
def training_test_gen(x, y):
    assert len(x) == len(y)
    indices = np.random.permutation(range(len(x)))  
    split_size = int(0.9 * len(indices))
    trX = x[indices[:split_size]]
    trY = y[indices[:split_size]]
    teX = x[indices[split_size:]]
    teY = y[indices[split_size:]]
    return trX, trY, teX, teY

#生产训练集和数据集
def x_y_gen():
    x = []
    y = []
    for i in range(1000):
        x.append(binary_enc(i))
        if i % 15 == 0:
            y.append([1, 0, 0, 0])
        elif i % 5 == 0:
            y.append([0, 1, 0, 0])
        elif i % 3 == 0:
            y.append([0, 0, 1, 0])
        else:
            y.append([0, 0, 0, 1])
    return training_test_gen(np.array(x), np.array(y))


def check_fizbuz(i):
    if i % 15 == 0:
        return 'fizbuz'
    elif i % 5 == 0:
        return 'buz'
    elif i % 3 == 0:
        return 'fiz'
    else:
        return 'number'

#生产训练集和测试集合
#trX : 900 * 10
#trY : 900 * 4
#teX : 100 * 10
#teY : 100 * 4
trX, trY, teX, teY = x_y_gen()

#定义网络参数
w1 = np.random.randn(10, 100)
w2 = np.random.randn(100, 4)
b1 = np.zeros((1, 100))
b2 = np.zeros((1, 4))

#训练的批次
no_of_batches = int(len(trX) / batches)


for epoch in range(epochs):
    for batch in range(no_of_batches):
        # forward 前向传播
        start = batch * batches
        end = start + batches
        x = trX[start:end]
        y = trY[start:end]
        
        #定义一层神经网络
        z1 = x.dot(w1) + b1  #z1 = w1 * x + b1
        a1 = sig(z1)         #a1 = sigmod(z1)
        z2 = a1.dot(w2) + b2 #z2 = w2 * a1 + b2
        a2 = sig(z2)         #a2 = sigmod(z2)
        y_ = a2              #输出
        
        #损失函数
        error = y_ - y
        loss = (error ** 2).mean()  # L(y,y_)

        #反向传播backward 
        outgrad = error * sig_d(z2)
        delta_w2 = a1.T.dot(outgrad)   # dw2 = dL(y,y_)/dw2
        delta_b2 = np.ones([1, batches]).dot(outgrad)  #db2 = dL(y,y_)/db2

        hidden_error = error.dot(w2.T)
        hidden_grad = hidden_error * sig_d(z1)
        delta_w1 = x.T.dot(hidden_grad)   # dw1 = dL(y,y_)/dw1
        delta_b1 = np.ones([1, batches]).dot(hidden_grad) #db1 = dL(y,y_)/db1

        #更新w和b
        #w_ = w - lr * dw
        #b_ = b- lr * db
        w1 -= delta_w1 * lr
        b1 -= delta_b1 * lr
        w2 -= delta_w2 * lr
        b2 -= delta_b2 * lr
    print(epoch, loss)

#测试
z1 = teX.dot(w1) + b1
a1 = sig(z1)
z2 = a1.dot(w2) + b2
y_ = sig(z2)
outli = ['fizbuz', 'buz', 'fiz', 'number']
for i in range(len(teX)):
    num = binary_dec(teX[i])
    print('Number: {} -- Actual: {} -- Prediction: {}'.format(num, check_fizbuz(num), outli[y_[i].argmax()]))

#准确率
print('Test loss: ', np.mean(teY - y_))

0 0.16882853473208226
1 0.16719818514157786
2 0.16689794565146532
3 0.16614141713754452
4 0.1656243765726701
5 0.16533413598718402
6 0.1651735514363386
7 0.16504655853152927
8 0.1648910626780974
9 0.16468495636890731
10 0.16442940776227402
11 0.16413407231820187
12 0.16380989049543984
13 0.16346624473397003
14 0.16311007515562026
15 0.1627471362408877
16 0.16238400742537729
17 0.16202828835687233
18 0.16168677340716076
19 0.1613628948501746
20 0.1610549278124453
21 0.16075647623708633
22 0.1604586019949145
23 0.16015108983565102
24 0.15982242271159044
25 0.15946053743947153
26 0.15905580791062013
27 0.1586049743507792
28 0.15811275862713492
29 0.15758982625710377
30 0.15704904295437527
31 0.1565024005790481
32 0.15595945430692315
33 0.15542697318032123
34 0.1549092386341931
35 0.15440856926332538
36 0.15392584919423508
37 0.15346097364372208
38 0.15301319431834268
39 0.1525813742822897
40 0.15216416899874705
41 0.1517601501287805
42 0.15136788669109896
43 0.1509859958810435
44 0.150613

375 0.03451825873515425
376 0.03437302846898735
377 0.03422859714351544
378 0.03408496832549787
379 0.033942139748088004
380 0.033800102533287156
381 0.033658840701276546
382 0.03351833101158893
383 0.03337854316912012
384 0.03323944042165754
385 0.03310098057405017
386 0.03296311744618949
387 0.03282580280615502
388 0.0326889888143375
389 0.03255263101668781
390 0.032416691922313806
391 0.03228114518851871
392 0.03214598041037618
393 0.032011208467222774
394 0.03187686731126347
395 0.031743027993492115
396 0.031609800616014455
397 0.031477339795779616
398 0.031345849155912535
399 0.031215584374777516
400 0.031086854470349343
401 0.030960021306209576
402 0.0308354977399629
403 0.03071374525818253
404 0.03059527211520967
405 0.030480632650495297
406 0.03037042746271102
407 0.030265302624567773
408 0.03016594463732004
409 0.030073067070888634
410 0.02998738545797717
411 0.02990957924545942
412 0.029840243049974898
413 0.02977983304863051
414 0.029728616650914697
415 0.029686633579633685


752 0.013878889380333739
753 0.01385299429445912
754 0.013827340236244384
755 0.013801927060292133
756 0.013776754555704776
757 0.013751822444585275
758 0.013727130380848713
759 0.013702677949341857
760 0.01367846466526653
761 0.013654489973899446
762 0.013630753250600931
763 0.013607253801101938
764 0.013583990862058425
765 0.0135609636018602
766 0.013538171121680839
767 0.01351561245675392
768 0.013493286577860532
769 0.013471192393012024
770 0.0134493287493121
771 0.01342769443498155
772 0.013406288181529235
773 0.013385108666052949
774 0.013364154513653363
775 0.013343424299945571
776 0.01332291655365193
777 0.013302629759261306
778 0.01328256235973967
779 0.013262712759278074
780 0.013243079326064314
781 0.013223660395065461
782 0.013204454270809317
783 0.013185459230153074
784 0.013166673525029001
785 0.01314809538515709
786 0.013129723020715772
787 0.013111554624962576
788 0.013093588376797163
789 0.013075822443260283
790 0.013058254981962739
791 0.013040884143439093
792 0.01302

# 使用numpy和pytorch混合实现

In [66]:
import numpy as np
import torch as th
from torch.autograd import Variable

input_size = 10
epochs = 1000
batches = 64
lr = 0.01


def binary_enc(num):
    ret = [int(i) for i in '{0:b}'.format(num)]
    return [0] * (input_size - len(ret)) + ret


def binary_dec(array):
    ret = 0
    for i in array:
        ret = ret * 2 + int(i)
    return ret


def training_test_gen(x, y):
    assert len(x) == len(y)
    indices = np.random.permutation(range(len(x)))
    split_size = int(0.9 * len(indices))
    trX = x[indices[:split_size]]
    trY = y[indices[:split_size]]
    teX = x[indices[split_size:]]
    teY = y[indices[split_size:]]
    return trX, trY, teX, teY


def x_y_gen():
    x = []
    y = []
    for i in range(1000):
        x.append(binary_enc(i))
        if i % 15 == 0:
            y.append([1, 0, 0, 0])
        elif i % 5 == 0:
            y.append([0, 1, 0, 0])
        elif i % 3 == 0:
            y.append([0, 0, 1, 0])
        else:
            y.append([0, 0, 0, 1])
    return training_test_gen(np.array(x), np.array(y))


def check_fizbuz(i):
    if i % 15 == 0:
        return 'fizbuz'
    elif i % 5 == 0:
        return 'buz'
    elif i % 3 == 0:
        return 'fiz'
    else:
        return 'number'


trX, trY, teX, teY = x_y_gen()
if th.cuda.is_available():
    dtype = th.cuda.FloatTensor
else:
    dtype = th.FloatTensor
x = Variable(th.from_numpy(trX).type(dtype), requires_grad=False)
y = Variable(th.from_numpy(trY).type(dtype), requires_grad=False)

w1 = Variable(th.randn(10, 100).type(dtype), requires_grad=True)
w2 = Variable(th.randn(100, 4).type(dtype), requires_grad=True)

b1 = Variable(th.zeros(1, 100).type(dtype), requires_grad=True)
b2 = Variable(th.zeros(1, 4).type(dtype), requires_grad=True)

no_of_batches = int(len(trX) / batches)
for epoch in range(epochs):
    for batch in range(no_of_batches):
        start = batch * batches
        end = start + batches
        x_ = x[start:end]
        y_ = y[start:end]

        a2 = x_.mm(w1)
        a2 = a2.add(b1.expand_as(a2))
        h2 = a2.sigmoid()

        a3 = h2.mm(w2)
        a3 = a3.add(b2.expand_as(a3))
        hyp = a3.sigmoid()

        error = hyp - y_
        loss = error.pow(2).sum()
        loss.backward()

        w1.data -= lr * w1.grad.data
        w2.data -= lr * w2.grad.data
        b1.data -= lr * b1.grad.data
        b2.data -= lr * b2.grad.data
        w1.grad.data.zero_()
        w2.grad.data.zero_()
    print(epoch, error.mean().item())
    
# test
a2 = th.cuda.FloatTensor(teX).mm(w1)
a2 = a2.add(b1.expand_as(a2))
h2 = a2.sigmoid()

a3 = h2.mm(w2)
a3 = a3.add(b2.expand_as(a3))
hyp = a3.sigmoid()
outli = ['fizbuz', 'buz', 'fiz', 'number']
for i in range(len(teX)):
    num = binary_dec(teX[i])
    print('Number: {} -- Actual: {} -- Prediction: {}'.format(num, check_fizbuz(num), outli[hyp[i].argmax()]))

#准确率
print('Test loss: ', (th.cuda.FloatTensor(teY) - hyp).mean().item())

0 -0.10500475019216537
1 -0.0749436467885971
2 -0.04602602869272232
3 -0.02914709784090519
4 -0.032776713371276855
5 -0.0439731627702713
6 0.007055184803903103
7 0.018787069246172905
8 0.020539317280054092
9 0.02027551829814911
10 0.018340077251195908
11 0.02017737738788128
12 0.023848040029406548
13 0.027208689600229263
14 0.02830706164240837
15 0.025374483317136765
16 0.020541273057460785
17 0.01539410650730133
18 0.011026573367416859
19 0.007126615848392248
20 0.001954986248165369
21 -0.030187733471393585
22 -0.10060733556747437
23 -0.0017138904659077525
24 -0.03291697800159454
25 -0.09708435833454132
26 -0.08456472307443619
27 -0.09086039662361145
28 -0.08721701800823212
29 -0.0856754332780838
30 -0.08278842270374298
31 -0.07941603660583496
32 -0.07415968924760818
33 -0.06753556430339813
34 -0.06228592246770859
35 -0.05878093093633652
36 -0.05680737644433975
37 -0.05843733251094818
38 0.047982268035411835
39 -0.020162682980298996
40 -0.06413997709751129
41 -0.06482817232608795
42 -

352 -0.06353360414505005
353 -0.06355636566877365
354 -0.06357940286397934
355 -0.06360450387001038
356 -0.06363274157047272
357 -0.06366442143917084
358 -0.06369840353727341
359 -0.06373371183872223
360 -0.06376826018095016
361 -0.06380113214254379
362 -0.0638316422700882
363 -0.06385981291532516
364 -0.06388556212186813
365 -0.06390904635190964
366 -0.06392912566661835
367 -0.06394533812999725
368 -0.06395697593688965
369 -0.0639643445611
370 -0.06396884471178055
371 -0.06397272646427155
372 -0.06397797912359238
373 -0.06398652493953705
374 -0.06399848312139511
375 -0.06401316821575165
376 -0.06402833759784698
377 -0.06404199451208115
378 -0.06405209004878998
379 -0.06405825167894363
380 -0.06406167894601822
381 -0.0640648752450943
382 -0.06407061219215393
383 -0.06408093124628067
384 -0.06409605592489243
385 -0.06411443650722504
386 -0.06413429975509644
387 -0.06415360420942307
388 -0.06417165696620941
389 -0.06418910622596741
390 -0.06420766562223434
391 -0.06422925740480423
392 -0

702 -0.04681394249200821
703 -0.0464799664914608
704 -0.04610133543610573
705 -0.0458902083337307
706 -0.045790351927280426
707 -0.04565006494522095
708 -0.04536886885762215
709 -0.04500526189804077
710 -0.044745564460754395
711 -0.044744305312633514
712 -0.045039884746074677
713 -0.045490264892578125
714 -0.04577997326850891
715 -0.04573363810777664
716 -0.04540944844484329
717 -0.0449535995721817
718 -0.04454673454165459
719 -0.044322334229946136
720 -0.04428412765264511
721 -0.044347263872623444
722 -0.04443492367863655
723 -0.04451604187488556
724 -0.04457765072584152
725 -0.04460597783327103
726 -0.04459432512521744
727 -0.04454532265663147
728 -0.04446601867675781
729 -0.044367335736751556
730 -0.04426231235265732
731 -0.04416012018918991
732 -0.04405999183654785
733 -0.043950170278549194
734 -0.04381667822599411
735 -0.04365261644124985
736 -0.043463334441185
737 -0.04326215386390686
738 -0.043063342571258545
739 -0.04287701100111008
740 -0.0427071638405323
741 -0.04255402833223

Number: 847 -- Actual: number -- Prediction: number
Number: 870 -- Actual: fizbuz -- Prediction: fiz
Number: 200 -- Actual: buz -- Prediction: number
Number: 806 -- Actual: number -- Prediction: number
Number: 929 -- Actual: number -- Prediction: number
Number: 933 -- Actual: fiz -- Prediction: fiz
Number: 169 -- Actual: number -- Prediction: number
Number: 671 -- Actual: number -- Prediction: number
Number: 155 -- Actual: buz -- Prediction: number
Number: 566 -- Actual: number -- Prediction: number
Number: 24 -- Actual: fiz -- Prediction: fiz
Number: 298 -- Actual: number -- Prediction: number
Number: 429 -- Actual: fiz -- Prediction: fiz
Number: 80 -- Actual: buz -- Prediction: fiz
Number: 443 -- Actual: number -- Prediction: number
Number: 742 -- Actual: number -- Prediction: number
Number: 305 -- Actual: buz -- Prediction: number
Number: 272 -- Actual: number -- Prediction: fiz
Number: 304 -- Actual: number -- Prediction: fiz
Number: 779 -- Actual: number -- Prediction: number
Numb