# **pytorch 实现 FizzBuzz**

FizzBuzz规则

从1开始数数，遇到3的倍数就说fizz，遇到5的倍数就说buzz，当遇到15的倍数就说fizzbuzz，其他情况正常

crossEntropy Loss

cuda加速

mini-batch


In [48]:
def FizzBuzzEncode(i):
    if i % 15 == 0:
        return 3
    elif i % 5 == 0:
        return 2
    elif i % 3 == 0:
        return 1
    else:
        return 0

def FizzBuzzDecode(i, prediction):
    # [][prediction],prediction是前面列表的index
    return [str(i),"fizz","buzz","fizzbuzz"][prediction]

def FizzBuzzPrint(i):
    print(FizzBuzzDecode(i,FizzBuzzEncode(i)))

for i in range(1,16):
    FizzBuzzPrint(i)


1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz


首先准备训练数据

x是二进制编码，y总共有四个类别（0,1,2,3）


做一个小的特征工程，将原数字转为二进制

x 【923，10】 y【923】

In [54]:
import numpy as np
import torch

def BinaryEncode(n):
    #将n转为二进制，n>0
    binarylist = []
    while(n>0):
        binarylist.append(n%2)#除2求余
        n = n//2
    # 此项目中设定二进制长度为10，如3显示0000000011
    binarylist += [0]*(10-len(binarylist)) 
    return binarylist[::-1]


trainX = torch.Tensor([BinaryEncode(i) for i in range(101,2**10)])
trainY = torch.LongTensor([FizzBuzzEncode(i) for i in range(101,2**10)])
#注意类别用longtensor
trainY.shape


torch.Size([923])

## 使用Pytorch定义模型

fizzbuzz本质上一个分类游戏，使用Cross entropy loss

注意cross entropy loss的target不是one hot类型的

input=torch.Tensor([[-0.7715, -0.6205,-0.2562]])

entroy=torch.nn.CrossEntropyLoss()

target = torch.tensor([0])

output = entroy(input, target)


In [0]:
NUM_HIDDEN = 100
model = torch.nn.Sequential(
    torch.nn.Linear(10, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN,4) #4probabiility after softmax
)
if torch.cuda.is_available():
    model = model.cuda()

lossFn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.05)


In [57]:
BATCH_SIZE=128

for epochIndex in range(10000):
    #对每个batch
    for start in range(0,len(trainX),BATCH_SIZE):#[start,end，margin]
        end = start + BATCH_SIZE
        batchX = trainX[start:end] #【【二进制】,【二进制】,...】一个batch的训练数据
        batchY = trainY[start:end]

        if torch.cuda.is_available():
            batchX = batchX.cuda() #将数据传到GPU上
            batchY = batchY.cuda()

        y_pred = model(batchX)
        loss = lossFn(y_pred,batchY)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if epochIndex % 200 == 0:
        print("epoch:",epochIndex, loss.item())


epoch: 0 0.0038208167534321547
epoch: 200 0.0034803373273462057
epoch: 400 0.003206288442015648
epoch: 600 0.0029756580479443073
epoch: 800 0.0027484893798828125
epoch: 1000 0.0025650307070463896
epoch: 1200 0.0023951884359121323
epoch: 1400 0.0022538502234965563
epoch: 1600 0.0021014567464590073
epoch: 1800 0.001979704247787595
epoch: 2000 0.0018657224718481302
epoch: 2200 0.0017675469862297177
epoch: 2400 0.0016561438096687198
epoch: 2600 0.0015852185897529125
epoch: 2800 0.0015105671482160687
epoch: 3000 0.0014337963657453656
epoch: 3200 0.0013674983056262136
epoch: 3400 0.0013079113559797406
epoch: 3600 0.001247264677658677
epoch: 3800 0.0011898676166310906
epoch: 4000 0.0011505373986437917
epoch: 4200 0.0011011053575202823
epoch: 4400 0.0010530507424846292
epoch: 4600 0.0010189657332375646
epoch: 4800 0.000979617820121348
epoch: 5000 0.0009453561506234109
epoch: 5200 0.0009111121762543917
epoch: 5400 0.0008814070024527609
epoch: 5600 0.0008509423932991922
epoch: 5800 0.00082600559


### 测试

在1到100上玩fizzbuzz




In [77]:
testX = torch.Tensor([BinaryEncode(i) for i in range(1,101)])
if torch.cuda.is_available():
    testX = testX.cuda()

with torch.no_grad():
    testY = model(testX)

#testY是一个【100，4】的矩阵，【【4个概率】，【4个概率】...】


predlist=torch.max(testY,dim=1)[1].data.tolist()#max返回两个值，value和argmax
#.data 取出数据，tolist（）转成列表

predictions = zip(range(1,101),predlist)

for i,pred in predictions:
    print(FizzBuzzDecode(i,pred))



1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz
fizz
22
23
fizz
buzz
26
fizz
28
29
fizzbuzz
31
32
fizz
fizz
buzz
fizz
37
38
fizz
buzz
41
fizz
43
44
fizzbuzz
46
47
fizz
49
buzz
fizz
52
53
fizz
buzz
56
fizz
58
59
fizzbuzz
61
62
fizz
64
buzz
fizz
67
68
fizz
buzz
71
fizz
73
74
fizzbuzz
76
77
fizz
79
buzz
fizz
82
83
fizz
buzz
86
fizz
88
89
fizzbuzz
91
92
93
94
buzz
fizz
97
98
fizz
buzz


准确率相当高啊！