<a href="https://colab.research.google.com/github/ghwlsro/pytorch_study_yiran/blob/master/FizzBuzz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#写一个fizzbuzz拟合程序
## fizzbuzz游戏规则如下:
- 从1开始数数
- 当数到3的倍数时说fizz
- 当说到5的倍数时说buzz
- 当说到3和5的共同倍数时说fizzbuzz

## 1. 定义一个fizzbuzz数据生成函数

### 1.1. 定义函数

In [1]:
def fizz_buzz_encode(i):
  if i % 15 == 0: return 3
  elif i % 5 == 0: return 2
  elif i % 3 == 0: return 1
  else: return 0

def fizz_buzz_decode(i, prediction):
  return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

def helper(i):
  print(fizz_buzz_decode(i, fizz_buzz_encode(i)))

### 1.2. 测试函数

In [2]:
for i in range(1, 17):
  helper(i)

1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16


## 2. 做数据

### 2.1. 为了便于训练，把十进制数字变成二进制数字

In [14]:
import numpy as np

def dec_to_bin(i, fixed_length=10):
   # 使用内置的bin函数将十进制数转换为二进制字符串，并去掉前缀'0b'
    binary_str = bin(i)[2:]

    # 如果二进制字符串长度超过指定的固定长度，抛出异常
    if len(binary_str) > fixed_length:
        raise ValueError(f"二进制表示超过了指定的长度 {fixed_length} 位")

    # 计算需要补全的零的个数
    padding_length = fixed_length - len(binary_str)

    # 补全零并返回结果
    padded_binary_str = '0' * padding_length + binary_str

    # 将二进制字符串转换为NumPy数组
    binary_array = np.array(list(map(int, padded_binary_str)))

    return binary_array

# 把y转换成4位开关
def dec_to_dig(i):
  if i == 0:
    return np.array([0,0,0,1])
  elif i == 1:
    return np.array([0,0,1,0])
  elif i == 2:
    return np.array([0,1,0,0])
  elif i == 3:
    return np.array([1,0,0,0])

def dig_to_dec(i):
  if np.array_equal(i, np.array([0, 0, 0, 1])):
        return 0
  elif np.array_equal(i, np.array([0, 0, 1, 0])):
      return 1
  elif np.array_equal(i, np.array([0, 1, 0, 0])):
      return 2
  elif np.array_equal(i, np.array([1, 0, 0, 0])):
      return 3
  else:
      return -1  # 返回一个默认值，表示没有匹配的模式

In [29]:
import numpy as np
print(dec_to_bin(1022))
# 测试y变换
for i in range(1, 17):
  y = fizz_buzz_encode(i)
  y_dig = dec_to_dig(y)
  y_dec = dig_to_dec(y_dig)
  # print(i,y, y_dig, y_dec, fizz_buzz_decode(i, y))
  print(y_dig)

[1 1 1 1 1 1 1 1 1 0]
[0 0 0 1]
[0 0 0 1]
[0 0 1 0]
[0 0 0 1]
[0 1 0 0]
[0 0 1 0]
[0 0 0 1]
[0 0 0 1]
[0 0 1 0]
[0 1 0 0]
[0 0 0 1]
[0 0 1 0]
[0 0 0 1]
[0 0 0 1]
[1 0 0 0]
[0 0 0 1]


### 2.2. 做数据

In [44]:
import numpy as np
import torch
from sklearn.model_selection import train_test_split

# X的维数
NUM_DIGITS = 10

# 所有X 所有Y
dX = torch.Tensor(np.arange(1, 2 ** NUM_DIGITS))
X = torch.Tensor(np.array([dec_to_bin(i) for i in range(1, 2 ** NUM_DIGITS)]))
Y = torch.zeros((2 ** NUM_DIGITS - 1, 4))
for i in range(1, 2 ** NUM_DIGITS):
  Y[i-1] = torch.LongTensor(dec_to_dig(fizz_buzz_encode(i)))

if torch.cuda.is_available():
  X = X.cuda()
  Y = Y.cuda()
  pass
X.shape, Y.shape

# 将数据集分成训练集和测试集，比例为8:2
X_train, X_test, Y_train, Y_test, indices_train, indices_test = train_test_split(X, Y, np.arange(len(X)), test_size=0.2, random_state=84)
print(f"X_train.shape={X_train.shape}, Y_train.shape={Y_train.shape}, X_test.shape={X_test.shape}, Y_test.shape={Y_test.shape}")


X_train.shape=torch.Size([818, 10]), Y_train.shape=torch.Size([818, 4]), X_test.shape=torch.Size([205, 10]), Y_test.shape=torch.Size([205, 4])


### 2.3. 使用pytorch建立模型

In [87]:
# 中间层维数
NUM_HIDDEN = 100
# 输出层维数
NUM_OUT = 4

# 定义模型
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, NUM_OUT)
)
print(next(model.parameters()).device)  # 输出：cpu

if torch.cuda.is_available():
  model = model.cuda()
print(next(model.parameters()).device)

# 定义loss function
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.08)

# 梯度下降
BATCH_SIZE = 128
for epoch in range(10000):
  # 每个batch迭代一次
  for start in range(0, len(X_train), BATCH_SIZE):
    end = start + BATCH_SIZE
    batchX = X_train[start:end]
    batchY = Y_train[start:end]
    # forward pass
    predY = model(batchX)
    loss = loss_fn(predY, batchY)
    # print(f'Epoch: {epoch: 5d}, Loss: {loss: .3e}')
    # # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  # Find loss on training data
  loss_train = loss_fn(model(X_train), Y_train).item()
  loss_test = loss_fn(model(X_test), Y_test).item()
  print(f'Epoch: {epoch: 5d}, loss_train: {loss_train: .3e}, Loss_test: {loss_test: .3e}')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch:  5000, loss_train:  9.261e-03, Loss_test:  3.137e-01
Epoch:  5001, loss_train:  9.256e-03, Loss_test:  3.137e-01
Epoch:  5002, loss_train:  9.253e-03, Loss_test:  3.137e-01
Epoch:  5003, loss_train:  9.249e-03, Loss_test:  3.136e-01
Epoch:  5004, loss_train:  9.246e-03, Loss_test:  3.137e-01
Epoch:  5005, loss_train:  9.244e-03, Loss_test:  3.136e-01
Epoch:  5006, loss_train:  9.241e-03, Loss_test:  3.136e-01
Epoch:  5007, loss_train:  9.236e-03, Loss_test:  3.137e-01
Epoch:  5008, loss_train:  9.235e-03, Loss_test:  3.137e-01
Epoch:  5009, loss_train:  9.232e-03, Loss_test:  3.137e-01
Epoch:  5010, loss_train:  9.227e-03, Loss_test:  3.135e-01
Epoch:  5011, loss_train:  9.226e-03, Loss_test:  3.136e-01
Epoch:  5012, loss_train:  9.222e-03, Loss_test:  3.136e-01
Epoch:  5013, loss_train:  9.218e-03, Loss_test:  3.137e-01
Epoch:  5014, loss_train:  9.217e-03, Loss_test:  3.137e-01
Epoch:  5015, loss_train:  9.215e-0

## 3. 玩fizzbuzz游戏

In [90]:
count_error = 0
total_num = 1000
for i in range(1, total_num):
  x = dec_to_bin(i)
  x = torch.Tensor(x)
  if torch.cuda.is_available():
    x = x.cuda()
  y_pred = model(x)
  # 处理y，找到最大值设置成1，其他值设置成0
  # 找到最大值的索引
  max_y_index = torch.argmax(y_pred)
  y_pred = torch.zeros_like(y_pred)
  # 将最大值的索引处设置为1
  y_pred[max_y_index] = 1
  # 把y转换成dec
  y = dig_to_dec(y_pred.cpu().numpy())
  pred_res = fizz_buzz_decode(i, y)

  real_res = fizz_buzz_decode(i, fizz_buzz_encode(i))
  # print(i, real_res, pred_res)
  if real_res != pred_res:
    count_error = count_error + 1
    print(i, real_res, pred_res)
print(f"Error: {count_error}/{total_num}")

273 fizz buzz
275 buzz 275
336 fizz 336
375 fizzbuzz 375
415 buzz fizz
465 fizzbuzz 465
580 buzz 580
630 fizzbuzz buzz
745 buzz 745
810 fizzbuzz 810
853 853 fizz
930 fizzbuzz 930
935 buzz 935
938 938 fizz
950 buzz 950
Error: 15/1000
