# 第一课

王晓涛<cauwxt@qq.com>


什么是PyTorch?
================

PyTorch是一个基于Python的科学计算库，它有以下特点:

- 类似于NumPy，但是它可以使用GPU
- 可以用它定义深度学习模型，可以灵活地进行深度学习模型的训练和使用

Tensors
---------------


Tensor类似与NumPy的ndarray，唯一的区别是Tensor可以在GPU上加速运算。


In [1]:
from __future__ import print_function
import torch

构造一个未初始化的5x3矩阵:

构建一个随机初始化的矩阵:

构建一个全部为0，类型为long的矩阵:

从数据直接直接构建tensor:

也可以从一个已有的tensor构建一个tensor。这些方法会重用原来tensor的特征，例如，数据类型，除非提供新的数据。

In [None]:
                                     # result has the same size

得到tensor的形状:

<div class="alert alert-info"><h4>注意</h4><p>``torch.Size`` 返回的是一个tuple</p></div>

Operations


有很多种tensor运算。我们先介绍加法运算。



另一种着加法的写法


加法：把输出作为一个变量

in-place加法

<div class="alert alert-info"><h4>注意</h4><p>任何in-place的运算都会以``_``结尾。
    举例来说：``x.copy_(y)``, ``x.t_()``, 会改变 ``x``。</p></div>

各种类似NumPy的indexing都可以在PyTorch tensor上面使用。


Resizing: 如果你希望resize/reshape一个tensor，可以使用``torch.view``：

如果你有一个只有一个元素的tensor，使用``.item()``方法可以把里面的value变成Python数值。

**更多阅读**


  各种Tensor operations, 包括transposing, indexing, slicing,
  mathematical operations, linear algebra, random numbers在
  `<https://pytorch.org/docs/torch>`.

Numpy和Tensor之间的转化
------------

在Torch Tensor和NumPy array之间相互转化非常容易。

Torch Tensor和NumPy array会共享内存，所以改变其中一项也会改变另一项。

把Torch Tensor转变成NumPy Array


改变numpy array里面的值。

把NumPy ndarray转成Torch Tensor

所有CPU上的Tensor都支持转成numpy或者从numpy转成Tensor。

CUDA Tensors
------------

使用``.to``方法，Tensor可以被移动到别的device上。



In [2]:
# let us run this cell only if CUDA is available
# We will use ``torch.device`` objects to move tensors in and out of GPU
if torch.cuda.is_available():
    device = torch.device("cuda")          # a CUDA device object
    y = torch.ones_like(x, device=device)  # directly create a tensor on GPU
    x = x.to(device)                       # or just use strings ``.to("cuda")``
    z = x + y
    print(z)
    print(z.to("cpu", torch.double))       # ``.to`` can also change dtype together!

NameError: name 'x' is not defined


热身: 用numpy实现两层神经网络
--------------

一个全连接ReLU神经网络，一个隐藏层，没有bias。用来从x预测y，使用L2 Loss。

这一实现完全使用numpy来计算前向神经网络，loss，和反向传播。

numpy ndarray是一个普通的n维array。它不知道任何关于深度学习或者梯度(gradient)的知识，也不知道计算图(computation graph)，只是一种用来计算数学运算的数据结构。



In [3]:
import numpy as np

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)

    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    
    # loss = (y_pred - y) ** 2
    grad_y_pred = 2.0 * (y_pred - y)
    # 
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)

    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 31260246.411749657
1 26546511.009867523
2 27201281.49060714
3 28674592.88350098
4 27767361.454228733
5 22806055.33962102
6 15660509.826203449
7 9218995.837742995
8 5062534.183504244
9 2811782.736140671
10 1693947.1213708415
11 1134061.0447826665
12 836109.3644645554
13 660223.5000345308
14 544263.390009437
15 460250.8661319723
16 395156.57484794897
17 342553.0477679961
18 298840.6364887414
19 261993.49222282256
20 230540.51706465316
21 203535.172381507
22 180230.70074317648
23 160016.86561235983
24 142424.88174835732
25 127058.46601168797
26 113579.96832661955
27 101739.18111637412
28 91320.0569265062
29 82121.72685544271
30 73975.01005427187
31 66744.38397783542
32 60309.13609170301
33 54572.37766369493
34 49446.49857428443
35 44860.793585826825
36 40751.06049754557
37 37060.831772760816
38 33743.8885846772
39 30757.794058621177
40 28065.694357289358
41 25635.56572224804
42 23438.306482855605
43 21451.405746688444
44 19651.434223201082
45 18019.17854630587
46 16539.101031462636
47 1

361 0.0014875237212221815
362 0.0014223827491915539
363 0.0013601084729342478
364 0.0013005420977857903
365 0.0012436142545274005
366 0.0011891685622898408
367 0.0011371083712561497
368 0.0010873435478532681
369 0.0010397473475525693
370 0.0009942426368644227
371 0.000950736351995439
372 0.0009091288952312883
373 0.0008693499714153713
374 0.0008313178603471722
375 0.0007949458831992317
376 0.0007601681132996963
377 0.0007269173132895689
378 0.0006951217417162386
379 0.000664715785026663
380 0.0006356463104040624
381 0.000607849411577985
382 0.000581266373378273
383 0.0005558515093470641
384 0.0005315457521936859
385 0.0005083051796237561
386 0.00048608488293530664
387 0.00046483221395580976
388 0.0004445140984700183
389 0.0004250866973767946
390 0.0004065034869699121
391 0.00038873840828832477
392 0.00037175009178639355
393 0.0003555011151679184
394 0.0003399685497143403
395 0.00032511326147982486
396 0.0003109055527556519
397 0.00029732409869727866
398 0.0002843332679776095
399 0.0002


PyTorch: Tensors
----------------

这次我们使用PyTorch tensors来创建前向神经网络，计算损失，以及反向传播。

一个PyTorch Tensor很像一个numpy的ndarray。但是它和numpy ndarray最大的区别是，PyTorch Tensor可以在CPU或者GPU上运算。如果想要在GPU上运算，就需要把Tensor换成cuda类型。


In [4]:
import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 39992300.0
1 41185612.0
2 44618008.0
3 40547396.0
4 27778012.0
5 14144622.0
6 6349229.5
7 3139666.5
8 1925899.875
9 1400285.875
10 1112990.375
11 920752.25
12 776336.1875
13 661717.0625
14 568255.25
15 491163.5
16 426773.53125
17 372520.8125
18 326517.5625
19 287334.65625
20 253744.28125
21 224813.9375
22 199777.34375
23 178032.875
24 159048.6875
25 142435.765625
26 127861.78125
27 115040.03125
28 103702.25
29 93660.46875
30 84745.96875
31 76813.8359375
32 69731.265625
33 63396.1015625
34 57718.7109375
35 52619.63671875
36 48029.32421875
37 43890.35546875
38 40156.25390625
39 36779.6953125
40 33725.578125
41 30957.22265625
42 28444.310546875
43 26159.615234375
44 24079.646484375
45 22183.228515625
46 20452.427734375
47 18870.9921875
48 17424.841796875
49 16101.4462890625
50 14888.76953125
51 13776.92578125
52 12756.94140625
53 11820.36328125
54 10959.21875
55 10167.6943359375
56 9438.7607421875
57 8767.15625
58 8148.1240234375
59 7577.052734375
60 7050.74951171875
61 6565.3125
62 611

378 0.0011588472407311201
379 0.0011182907037436962
380 0.0010792964603751898
381 0.0010425273794680834
382 0.001006508362479508
383 0.0009704729309305549
384 0.000938737764954567
385 0.0009080087766051292
386 0.0008777950424700975
387 0.0008497098460793495
388 0.000819699838757515
389 0.0007925801910459995
390 0.0007680226117372513
391 0.0007434476865455508
392 0.0007187966257333755
393 0.0006966863293200731
394 0.0006743226549588144
395 0.0006539690075442195
396 0.0006335899233818054
397 0.0006133493734523654
398 0.0005942268762737513
399 0.0005761347711086273
400 0.000558895873837173
401 0.0005417743232101202
402 0.000525482464581728
403 0.0005094424122944474
404 0.00049453234532848
405 0.0004805568605661392
406 0.00046678134822286665
407 0.00045341558870859444
408 0.0004407507076393813
409 0.0004279001150280237
410 0.0004156938230153173
411 0.0004044497909490019
412 0.00039300232310779393
413 0.0003819395788013935
414 0.0003711057361215353
415 0.00036195587017573416
416 0.000351432

简单的autograd

In [5]:
# Create tensors.
x = torch.tensor(1., requires_grad=True)
w = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)

# Build a computational graph.
y = w * x + b    # y = 2 * x + 3

# Compute gradients.
y.backward()

# Print out the gradients.
print(x.grad)    # x.grad = 2 
print(w.grad)    # w.grad = 1 
print(b.grad)    # b.grad = 1 

tensor(2.)
tensor(1.)
tensor(1.)



PyTorch: Tensor和autograd
-------------------------------

PyTorch的一个重要功能就是autograd，也就是说只要定义了forward pass(前向神经网络)，计算了loss之后，PyTorch可以自动求导计算模型所有参数的梯度。

一个PyTorch的Tensor表示计算图中的一个节点。如果``x``是一个Tensor并且``x.requires_grad=True``那么``x.grad``是另一个储存着``x``当前梯度(相对于一个scalar，常常是loss)的向量。


In [6]:
import torch

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N 是 batch size; D_in 是 input dimension;
# H 是 hidden dimension; D_out 是 output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# 创建随机的Tensor来保存输入和输出
# 设定requires_grad=False表示在反向传播的时候我们不需要计算gradient
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# 创建随机的Tensor和权重。
# 设置requires_grad=True表示我们希望反向传播的时候计算Tensor的gradient
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # 前向传播:通过Tensor预测y；这个和普通的神经网络的前向传播没有任何不同，
    # 但是我们不需要保存网络的中间运算结果，因为我们不需要手动计算反向传播。
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    # 通过前向传播计算loss
    # loss是一个形状为(1，)的Tensor
    # loss.item()可以给我们返回一个loss的scalar
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # PyTorch给我们提供了autograd的方法做反向传播。如果一个Tensor的requires_grad=True，
    # backward会自动计算loss相对于每个Tensor的gradient。在backward之后，
    # w1.grad和w2.grad会包含两个loss相对于两个Tensor的gradient信息。
    loss.backward()

    # 我们可以手动做gradient descent(后面我们会介绍自动的方法)。
    # 用torch.no_grad()包含以下statements，因为w1和w2都是requires_grad=True，
    # 但是在更新weights之后我们并不需要再做autograd。
    # 另一种方法是在weight.data和weight.grad.data上做操作，这样就不会对grad产生影响。
    # tensor.data会我们一个tensor，这个tensor和原来的tensor指向相同的内存空间，
    # 但是不会记录计算图的历史。
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 27484794.0
1 25186148.0
2 26499672.0
3 27811656.0
4 26498382.0
5 21366974.0
6 14516730.0
7 8542893.0
8 4735615.0
9 2664450.5
10 1622764.875
11 1092537.75
12 806148.5
13 635412.125
14 522389.5
15 440550.6875
16 377392.0625
17 326495.46875
18 284465.375
19 249142.46875
20 219169.796875
21 193523.078125
22 171425.4375
23 152314.140625
24 135731.796875
25 121290.828125
26 108650.5546875
27 97553.484375
28 87782.5546875
29 79149.46875
30 71507.9921875
31 64727.0546875
32 58689.7890625
33 53316.9296875
34 48510.96484375
35 44203.50390625
36 40340.703125
37 36863.23046875
38 33729.79296875
39 30905.220703125
40 28350.140625
41 26035.998046875
42 23934.515625
43 22025.091796875
44 20288.197265625
45 18705.76171875
46 17262.4609375
47 15943.8671875
48 14737.8857421875
49 13633.216796875
50 12620.9921875
51 11692.1796875
52 10837.201171875
53 10052.171875
54 9329.849609375
55 8664.591796875
56 8051.71533203125
57 7486.51025390625
58 6965.64306640625
59 6486.21923828125
60 6043.001953125
61 563


PyTorch: nn
-----------


这次我们使用PyTorch中nn这个库来构建网络。
用PyTorch autograd来构建计算图和计算gradients，
然后PyTorch会帮我们自动计算gradient。




In [7]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 620.5733642578125
1 571.293701171875
2 529.0409545898438
3 492.3656311035156
4 459.8135986328125
5 430.5155029296875
6 404.1044006347656
7 380.03692626953125
8 357.8116455078125
9 337.3223571777344
10 318.09051513671875
11 300.06781005859375
12 283.1866455078125
13 267.205078125
14 252.0729522705078
15 237.79379272460938
16 224.301513671875
17 211.51817321777344
18 199.46197509765625
19 188.07485961914062
20 177.29080200195312
21 167.09542846679688
22 157.40750122070312
23 148.26039123535156
24 139.60470581054688
25 131.42649841308594
26 123.69979858398438
27 116.38845825195312
28 109.4806137084961
29 102.9622573852539
30 96.81407165527344
31 91.02054595947266
32 85.57273864746094
33 80.44019317626953
34 75.61373901367188
35 71.07894134521484
36 66.81837463378906
37 62.82241439819336
38 59.07289505004883
39 55.556522369384766
40 52.253807067871094
41 49.15754699707031
42 46.25043869018555
43 43.52094268798828
44 40.962318420410156
45 38.563846588134766
46 36.30937576293945
47 34.2000

401 7.626089791301638e-05
402 7.415669824695215e-05
403 7.211537740658969e-05
404 7.013075082795694e-05
405 6.82039390085265e-05
406 6.633048178628087e-05
407 6.450699584092945e-05
408 6.273386679822579e-05
409 6.101321923779324e-05
410 5.933863576501608e-05
411 5.771401629317552e-05
412 5.613020766759291e-05
413 5.4596261179540306e-05
414 5.310078267939389e-05
415 5.164902177057229e-05
416 5.023710764362477e-05
417 4.886589522357099e-05
418 4.75314081995748e-05
419 4.623486165655777e-05
420 4.497456757235341e-05
421 4.3751027988037094e-05
422 4.255849853507243e-05
423 4.14010755775962e-05
424 4.0272097976412624e-05
425 3.9181788451969624e-05
426 3.811478745774366e-05
427 3.707895666593686e-05
428 3.607549297157675e-05
429 3.5093129554297775e-05
430 3.414317689021118e-05
431 3.321848635096103e-05
432 3.231843948015012e-05
433 3.144356742268428e-05
434 3.059373557334766e-05
435 2.976709947688505e-05
436 2.896072510338854e-05
437 2.8184394977870397e-05
438 2.7421152481110767e-05
439 2.66


PyTorch: optim
--------------

这一次我们不再手动更新模型的weights,而是使用optim这个包来帮助我们更新参数。
optim这个package提供了各种不同的模型优化方法，包括SGD+momentum, RMSProp, Adam等等。


In [8]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 619.4041748046875
1 603.1110229492188
2 587.2695922851562
3 571.850830078125
4 556.8591918945312
5 542.2738647460938
6 528.0520629882812
7 514.189208984375
8 500.81378173828125
9 487.809326171875
10 475.2987060546875
11 463.1747741699219
12 451.4615783691406
13 440.12237548828125
14 429.0684814453125
15 418.3021545410156
16 407.7438049316406
17 397.4444580078125
18 387.4076843261719
19 377.6274719238281
20 368.0901184082031
21 358.7952880859375
22 349.7372131347656
23 340.906005859375
24 332.28973388671875
25 323.9122009277344
26 315.7377014160156
27 307.7489929199219
28 299.9654235839844
29 292.38507080078125
30 284.9714660644531
31 277.7470703125
32 270.651123046875
33 263.6981201171875
34 256.8956604003906
35 250.253662109375
36 243.7548370361328
37 237.4071807861328
38 231.18507385253906
39 225.09231567382812
40 219.14073181152344
41 213.31344604492188
42 207.59295654296875
43 201.98304748535156
44 196.4952392578125
45 191.1175994873047
46 185.8562469482422
47 180.7108154296875
4

426 1.4392276259656e-09
427 1.3475736082568801e-09
428 1.2555183559470606e-09
429 1.1850953551828525e-09
430 1.1078031825206835e-09
431 1.0416733031703984e-09
432 9.806724321492766e-10
433 9.197524963866499e-10
434 8.712984778114219e-10
435 8.14494693912593e-10
436 7.637628862688928e-10
437 7.211815589158732e-10
438 6.794754203731657e-10
439 6.457345214094801e-10
440 6.058115675777742e-10
441 5.750903642187666e-10
442 5.411129877508358e-10
443 5.140519121482612e-10
444 4.859382896071907e-10
445 4.586113711013695e-10
446 4.3512876635176667e-10
447 4.1456915678139694e-10
448 3.8931696755284406e-10
449 3.67793517863646e-10
450 3.519370905813446e-10
451 3.337908005551782e-10
452 3.141437110443235e-10
453 3.0254215799274675e-10
454 2.879710359060539e-10
455 2.7465860141795417e-10
456 2.606502791380194e-10
457 2.512381136465791e-10
458 2.386495168149594e-10
459 2.2848578584699908e-10
460 2.1820956153106863e-10
461 2.0644329301600095e-10
462 2.0044046977751862e-10
463 1.9204671186656697e-10
4


PyTorch: 自定义 nn Modules
--------------------------

我们可以定义一个模型，这个模型继承自nn.Module类。如果需要定义一个比Sequential模型更加复杂的模型，就需要定义nn.Module模型。



In [9]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 687.2293701171875
1 634.19189453125
2 588.5949096679688
3 548.7039184570312
4 513.4244384765625
5 482.0264587402344
6 453.6788330078125
7 427.80926513671875
8 404.1365661621094
9 382.3700866699219
10 362.061279296875
11 342.96490478515625
12 324.990478515625
13 307.9693603515625
14 291.95574951171875
15 276.76947021484375
16 262.1172790527344
17 248.11285400390625
18 234.79629516601562
19 222.14688110351562
20 210.0814208984375
21 198.62844848632812
22 187.7060089111328
23 177.27169799804688
24 167.33160400390625
25 157.83274841308594
26 148.77952575683594
27 140.1869659423828
28 131.96580505371094
29 124.1637954711914
30 116.76571655273438
31 109.747802734375
32 103.10079956054688
33 96.81016540527344
34 90.86135864257812
35 85.23873138427734
36 79.93568420410156
37 74.9471435546875
38 70.24595642089844
39 65.81832885742188
40 61.655059814453125
41 57.732093811035156
42 54.04985809326172
43 50.59822463989258
44 47.36359786987305
45 44.338504791259766
46 41.50608444213867
47 38.84733

444 4.116254785913043e-06
445 3.993028713011881e-06
446 3.87336285712081e-06
447 3.7574150155705865e-06
448 3.644649268608191e-06
449 3.5359253161004744e-06
450 3.429631988183246e-06
451 3.326909336465178e-06
452 3.2277207537845243e-06
453 3.130404820694821e-06
454 3.0369674277608283e-06
455 2.9461200483638095e-06
456 2.8588967779796803e-06
457 2.773228743535583e-06
458 2.689662778720958e-06
459 2.6094221539096907e-06
460 2.5317474410258e-06
461 2.456156835251022e-06
462 2.382673528700252e-06
463 2.311317075509578e-06
464 2.2424972030421486e-06
465 2.175990857722354e-06
466 2.1103144263179274e-06
467 2.0477159523579758e-06
468 1.9866299680870725e-06
469 1.9275184968137182e-06
470 1.8701000499277143e-06
471 1.814253209886374e-06
472 1.7603261994736386e-06
473 1.7079676126741106e-06
474 1.6568108094361378e-06
475 1.607604417586117e-06
476 1.5597463516314747e-06
477 1.513082111159747e-06
478 1.4682266282761702e-06
479 1.42461726682086e-06
480 1.3822332221025135e-06
481 1.341399752163852e-

# FizzBuzz

FizzBuzz是一个简单的小游戏。游戏规则如下：从1开始往上数数，当遇到3的倍数的时候，说fizz，当遇到5的倍数，说buzz，当遇到15的倍数，就说fizzbuzz，其他情况下则正常数数。

我们可以写一个简单的小程序来决定要返回正常数值还是fizz, buzz 或者 fizzbuzz。

In [10]:
# One-hot encode the desired outputs: [number, "fizz", "buzz", "fizzbuzz"]
def fizz_buzz_encode(i):
    if   i % 15 == 0: return 3
    elif i % 5  == 0: return 2
    elif i % 3  == 0: return 1
    else:             return 0
    
def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

print(fizz_buzz_decode(1, fizz_buzz_encode(1)))
print(fizz_buzz_decode(2, fizz_buzz_encode(2)))
print(fizz_buzz_decode(5, fizz_buzz_encode(5)))
print(fizz_buzz_decode(12, fizz_buzz_encode(12)))
print(fizz_buzz_decode(15, fizz_buzz_encode(15)))

1
2
buzz
fizz
fizzbuzz


我们首先定义模型的输入与输出(训练数据)

In [11]:
import numpy as np
import torch

NUM_DIGITS = 10

# Represent each input by an array of its binary digits.
def binary_encode(i, num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)])

trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])

然后我们用PyTorch定义模型

In [14]:
# Define the model
NUM_HIDDEN = 100
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4)
)

- 为了让我们的模型学会FizzBuzz这个游戏，我们需要定义一个损失函数，和一个优化算法。
- 这个优化算法会不断优化（降低）损失函数，使得模型的在该任务上取得尽可能低的损失值。
- 损失值低往往表示我们的模型表现好，损失值高表示我们的模型表现差。
- 由于FizzBuzz游戏本质上是一个分类问题，我们选用Cross Entropyy Loss函数。
- 优化函数我们选用Stochastic Gradient Descent。

In [15]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)

以下是模型的训练代码

In [16]:
# Start training it
BATCH_SIZE = 128
for epoch in range(10000):
    for start in range(0, len(trX), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = trX[start:end]
        batchY = trY[start:end]

        y_pred = model(batchX)
        loss = loss_fn(y_pred, batchY)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Find loss on training data
    loss = loss_fn(model(trX), trY).item()
    print('Epoch:', epoch, 'Loss:', loss)

Epoch: 0 Loss: 1.1752445697784424
Epoch: 1 Loss: 1.1527748107910156
Epoch: 2 Loss: 1.1471346616744995
Epoch: 3 Loss: 1.1449143886566162
Epoch: 4 Loss: 1.1437774896621704
Epoch: 5 Loss: 1.143070936203003
Epoch: 6 Loss: 1.1425672769546509
Epoch: 7 Loss: 1.1421717405319214
Epoch: 8 Loss: 1.1418371200561523
Epoch: 9 Loss: 1.1415421962738037
Epoch: 10 Loss: 1.1412750482559204
Epoch: 11 Loss: 1.141027808189392
Epoch: 12 Loss: 1.140796184539795
Epoch: 13 Loss: 1.1405755281448364
Epoch: 14 Loss: 1.1403676271438599
Epoch: 15 Loss: 1.1401653289794922
Epoch: 16 Loss: 1.13997220993042
Epoch: 17 Loss: 1.139786958694458
Epoch: 18 Loss: 1.139608383178711
Epoch: 19 Loss: 1.1394349336624146
Epoch: 20 Loss: 1.1392674446105957
Epoch: 21 Loss: 1.1391061544418335
Epoch: 22 Loss: 1.1389487981796265
Epoch: 23 Loss: 1.1387975215911865
Epoch: 24 Loss: 1.1386489868164062
Epoch: 25 Loss: 1.138502597808838
Epoch: 26 Loss: 1.138362169265747
Epoch: 27 Loss: 1.1382249593734741
Epoch: 28 Loss: 1.1380912065505981
Epoc

Epoch: 233 Loss: 1.0845177173614502
Epoch: 234 Loss: 1.084859848022461
Epoch: 235 Loss: 1.0831843614578247
Epoch: 236 Loss: 1.0825752019882202
Epoch: 237 Loss: 1.0834254026412964
Epoch: 238 Loss: 1.0817445516586304
Epoch: 239 Loss: 1.082134485244751
Epoch: 240 Loss: 1.0805970430374146
Epoch: 241 Loss: 1.0810049772262573
Epoch: 242 Loss: 1.0793569087982178
Epoch: 243 Loss: 1.0797879695892334
Epoch: 244 Loss: 1.0793440341949463
Epoch: 245 Loss: 1.0779162645339966
Epoch: 246 Loss: 1.078298807144165
Epoch: 247 Loss: 1.0769002437591553
Epoch: 248 Loss: 1.0768775939941406
Epoch: 249 Loss: 1.075770378112793
Epoch: 250 Loss: 1.0745978355407715
Epoch: 251 Loss: 1.0754570960998535
Epoch: 252 Loss: 1.0746948719024658
Epoch: 253 Loss: 1.0732218027114868
Epoch: 254 Loss: 1.0733839273452759
Epoch: 255 Loss: 1.0719608068466187
Epoch: 256 Loss: 1.072079062461853
Epoch: 257 Loss: 1.071616530418396
Epoch: 258 Loss: 1.0706380605697632
Epoch: 259 Loss: 1.0706173181533813
Epoch: 260 Loss: 1.069990634918213

Epoch: 464 Loss: 0.9357263445854187
Epoch: 465 Loss: 0.9347310662269592
Epoch: 466 Loss: 0.9330903887748718
Epoch: 467 Loss: 0.9341756701469421
Epoch: 468 Loss: 0.9320069551467896
Epoch: 469 Loss: 0.930217444896698
Epoch: 470 Loss: 0.9315487742424011
Epoch: 471 Loss: 0.928566575050354
Epoch: 472 Loss: 0.9311053156852722
Epoch: 473 Loss: 0.9272609353065491
Epoch: 474 Loss: 0.9271014928817749
Epoch: 475 Loss: 0.9259443283081055
Epoch: 476 Loss: 0.925812840461731
Epoch: 477 Loss: 0.924372673034668
Epoch: 478 Loss: 0.9241588115692139
Epoch: 479 Loss: 0.9230336546897888
Epoch: 480 Loss: 0.9242668151855469
Epoch: 481 Loss: 0.9208613634109497
Epoch: 482 Loss: 0.9199551939964294
Epoch: 483 Loss: 0.9210118651390076
Epoch: 484 Loss: 0.9187893867492676
Epoch: 485 Loss: 0.9198324084281921
Epoch: 486 Loss: 0.916894257068634
Epoch: 487 Loss: 0.9160569310188293
Epoch: 488 Loss: 0.9162460565567017
Epoch: 489 Loss: 0.9155914187431335
Epoch: 490 Loss: 0.9149048924446106
Epoch: 491 Loss: 0.91507148742675

Epoch: 693 Loss: 0.7160546779632568
Epoch: 694 Loss: 0.7149877548217773
Epoch: 695 Loss: 0.7129070162773132
Epoch: 696 Loss: 0.7130376696586609
Epoch: 697 Loss: 0.7087226510047913
Epoch: 698 Loss: 0.708712637424469
Epoch: 699 Loss: 0.7086120843887329
Epoch: 700 Loss: 0.7084261178970337
Epoch: 701 Loss: 0.7074194550514221
Epoch: 702 Loss: 0.7073612809181213
Epoch: 703 Loss: 0.7039808034896851
Epoch: 704 Loss: 0.7026575207710266
Epoch: 705 Loss: 0.7006410956382751
Epoch: 706 Loss: 0.7013440132141113
Epoch: 707 Loss: 0.7000851035118103
Epoch: 708 Loss: 0.698195219039917
Epoch: 709 Loss: 0.6985461711883545
Epoch: 710 Loss: 0.6969631314277649
Epoch: 711 Loss: 0.6980965733528137
Epoch: 712 Loss: 0.6946825385093689
Epoch: 713 Loss: 0.6947281360626221
Epoch: 714 Loss: 0.6928823590278625
Epoch: 715 Loss: 0.6917062401771545
Epoch: 716 Loss: 0.6893540620803833
Epoch: 717 Loss: 0.6890638470649719
Epoch: 718 Loss: 0.6883246898651123
Epoch: 719 Loss: 0.6886910796165466
Epoch: 720 Loss: 0.68699157238

Epoch: 923 Loss: 0.4852830469608307
Epoch: 924 Loss: 0.48484867811203003
Epoch: 925 Loss: 0.4828505516052246
Epoch: 926 Loss: 0.4830262362957001
Epoch: 927 Loss: 0.48098206520080566
Epoch: 928 Loss: 0.48141157627105713
Epoch: 929 Loss: 0.48000141978263855
Epoch: 930 Loss: 0.4791625738143921
Epoch: 931 Loss: 0.47828519344329834
Epoch: 932 Loss: 0.4773813486099243
Epoch: 933 Loss: 0.4771408140659332
Epoch: 934 Loss: 0.47639819979667664
Epoch: 935 Loss: 0.47618934512138367
Epoch: 936 Loss: 0.47437945008277893
Epoch: 937 Loss: 0.47371840476989746
Epoch: 938 Loss: 0.47258293628692627
Epoch: 939 Loss: 0.47198134660720825
Epoch: 940 Loss: 0.4710565507411957
Epoch: 941 Loss: 0.4705093502998352
Epoch: 942 Loss: 0.46970996260643005
Epoch: 943 Loss: 0.4700033366680145
Epoch: 944 Loss: 0.46833837032318115
Epoch: 945 Loss: 0.4671805202960968
Epoch: 946 Loss: 0.46617591381073
Epoch: 947 Loss: 0.466474711894989
Epoch: 948 Loss: 0.46504345536231995
Epoch: 949 Loss: 0.4640341103076935
Epoch: 950 Loss: 

Epoch: 1150 Loss: 0.3384862244129181
Epoch: 1151 Loss: 0.338246613740921
Epoch: 1152 Loss: 0.3377518057823181
Epoch: 1153 Loss: 0.3373359739780426
Epoch: 1154 Loss: 0.33667632937431335
Epoch: 1155 Loss: 0.3360970616340637
Epoch: 1156 Loss: 0.3357165455818176
Epoch: 1157 Loss: 0.335123211145401
Epoch: 1158 Loss: 0.3345615267753601
Epoch: 1159 Loss: 0.33499598503112793
Epoch: 1160 Loss: 0.33377426862716675
Epoch: 1161 Loss: 0.3332277238368988
Epoch: 1162 Loss: 0.33313918113708496
Epoch: 1163 Loss: 0.33231309056282043
Epoch: 1164 Loss: 0.3315868675708771
Epoch: 1165 Loss: 0.3310849070549011
Epoch: 1166 Loss: 0.3305002450942993
Epoch: 1167 Loss: 0.33037078380584717
Epoch: 1168 Loss: 0.32957279682159424
Epoch: 1169 Loss: 0.329277902841568
Epoch: 1170 Loss: 0.3290802240371704
Epoch: 1171 Loss: 0.32821452617645264
Epoch: 1172 Loss: 0.3278113007545471
Epoch: 1173 Loss: 0.3276754319667816
Epoch: 1174 Loss: 0.3268769085407257
Epoch: 1175 Loss: 0.32619383931159973
Epoch: 1176 Loss: 0.325750559568

Epoch: 1379 Loss: 0.2494128793478012
Epoch: 1380 Loss: 0.24899470806121826
Epoch: 1381 Loss: 0.24915188550949097
Epoch: 1382 Loss: 0.2483641505241394
Epoch: 1383 Loss: 0.24802225828170776
Epoch: 1384 Loss: 0.24826838076114655
Epoch: 1385 Loss: 0.2479567676782608
Epoch: 1386 Loss: 0.24734529852867126
Epoch: 1387 Loss: 0.24686455726623535
Epoch: 1388 Loss: 0.2467280775308609
Epoch: 1389 Loss: 0.24629411101341248
Epoch: 1390 Loss: 0.24654754996299744
Epoch: 1391 Loss: 0.24556811153888702
Epoch: 1392 Loss: 0.24575264751911163
Epoch: 1393 Loss: 0.24542789161205292
Epoch: 1394 Loss: 0.24530132114887238
Epoch: 1395 Loss: 0.24527348577976227
Epoch: 1396 Loss: 0.24432404339313507
Epoch: 1397 Loss: 0.24452835321426392
Epoch: 1398 Loss: 0.24418877065181732
Epoch: 1399 Loss: 0.24349352717399597
Epoch: 1400 Loss: 0.24310483038425446
Epoch: 1401 Loss: 0.24342599511146545
Epoch: 1402 Loss: 0.24297094345092773
Epoch: 1403 Loss: 0.24224035441875458
Epoch: 1404 Loss: 0.24212516844272614
Epoch: 1405 Loss

Epoch: 1609 Loss: 0.19376280903816223
Epoch: 1610 Loss: 0.19359146058559418
Epoch: 1611 Loss: 0.19374045729637146
Epoch: 1612 Loss: 0.19336123764514923
Epoch: 1613 Loss: 0.1931534707546234
Epoch: 1614 Loss: 0.1929088681936264
Epoch: 1615 Loss: 0.19272829592227936
Epoch: 1616 Loss: 0.19259880483150482
Epoch: 1617 Loss: 0.1926836222410202
Epoch: 1618 Loss: 0.1921258419752121
Epoch: 1619 Loss: 0.19192567467689514
Epoch: 1620 Loss: 0.1915798783302307
Epoch: 1621 Loss: 0.19151556491851807
Epoch: 1622 Loss: 0.19122876226902008
Epoch: 1623 Loss: 0.1910504251718521
Epoch: 1624 Loss: 0.19128356873989105
Epoch: 1625 Loss: 0.1908683031797409
Epoch: 1626 Loss: 0.1904984712600708
Epoch: 1627 Loss: 0.19041000306606293
Epoch: 1628 Loss: 0.19036118686199188
Epoch: 1629 Loss: 0.19006437063217163
Epoch: 1630 Loss: 0.1898513287305832
Epoch: 1631 Loss: 0.18964387476444244
Epoch: 1632 Loss: 0.1897091567516327
Epoch: 1633 Loss: 0.18935054540634155
Epoch: 1634 Loss: 0.18904928863048553
Epoch: 1635 Loss: 0.18

Epoch: 1837 Loss: 0.15513856709003448
Epoch: 1838 Loss: 0.15512099862098694
Epoch: 1839 Loss: 0.15479019284248352
Epoch: 1840 Loss: 0.15475432574748993
Epoch: 1841 Loss: 0.15459176898002625
Epoch: 1842 Loss: 0.1544928103685379
Epoch: 1843 Loss: 0.15430259704589844
Epoch: 1844 Loss: 0.1542138010263443
Epoch: 1845 Loss: 0.15404582023620605
Epoch: 1846 Loss: 0.15402069687843323
Epoch: 1847 Loss: 0.15385590493679047
Epoch: 1848 Loss: 0.15371613204479218
Epoch: 1849 Loss: 0.15349648892879486
Epoch: 1850 Loss: 0.15329207479953766
Epoch: 1851 Loss: 0.15313149988651276
Epoch: 1852 Loss: 0.1530849039554596
Epoch: 1853 Loss: 0.15309025347232819
Epoch: 1854 Loss: 0.15277895331382751
Epoch: 1855 Loss: 0.1528526246547699
Epoch: 1856 Loss: 0.15257227420806885
Epoch: 1857 Loss: 0.15238294005393982
Epoch: 1858 Loss: 0.1522604078054428
Epoch: 1859 Loss: 0.15218771994113922
Epoch: 1860 Loss: 0.1519964039325714
Epoch: 1861 Loss: 0.15193051099777222
Epoch: 1862 Loss: 0.15169307589530945
Epoch: 1863 Loss: 

Epoch: 2063 Loss: 0.1278247982263565
Epoch: 2064 Loss: 0.12766648828983307
Epoch: 2065 Loss: 0.12753556668758392
Epoch: 2066 Loss: 0.12754657864570618
Epoch: 2067 Loss: 0.12735521793365479
Epoch: 2068 Loss: 0.12726902961730957
Epoch: 2069 Loss: 0.1271209716796875
Epoch: 2070 Loss: 0.12702256441116333
Epoch: 2071 Loss: 0.1268680989742279
Epoch: 2072 Loss: 0.12674736976623535
Epoch: 2073 Loss: 0.12667782604694366
Epoch: 2074 Loss: 0.12662020325660706
Epoch: 2075 Loss: 0.12644310295581818
Epoch: 2076 Loss: 0.12648504972457886
Epoch: 2077 Loss: 0.1263146996498108
Epoch: 2078 Loss: 0.1261419802904129
Epoch: 2079 Loss: 0.12604224681854248
Epoch: 2080 Loss: 0.12597434222698212
Epoch: 2081 Loss: 0.12587346136569977
Epoch: 2082 Loss: 0.1258019059896469
Epoch: 2083 Loss: 0.12564098834991455
Epoch: 2084 Loss: 0.12555919587612152
Epoch: 2085 Loss: 0.12543262541294098
Epoch: 2086 Loss: 0.1253429800271988
Epoch: 2087 Loss: 0.1252778023481369
Epoch: 2088 Loss: 0.1251889020204544
Epoch: 2089 Loss: 0.1

Epoch: 2295 Loss: 0.10675574839115143
Epoch: 2296 Loss: 0.10671587288379669
Epoch: 2297 Loss: 0.10662904381752014
Epoch: 2298 Loss: 0.10658855736255646
Epoch: 2299 Loss: 0.106516994535923
Epoch: 2300 Loss: 0.10642893612384796
Epoch: 2301 Loss: 0.1063336730003357
Epoch: 2302 Loss: 0.10624126344919205
Epoch: 2303 Loss: 0.10616564750671387
Epoch: 2304 Loss: 0.10603068023920059
Epoch: 2305 Loss: 0.1060151755809784
Epoch: 2306 Loss: 0.10596397519111633
Epoch: 2307 Loss: 0.10585451126098633
Epoch: 2308 Loss: 0.10583542287349701
Epoch: 2309 Loss: 0.10570120066404343
Epoch: 2310 Loss: 0.10556856542825699
Epoch: 2311 Loss: 0.10554815083742142
Epoch: 2312 Loss: 0.10547947138547897
Epoch: 2313 Loss: 0.10537142306566238
Epoch: 2314 Loss: 0.10528701543807983
Epoch: 2315 Loss: 0.10525234788656235
Epoch: 2316 Loss: 0.10516893863677979
Epoch: 2317 Loss: 0.10508767515420914
Epoch: 2318 Loss: 0.10499560087919235
Epoch: 2319 Loss: 0.10491064935922623
Epoch: 2320 Loss: 0.10484758019447327
Epoch: 2321 Loss

Epoch: 2529 Loss: 0.08962082117795944
Epoch: 2530 Loss: 0.08958686143159866
Epoch: 2531 Loss: 0.08950037509202957
Epoch: 2532 Loss: 0.08948889374732971
Epoch: 2533 Loss: 0.08937647193670273
Epoch: 2534 Loss: 0.08937077969312668
Epoch: 2535 Loss: 0.08923477679491043
Epoch: 2536 Loss: 0.0892346203327179
Epoch: 2537 Loss: 0.08908083289861679
Epoch: 2538 Loss: 0.08908028155565262
Epoch: 2539 Loss: 0.0890413373708725
Epoch: 2540 Loss: 0.08894630521535873
Epoch: 2541 Loss: 0.08885446935892105
Epoch: 2542 Loss: 0.08876486122608185
Epoch: 2543 Loss: 0.08875048160552979
Epoch: 2544 Loss: 0.08869381994009018
Epoch: 2545 Loss: 0.08859144151210785
Epoch: 2546 Loss: 0.0885453149676323
Epoch: 2547 Loss: 0.08845604956150055
Epoch: 2548 Loss: 0.08840480446815491
Epoch: 2549 Loss: 0.08832751214504242
Epoch: 2550 Loss: 0.08828280866146088
Epoch: 2551 Loss: 0.08821981400251389
Epoch: 2552 Loss: 0.08812179416418076
Epoch: 2553 Loss: 0.08811111003160477
Epoch: 2554 Loss: 0.08798735588788986
Epoch: 2555 Los

Epoch: 2757 Loss: 0.07633712142705917
Epoch: 2758 Loss: 0.07620761543512344
Epoch: 2759 Loss: 0.07617665827274323
Epoch: 2760 Loss: 0.07616022229194641
Epoch: 2761 Loss: 0.0760614201426506
Epoch: 2762 Loss: 0.07600770145654678
Epoch: 2763 Loss: 0.0759783610701561
Epoch: 2764 Loss: 0.07592297345399857
Epoch: 2765 Loss: 0.07587803900241852
Epoch: 2766 Loss: 0.07580304145812988
Epoch: 2767 Loss: 0.07574243098497391
Epoch: 2768 Loss: 0.07571063935756683
Epoch: 2769 Loss: 0.07565588504076004
Epoch: 2770 Loss: 0.07559127360582352
Epoch: 2771 Loss: 0.07555130869150162
Epoch: 2772 Loss: 0.07550982385873795
Epoch: 2773 Loss: 0.07542743533849716
Epoch: 2774 Loss: 0.07540787011384964
Epoch: 2775 Loss: 0.07534832507371902
Epoch: 2776 Loss: 0.07531634718179703
Epoch: 2777 Loss: 0.07523971796035767
Epoch: 2778 Loss: 0.07524499297142029
Epoch: 2779 Loss: 0.07513639330863953
Epoch: 2780 Loss: 0.07510247081518173
Epoch: 2781 Loss: 0.07508388161659241
Epoch: 2782 Loss: 0.07497795671224594
Epoch: 2783 Lo

Epoch: 2994 Loss: 0.0652720257639885
Epoch: 2995 Loss: 0.06522078812122345
Epoch: 2996 Loss: 0.06520145386457443
Epoch: 2997 Loss: 0.06514319032430649
Epoch: 2998 Loss: 0.065123051404953
Epoch: 2999 Loss: 0.06507506966590881
Epoch: 3000 Loss: 0.06503170728683472
Epoch: 3001 Loss: 0.0649731457233429
Epoch: 3002 Loss: 0.06496724486351013
Epoch: 3003 Loss: 0.06489584594964981
Epoch: 3004 Loss: 0.06485655158758163
Epoch: 3005 Loss: 0.06483665108680725
Epoch: 3006 Loss: 0.06476350873708725
Epoch: 3007 Loss: 0.06475469470024109
Epoch: 3008 Loss: 0.06471019238233566
Epoch: 3009 Loss: 0.06464783102273941
Epoch: 3010 Loss: 0.06461190432310104
Epoch: 3011 Loss: 0.06459228694438934
Epoch: 3012 Loss: 0.06456786394119263
Epoch: 3013 Loss: 0.06448039412498474
Epoch: 3014 Loss: 0.06444934010505676
Epoch: 3015 Loss: 0.06440253555774689
Epoch: 3016 Loss: 0.0643620640039444
Epoch: 3017 Loss: 0.06433247029781342
Epoch: 3018 Loss: 0.06427686661481857
Epoch: 3019 Loss: 0.06426727026700974
Epoch: 3020 Loss:

Epoch: 3224 Loss: 0.0564863458275795
Epoch: 3225 Loss: 0.05647127702832222
Epoch: 3226 Loss: 0.05641339346766472
Epoch: 3227 Loss: 0.05639016628265381
Epoch: 3228 Loss: 0.056342095136642456
Epoch: 3229 Loss: 0.05631621927022934
Epoch: 3230 Loss: 0.05628373846411705
Epoch: 3231 Loss: 0.05624439939856529
Epoch: 3232 Loss: 0.056192103773355484
Epoch: 3233 Loss: 0.056160956621170044
Epoch: 3234 Loss: 0.056152455508708954
Epoch: 3235 Loss: 0.056101854890584946
Epoch: 3236 Loss: 0.05607079714536667
Epoch: 3237 Loss: 0.05602863430976868
Epoch: 3238 Loss: 0.05601292848587036
Epoch: 3239 Loss: 0.055970508605241776
Epoch: 3240 Loss: 0.05594473332166672
Epoch: 3241 Loss: 0.0559040792286396
Epoch: 3242 Loss: 0.05589299649000168
Epoch: 3243 Loss: 0.05583217367529869
Epoch: 3244 Loss: 0.0557854101061821
Epoch: 3245 Loss: 0.055793941020965576
Epoch: 3246 Loss: 0.0557340532541275
Epoch: 3247 Loss: 0.05568315088748932
Epoch: 3248 Loss: 0.055672500282526016
Epoch: 3249 Loss: 0.055627092719078064
Epoch: 

Epoch: 3453 Loss: 0.04921836033463478
Epoch: 3454 Loss: 0.049183811992406845
Epoch: 3455 Loss: 0.049172379076480865
Epoch: 3456 Loss: 0.04912332445383072
Epoch: 3457 Loss: 0.04910844564437866
Epoch: 3458 Loss: 0.04909813031554222
Epoch: 3459 Loss: 0.049071960151195526
Epoch: 3460 Loss: 0.04901393875479698
Epoch: 3461 Loss: 0.048987820744514465
Epoch: 3462 Loss: 0.04896266385912895
Epoch: 3463 Loss: 0.048933014273643494
Epoch: 3464 Loss: 0.0488925576210022
Epoch: 3465 Loss: 0.0488772876560688
Epoch: 3466 Loss: 0.04884522035717964
Epoch: 3467 Loss: 0.04881761595606804
Epoch: 3468 Loss: 0.048803724348545074
Epoch: 3469 Loss: 0.04877346754074097
Epoch: 3470 Loss: 0.04873742535710335
Epoch: 3471 Loss: 0.04869253188371658
Epoch: 3472 Loss: 0.04867546260356903
Epoch: 3473 Loss: 0.04864583909511566
Epoch: 3474 Loss: 0.04863196611404419
Epoch: 3475 Loss: 0.048601891845464706
Epoch: 3476 Loss: 0.048544980585575104
Epoch: 3477 Loss: 0.048545870929956436
Epoch: 3478 Loss: 0.04851758852601051
Epoch

Epoch: 3678 Loss: 0.04330158606171608
Epoch: 3679 Loss: 0.0433051735162735
Epoch: 3680 Loss: 0.04329291731119156
Epoch: 3681 Loss: 0.04323983192443848
Epoch: 3682 Loss: 0.04323555901646614
Epoch: 3683 Loss: 0.043224822729825974
Epoch: 3684 Loss: 0.04317733645439148
Epoch: 3685 Loss: 0.04318748414516449
Epoch: 3686 Loss: 0.04314253106713295
Epoch: 3687 Loss: 0.04309610649943352
Epoch: 3688 Loss: 0.043092675507068634
Epoch: 3689 Loss: 0.04307383671402931
Epoch: 3690 Loss: 0.043042007833719254
Epoch: 3691 Loss: 0.04303300008177757
Epoch: 3692 Loss: 0.04299953579902649
Epoch: 3693 Loss: 0.04297053441405296
Epoch: 3694 Loss: 0.04296097904443741
Epoch: 3695 Loss: 0.042925093322992325
Epoch: 3696 Loss: 0.04289475083351135
Epoch: 3697 Loss: 0.042900942265987396
Epoch: 3698 Loss: 0.04285317659378052
Epoch: 3699 Loss: 0.04284612089395523
Epoch: 3700 Loss: 0.04282191023230553
Epoch: 3701 Loss: 0.042767077684402466
Epoch: 3702 Loss: 0.04279349371790886
Epoch: 3703 Loss: 0.04273911938071251
Epoch: 

Epoch: 3897 Loss: 0.03855627775192261
Epoch: 3898 Loss: 0.03853856027126312
Epoch: 3899 Loss: 0.038534946739673615
Epoch: 3900 Loss: 0.03849153220653534
Epoch: 3901 Loss: 0.03847488760948181
Epoch: 3902 Loss: 0.038472115993499756
Epoch: 3903 Loss: 0.03844344615936279
Epoch: 3904 Loss: 0.038421742618083954
Epoch: 3905 Loss: 0.0384010449051857
Epoch: 3906 Loss: 0.03838338702917099
Epoch: 3907 Loss: 0.038347575813531876
Epoch: 3908 Loss: 0.03834519535303116
Epoch: 3909 Loss: 0.03833113610744476
Epoch: 3910 Loss: 0.03831006959080696
Epoch: 3911 Loss: 0.038277726620435715
Epoch: 3912 Loss: 0.03826095163822174
Epoch: 3913 Loss: 0.03824489191174507
Epoch: 3914 Loss: 0.03824388608336449
Epoch: 3915 Loss: 0.038190294057130814
Epoch: 3916 Loss: 0.038194380700588226
Epoch: 3917 Loss: 0.03813725337386131
Epoch: 3918 Loss: 0.03815467283129692
Epoch: 3919 Loss: 0.03812309354543686
Epoch: 3920 Loss: 0.03809506446123123
Epoch: 3921 Loss: 0.03809181600809097
Epoch: 3922 Loss: 0.038052354007959366
Epoch

Epoch: 4114 Loss: 0.03450794517993927
Epoch: 4115 Loss: 0.03448285534977913
Epoch: 4116 Loss: 0.034496672451496124
Epoch: 4117 Loss: 0.03446910157799721
Epoch: 4118 Loss: 0.03443210944533348
Epoch: 4119 Loss: 0.034427352249622345
Epoch: 4120 Loss: 0.03442871943116188
Epoch: 4121 Loss: 0.03439464047551155
Epoch: 4122 Loss: 0.034376807510852814
Epoch: 4123 Loss: 0.03435951843857765
Epoch: 4124 Loss: 0.03432612493634224
Epoch: 4125 Loss: 0.03432537615299225
Epoch: 4126 Loss: 0.03430841118097305
Epoch: 4127 Loss: 0.03429540991783142
Epoch: 4128 Loss: 0.03428221121430397
Epoch: 4129 Loss: 0.034253332763910294
Epoch: 4130 Loss: 0.03424909710884094
Epoch: 4131 Loss: 0.03423842415213585
Epoch: 4132 Loss: 0.03419266641139984
Epoch: 4133 Loss: 0.03418863192200661
Epoch: 4134 Loss: 0.034173235297203064
Epoch: 4135 Loss: 0.03415851667523384
Epoch: 4136 Loss: 0.03414078801870346
Epoch: 4137 Loss: 0.0341218039393425
Epoch: 4138 Loss: 0.034102410078048706
Epoch: 4139 Loss: 0.03406967222690582
Epoch: 

Epoch: 4332 Loss: 0.03101331926882267
Epoch: 4333 Loss: 0.030994996428489685
Epoch: 4334 Loss: 0.030985388904809952
Epoch: 4335 Loss: 0.03096829354763031
Epoch: 4336 Loss: 0.030942194163799286
Epoch: 4337 Loss: 0.03093201480805874
Epoch: 4338 Loss: 0.030930981040000916
Epoch: 4339 Loss: 0.030891448259353638
Epoch: 4340 Loss: 0.0308902096003294
Epoch: 4341 Loss: 0.030866069719195366
Epoch: 4342 Loss: 0.03086020052433014
Epoch: 4343 Loss: 0.03085469827055931
Epoch: 4344 Loss: 0.030829066410660744
Epoch: 4345 Loss: 0.030818630009889603
Epoch: 4346 Loss: 0.030804134905338287
Epoch: 4347 Loss: 0.030791085213422775
Epoch: 4348 Loss: 0.030780671164393425
Epoch: 4349 Loss: 0.030738545581698418
Epoch: 4350 Loss: 0.030743664130568504
Epoch: 4351 Loss: 0.03073381446301937
Epoch: 4352 Loss: 0.030702106654644012
Epoch: 4353 Loss: 0.03071211650967598
Epoch: 4354 Loss: 0.030677620321512222
Epoch: 4355 Loss: 0.030671466141939163
Epoch: 4356 Loss: 0.030662231147289276
Epoch: 4357 Loss: 0.03064425848424

Epoch: 4551 Loss: 0.02802441269159317
Epoch: 4552 Loss: 0.02802400104701519
Epoch: 4553 Loss: 0.02800881117582321
Epoch: 4554 Loss: 0.027989143505692482
Epoch: 4555 Loss: 0.027970699593424797
Epoch: 4556 Loss: 0.027974087744951248
Epoch: 4557 Loss: 0.02795853465795517
Epoch: 4558 Loss: 0.027951819822192192
Epoch: 4559 Loss: 0.02792777307331562
Epoch: 4560 Loss: 0.027921311557292938
Epoch: 4561 Loss: 0.027918975800275803
Epoch: 4562 Loss: 0.02789241634309292
Epoch: 4563 Loss: 0.02787201479077339
Epoch: 4564 Loss: 0.027877096086740494
Epoch: 4565 Loss: 0.02785651385784149
Epoch: 4566 Loss: 0.027840329334139824
Epoch: 4567 Loss: 0.027832480147480965
Epoch: 4568 Loss: 0.027818335220217705
Epoch: 4569 Loss: 0.027817388996481895
Epoch: 4570 Loss: 0.027798190712928772
Epoch: 4571 Loss: 0.0277748741209507
Epoch: 4572 Loss: 0.027780015021562576
Epoch: 4573 Loss: 0.02774706669151783
Epoch: 4574 Loss: 0.027748407796025276
Epoch: 4575 Loss: 0.027745358645915985
Epoch: 4576 Loss: 0.0277198031544685

Epoch: 4776 Loss: 0.025430969893932343
Epoch: 4777 Loss: 0.025423970073461533
Epoch: 4778 Loss: 0.025415271520614624
Epoch: 4779 Loss: 0.025398042052984238
Epoch: 4780 Loss: 0.02538481540977955
Epoch: 4781 Loss: 0.02537926472723484
Epoch: 4782 Loss: 0.025369346141815186
Epoch: 4783 Loss: 0.02534906566143036
Epoch: 4784 Loss: 0.025344958528876305
Epoch: 4785 Loss: 0.025346456095576286
Epoch: 4786 Loss: 0.025317581370472908
Epoch: 4787 Loss: 0.025314172729849815
Epoch: 4788 Loss: 0.02530021220445633
Epoch: 4789 Loss: 0.02528516761958599
Epoch: 4790 Loss: 0.025282535701990128
Epoch: 4791 Loss: 0.02526235766708851
Epoch: 4792 Loss: 0.025265181437134743
Epoch: 4793 Loss: 0.025247648358345032
Epoch: 4794 Loss: 0.025241902098059654
Epoch: 4795 Loss: 0.025220220908522606
Epoch: 4796 Loss: 0.025212563574314117
Epoch: 4797 Loss: 0.025201547890901566
Epoch: 4798 Loss: 0.02520003728568554
Epoch: 4799 Loss: 0.025180814787745476
Epoch: 4800 Loss: 0.02516951598227024
Epoch: 4801 Loss: 0.0251677874475

Epoch: 5008 Loss: 0.023124149069190025
Epoch: 5009 Loss: 0.02311285212635994
Epoch: 5010 Loss: 0.0231096763163805
Epoch: 5011 Loss: 0.023096919059753418
Epoch: 5012 Loss: 0.023090006783604622
Epoch: 5013 Loss: 0.023077109828591347
Epoch: 5014 Loss: 0.02306460402905941
Epoch: 5015 Loss: 0.02305598184466362
Epoch: 5016 Loss: 0.02304019220173359
Epoch: 5017 Loss: 0.023040523752570152
Epoch: 5018 Loss: 0.02302410826086998
Epoch: 5019 Loss: 0.023024171590805054
Epoch: 5020 Loss: 0.023017408326268196
Epoch: 5021 Loss: 0.022999770939350128
Epoch: 5022 Loss: 0.023006943985819817
Epoch: 5023 Loss: 0.02299128659069538
Epoch: 5024 Loss: 0.022977255284786224
Epoch: 5025 Loss: 0.02297382615506649
Epoch: 5026 Loss: 0.022954540327191353
Epoch: 5027 Loss: 0.022951487451791763
Epoch: 5028 Loss: 0.022944577038288116
Epoch: 5029 Loss: 0.02293766848742962
Epoch: 5030 Loss: 0.022924916818737984
Epoch: 5031 Loss: 0.022916100919246674
Epoch: 5032 Loss: 0.02290939725935459
Epoch: 5033 Loss: 0.0228947680443525

Epoch: 5230 Loss: 0.02124420367181301
Epoch: 5231 Loss: 0.021235058084130287
Epoch: 5232 Loss: 0.021220644935965538
Epoch: 5233 Loss: 0.02122354879975319
Epoch: 5234 Loss: 0.021207617595791817
Epoch: 5235 Loss: 0.021209081634879112
Epoch: 5236 Loss: 0.02119465172290802
Epoch: 5237 Loss: 0.021181447431445122
Epoch: 5238 Loss: 0.02118297480046749
Epoch: 5239 Loss: 0.02116125263273716
Epoch: 5240 Loss: 0.021161988377571106
Epoch: 5241 Loss: 0.021153854206204414
Epoch: 5242 Loss: 0.021139755845069885
Epoch: 5243 Loss: 0.021144066005945206
Epoch: 5244 Loss: 0.02112978883087635
Epoch: 5245 Loss: 0.021127082407474518
Epoch: 5246 Loss: 0.02111595869064331
Epoch: 5247 Loss: 0.021111521869897842
Epoch: 5248 Loss: 0.021095318719744682
Epoch: 5249 Loss: 0.021085480228066444
Epoch: 5250 Loss: 0.02108636125922203
Epoch: 5251 Loss: 0.021068058907985687
Epoch: 5252 Loss: 0.021073026582598686
Epoch: 5253 Loss: 0.021054591983556747
Epoch: 5254 Loss: 0.02105756476521492
Epoch: 5255 Loss: 0.02104451321065

Epoch: 5458 Loss: 0.019555995240807533
Epoch: 5459 Loss: 0.019551072269678116
Epoch: 5460 Loss: 0.019538166001439095
Epoch: 5461 Loss: 0.019539695233106613
Epoch: 5462 Loss: 0.0195203498005867
Epoch: 5463 Loss: 0.01952405832707882
Epoch: 5464 Loss: 0.01952180080115795
Epoch: 5465 Loss: 0.01951626129448414
Epoch: 5466 Loss: 0.01950380951166153
Epoch: 5467 Loss: 0.019499866291880608
Epoch: 5468 Loss: 0.01948523335158825
Epoch: 5469 Loss: 0.019480472430586815
Epoch: 5470 Loss: 0.01946893520653248
Epoch: 5471 Loss: 0.019469982013106346
Epoch: 5472 Loss: 0.019455652683973312
Epoch: 5473 Loss: 0.01945907063782215
Epoch: 5474 Loss: 0.019456319510936737
Epoch: 5475 Loss: 0.019437259063124657
Epoch: 5476 Loss: 0.019433867186307907
Epoch: 5477 Loss: 0.019427362829446793
Epoch: 5478 Loss: 0.019419703632593155
Epoch: 5479 Loss: 0.019419627264142036
Epoch: 5480 Loss: 0.019403574988245964
Epoch: 5481 Loss: 0.019406359642744064
Epoch: 5482 Loss: 0.01939154975116253
Epoch: 5483 Loss: 0.019381679594516

Epoch: 5692 Loss: 0.018059948459267616
Epoch: 5693 Loss: 0.01804395578801632
Epoch: 5694 Loss: 0.018048491328954697
Epoch: 5695 Loss: 0.01803467981517315
Epoch: 5696 Loss: 0.018024088814854622
Epoch: 5697 Loss: 0.01802702061831951
Epoch: 5698 Loss: 0.018020259216427803
Epoch: 5699 Loss: 0.0180144514888525
Epoch: 5700 Loss: 0.018012629821896553
Epoch: 5701 Loss: 0.018002178519964218
Epoch: 5702 Loss: 0.017998896539211273
Epoch: 5703 Loss: 0.01799062080681324
Epoch: 5704 Loss: 0.017977965995669365
Epoch: 5705 Loss: 0.017974095419049263
Epoch: 5706 Loss: 0.01797020062804222
Epoch: 5707 Loss: 0.017963742837309837
Epoch: 5708 Loss: 0.017954744398593903
Epoch: 5709 Loss: 0.017955169081687927
Epoch: 5710 Loss: 0.01794915460050106
Epoch: 5711 Loss: 0.01793864741921425
Epoch: 5712 Loss: 0.017941029742360115
Epoch: 5713 Loss: 0.017927054315805435
Epoch: 5714 Loss: 0.01791740208864212
Epoch: 5715 Loss: 0.017919311299920082
Epoch: 5716 Loss: 0.01790444552898407
Epoch: 5717 Loss: 0.0179084129631519

Epoch: 5921 Loss: 0.016751818358898163
Epoch: 5922 Loss: 0.016756653785705566
Epoch: 5923 Loss: 0.016740040853619576
Epoch: 5924 Loss: 0.016737204045057297
Epoch: 5925 Loss: 0.016731882467865944
Epoch: 5926 Loss: 0.016729101538658142
Epoch: 5927 Loss: 0.016721906140446663
Epoch: 5928 Loss: 0.016710305586457253
Epoch: 5929 Loss: 0.016709977760910988
Epoch: 5930 Loss: 0.016706807538866997
Epoch: 5931 Loss: 0.01670309714972973
Epoch: 5932 Loss: 0.016695845872163773
Epoch: 5933 Loss: 0.016684118658304214
Epoch: 5934 Loss: 0.016685951501131058
Epoch: 5935 Loss: 0.01667878031730652
Epoch: 5936 Loss: 0.016668392345309258
Epoch: 5937 Loss: 0.01667308621108532
Epoch: 5938 Loss: 0.016665983945131302
Epoch: 5939 Loss: 0.016662387177348137
Epoch: 5940 Loss: 0.016646603122353554
Epoch: 5941 Loss: 0.016655277460813522
Epoch: 5942 Loss: 0.016639186069369316
Epoch: 5943 Loss: 0.01663653925061226
Epoch: 5944 Loss: 0.016636088490486145
Epoch: 5945 Loss: 0.0166319552809
Epoch: 5946 Loss: 0.01661190018057

Epoch: 6146 Loss: 0.015623723156750202
Epoch: 6147 Loss: 0.015616869553923607
Epoch: 6148 Loss: 0.015612672083079815
Epoch: 6149 Loss: 0.01560975331813097
Epoch: 6150 Loss: 0.015606150962412357
Epoch: 6151 Loss: 0.015596314333379269
Epoch: 6152 Loss: 0.015594135038554668
Epoch: 6153 Loss: 0.015585945919156075
Epoch: 6154 Loss: 0.015586042776703835
Epoch: 6155 Loss: 0.015581046231091022
Epoch: 6156 Loss: 0.015574844554066658
Epoch: 6157 Loss: 0.015571017749607563
Epoch: 6158 Loss: 0.015567534603178501
Epoch: 6159 Loss: 0.015558273531496525
Epoch: 6160 Loss: 0.015561145730316639
Epoch: 6161 Loss: 0.015549326315522194
Epoch: 6162 Loss: 0.015550363808870316
Epoch: 6163 Loss: 0.015544939786195755
Epoch: 6164 Loss: 0.015531564131379128
Epoch: 6165 Loss: 0.015535936690866947
Epoch: 6166 Loss: 0.015521916560828686
Epoch: 6167 Loss: 0.015524025075137615
Epoch: 6168 Loss: 0.015517541207373142
Epoch: 6169 Loss: 0.01551157608628273
Epoch: 6170 Loss: 0.015517983585596085
Epoch: 6171 Loss: 0.0155094

Epoch: 6363 Loss: 0.014648516662418842
Epoch: 6364 Loss: 0.014647716656327248
Epoch: 6365 Loss: 0.014639166183769703
Epoch: 6366 Loss: 0.014637434855103493
Epoch: 6367 Loss: 0.014629475772380829
Epoch: 6368 Loss: 0.014627312310039997
Epoch: 6369 Loss: 0.014625532552599907
Epoch: 6370 Loss: 0.014620319940149784
Epoch: 6371 Loss: 0.014619025401771069
Epoch: 6372 Loss: 0.014611893333494663
Epoch: 6373 Loss: 0.014611639082431793
Epoch: 6374 Loss: 0.014597388915717602
Epoch: 6375 Loss: 0.014600465074181557
Epoch: 6376 Loss: 0.01458740048110485
Epoch: 6377 Loss: 0.014590180478990078
Epoch: 6378 Loss: 0.014584089629352093
Epoch: 6379 Loss: 0.014581497758626938
Epoch: 6380 Loss: 0.014579457230865955
Epoch: 6381 Loss: 0.014571878127753735
Epoch: 6382 Loss: 0.014572910033166409
Epoch: 6383 Loss: 0.014564672484993935
Epoch: 6384 Loss: 0.014562253840267658
Epoch: 6385 Loss: 0.014555297791957855
Epoch: 6386 Loss: 0.014550821855664253
Epoch: 6387 Loss: 0.014549577608704567
Epoch: 6388 Loss: 0.014549

Epoch: 6590 Loss: 0.013725397177040577
Epoch: 6591 Loss: 0.013727542944252491
Epoch: 6592 Loss: 0.013723383657634258
Epoch: 6593 Loss: 0.013721131719648838
Epoch: 6594 Loss: 0.013716669753193855
Epoch: 6595 Loss: 0.01371318381279707
Epoch: 6596 Loss: 0.013707313686609268
Epoch: 6597 Loss: 0.013702659867703915
Epoch: 6598 Loss: 0.013699376955628395
Epoch: 6599 Loss: 0.013691329397261143
Epoch: 6600 Loss: 0.0136897973716259
Epoch: 6601 Loss: 0.013693539425730705
Epoch: 6602 Loss: 0.013685053214430809
Epoch: 6603 Loss: 0.013680726289749146
Epoch: 6604 Loss: 0.013678097166121006
Epoch: 6605 Loss: 0.013674041256308556
Epoch: 6606 Loss: 0.01367000862956047
Epoch: 6607 Loss: 0.013662480749189854
Epoch: 6608 Loss: 0.01366428378969431
Epoch: 6609 Loss: 0.013651261106133461
Epoch: 6610 Loss: 0.013652254827320576
Epoch: 6611 Loss: 0.01365151908248663
Epoch: 6612 Loss: 0.01364744734019041
Epoch: 6613 Loss: 0.013642419129610062
Epoch: 6614 Loss: 0.013642163947224617
Epoch: 6615 Loss: 0.013635191135

Epoch: 6812 Loss: 0.012920289300382137
Epoch: 6813 Loss: 0.01291855052113533
Epoch: 6814 Loss: 0.012910999357700348
Epoch: 6815 Loss: 0.012912956066429615
Epoch: 6816 Loss: 0.01290811412036419
Epoch: 6817 Loss: 0.01290250476449728
Epoch: 6818 Loss: 0.01289973221719265
Epoch: 6819 Loss: 0.012897228822112083
Epoch: 6820 Loss: 0.012893661856651306
Epoch: 6821 Loss: 0.012892208993434906
Epoch: 6822 Loss: 0.012884020805358887
Epoch: 6823 Loss: 0.012880561873316765
Epoch: 6824 Loss: 0.012879295274615288
Epoch: 6825 Loss: 0.012875374406576157
Epoch: 6826 Loss: 0.012870498932898045
Epoch: 6827 Loss: 0.012869009748101234
Epoch: 6828 Loss: 0.012867572717368603
Epoch: 6829 Loss: 0.012862615287303925
Epoch: 6830 Loss: 0.012858647853136063
Epoch: 6831 Loss: 0.012859704904258251
Epoch: 6832 Loss: 0.012851349078118801
Epoch: 6833 Loss: 0.0128447525203228
Epoch: 6834 Loss: 0.012848790735006332
Epoch: 6835 Loss: 0.012842528522014618
Epoch: 6836 Loss: 0.012840581126511097
Epoch: 6837 Loss: 0.01283118128

Epoch: 7029 Loss: 0.012206116691231728
Epoch: 7030 Loss: 0.012204141356050968
Epoch: 7031 Loss: 0.012201428413391113
Epoch: 7032 Loss: 0.012195548042654991
Epoch: 7033 Loss: 0.012195627205073833
Epoch: 7034 Loss: 0.01219095941632986
Epoch: 7035 Loss: 0.0121897729113698
Epoch: 7036 Loss: 0.012186652049422264
Epoch: 7037 Loss: 0.012184560298919678
Epoch: 7038 Loss: 0.012179519049823284
Epoch: 7039 Loss: 0.012178524397313595
Epoch: 7040 Loss: 0.01217470970004797
Epoch: 7041 Loss: 0.01217047031968832
Epoch: 7042 Loss: 0.012163612991571426
Epoch: 7043 Loss: 0.012160932645201683
Epoch: 7044 Loss: 0.012162433005869389
Epoch: 7045 Loss: 0.012155238538980484
Epoch: 7046 Loss: 0.01215599849820137
Epoch: 7047 Loss: 0.012150776572525501
Epoch: 7048 Loss: 0.012150784023106098
Epoch: 7049 Loss: 0.012146546505391598
Epoch: 7050 Loss: 0.012143425643444061
Epoch: 7051 Loss: 0.012137660756707191
Epoch: 7052 Loss: 0.012137176468968391
Epoch: 7053 Loss: 0.012131666764616966
Epoch: 7054 Loss: 0.01212816033

Epoch: 7247 Loss: 0.011556421406567097
Epoch: 7248 Loss: 0.011555311270058155
Epoch: 7249 Loss: 0.01155197061598301
Epoch: 7250 Loss: 0.011548751033842564
Epoch: 7251 Loss: 0.011548126116394997
Epoch: 7252 Loss: 0.011541886255145073
Epoch: 7253 Loss: 0.011538490653038025
Epoch: 7254 Loss: 0.011534390039741993
Epoch: 7255 Loss: 0.011531691998243332
Epoch: 7256 Loss: 0.011530510149896145
Epoch: 7257 Loss: 0.011528615839779377
Epoch: 7258 Loss: 0.011526796035468578
Epoch: 7259 Loss: 0.011525196023285389
Epoch: 7260 Loss: 0.011516714468598366
Epoch: 7261 Loss: 0.011516747064888477
Epoch: 7262 Loss: 0.011513330042362213
Epoch: 7263 Loss: 0.011510885320603848
Epoch: 7264 Loss: 0.011508173309266567
Epoch: 7265 Loss: 0.011504068970680237
Epoch: 7266 Loss: 0.011502555571496487
Epoch: 7267 Loss: 0.011495331302285194
Epoch: 7268 Loss: 0.01149759255349636
Epoch: 7269 Loss: 0.011493048630654812
Epoch: 7270 Loss: 0.011490961536765099
Epoch: 7271 Loss: 0.011486063711345196
Epoch: 7272 Loss: 0.0114849

Epoch: 7477 Loss: 0.010925228707492352
Epoch: 7478 Loss: 0.010921364650130272
Epoch: 7479 Loss: 0.010918374173343182
Epoch: 7480 Loss: 0.01091738324612379
Epoch: 7481 Loss: 0.010915214195847511
Epoch: 7482 Loss: 0.010913640260696411
Epoch: 7483 Loss: 0.010906992480158806
Epoch: 7484 Loss: 0.010907614603638649
Epoch: 7485 Loss: 0.010901308618485928
Epoch: 7486 Loss: 0.010904116556048393
Epoch: 7487 Loss: 0.01089754980057478
Epoch: 7488 Loss: 0.010897226631641388
Epoch: 7489 Loss: 0.010894428938627243
Epoch: 7490 Loss: 0.010892399586737156
Epoch: 7491 Loss: 0.010887104086577892
Epoch: 7492 Loss: 0.010884977877140045
Epoch: 7493 Loss: 0.01088007353246212
Epoch: 7494 Loss: 0.010882176458835602
Epoch: 7495 Loss: 0.010878287255764008
Epoch: 7496 Loss: 0.010873678140342236
Epoch: 7497 Loss: 0.010874530300498009
Epoch: 7498 Loss: 0.010870161466300488
Epoch: 7499 Loss: 0.010867386125028133
Epoch: 7500 Loss: 0.01086373720318079
Epoch: 7501 Loss: 0.010863297618925571
Epoch: 7502 Loss: 0.010858768

Epoch: 7695 Loss: 0.010378210805356503
Epoch: 7696 Loss: 0.010377371683716774
Epoch: 7697 Loss: 0.010376145131886005
Epoch: 7698 Loss: 0.010369601659476757
Epoch: 7699 Loss: 0.010370178148150444
Epoch: 7700 Loss: 0.010366469621658325
Epoch: 7701 Loss: 0.010363390669226646
Epoch: 7702 Loss: 0.01036261860281229
Epoch: 7703 Loss: 0.01035712193697691
Epoch: 7704 Loss: 0.010355894453823566
Epoch: 7705 Loss: 0.01035622414201498
Epoch: 7706 Loss: 0.010352535173296928
Epoch: 7707 Loss: 0.010349649004638195
Epoch: 7708 Loss: 0.01034930907189846
Epoch: 7709 Loss: 0.010344191454350948
Epoch: 7710 Loss: 0.010345655493438244
Epoch: 7711 Loss: 0.010341492481529713
Epoch: 7712 Loss: 0.010335741564631462
Epoch: 7713 Loss: 0.010336863808333874
Epoch: 7714 Loss: 0.010328592732548714
Epoch: 7715 Loss: 0.010332542471587658
Epoch: 7716 Loss: 0.010327908210456371
Epoch: 7717 Loss: 0.01032823882997036
Epoch: 7718 Loss: 0.01032449770718813
Epoch: 7719 Loss: 0.0103223267942667
Epoch: 7720 Loss: 0.0103172799572

Epoch: 7922 Loss: 0.009860793128609657
Epoch: 7923 Loss: 0.009858358651399612
Epoch: 7924 Loss: 0.009854636155068874
Epoch: 7925 Loss: 0.009854044765233994
Epoch: 7926 Loss: 0.009852058254182339
Epoch: 7927 Loss: 0.009849149733781815
Epoch: 7928 Loss: 0.009843491949141026
Epoch: 7929 Loss: 0.009843594394624233
Epoch: 7930 Loss: 0.009843842126429081
Epoch: 7931 Loss: 0.009841362945735455
Epoch: 7932 Loss: 0.009837010875344276
Epoch: 7933 Loss: 0.009835745207965374
Epoch: 7934 Loss: 0.009833374060690403
Epoch: 7935 Loss: 0.009830431081354618
Epoch: 7936 Loss: 0.009827980771660805
Epoch: 7937 Loss: 0.009828485548496246
Epoch: 7938 Loss: 0.009824525564908981
Epoch: 7939 Loss: 0.009822606109082699
Epoch: 7940 Loss: 0.0098216962069273
Epoch: 7941 Loss: 0.009818164631724358
Epoch: 7942 Loss: 0.009815463796257973
Epoch: 7943 Loss: 0.0098130377009511
Epoch: 7944 Loss: 0.009810592979192734
Epoch: 7945 Loss: 0.009807784110307693
Epoch: 7946 Loss: 0.009805641137063503
Epoch: 7947 Loss: 0.009807494

Epoch: 8154 Loss: 0.00937653798609972
Epoch: 8155 Loss: 0.009372022934257984
Epoch: 8156 Loss: 0.009368532337248325
Epoch: 8157 Loss: 0.009368368424475193
Epoch: 8158 Loss: 0.009367772378027439
Epoch: 8159 Loss: 0.00936368852853775
Epoch: 8160 Loss: 0.009364339523017406
Epoch: 8161 Loss: 0.009356887079775333
Epoch: 8162 Loss: 0.009360057301819324
Epoch: 8163 Loss: 0.00935930572450161
Epoch: 8164 Loss: 0.009355871938169003
Epoch: 8165 Loss: 0.009352611377835274
Epoch: 8166 Loss: 0.009353701956570148
Epoch: 8167 Loss: 0.009347238577902317
Epoch: 8168 Loss: 0.009347022511065006
Epoch: 8169 Loss: 0.009345192462205887
Epoch: 8170 Loss: 0.00934327207505703
Epoch: 8171 Loss: 0.009341024793684483
Epoch: 8172 Loss: 0.009339911863207817
Epoch: 8173 Loss: 0.00933829601854086
Epoch: 8174 Loss: 0.009332106448709965
Epoch: 8175 Loss: 0.009335930459201336
Epoch: 8176 Loss: 0.009331364184617996
Epoch: 8177 Loss: 0.009328668005764484
Epoch: 8178 Loss: 0.009326968342065811
Epoch: 8179 Loss: 0.0093250311

Epoch: 8378 Loss: 0.008947393856942654
Epoch: 8379 Loss: 0.00894198939204216
Epoch: 8380 Loss: 0.008943016640841961
Epoch: 8381 Loss: 0.008941993117332458
Epoch: 8382 Loss: 0.00893646851181984
Epoch: 8383 Loss: 0.008936138823628426
Epoch: 8384 Loss: 0.008935471065342426
Epoch: 8385 Loss: 0.008932594209909439
Epoch: 8386 Loss: 0.00892980583012104
Epoch: 8387 Loss: 0.00892753154039383
Epoch: 8388 Loss: 0.008926843293011189
Epoch: 8389 Loss: 0.008925186470150948
Epoch: 8390 Loss: 0.008923332206904888
Epoch: 8391 Loss: 0.00892181321978569
Epoch: 8392 Loss: 0.008921463042497635
Epoch: 8393 Loss: 0.00891687348484993
Epoch: 8394 Loss: 0.008915513753890991
Epoch: 8395 Loss: 0.00891310814768076
Epoch: 8396 Loss: 0.008914324454963207
Epoch: 8397 Loss: 0.008911027573049068
Epoch: 8398 Loss: 0.00890566036105156
Epoch: 8399 Loss: 0.008907795883715153
Epoch: 8400 Loss: 0.008906764909625053
Epoch: 8401 Loss: 0.008900690823793411
Epoch: 8402 Loss: 0.008900542743504047
Epoch: 8403 Loss: 0.0089021064341

Epoch: 8604 Loss: 0.008544682525098324
Epoch: 8605 Loss: 0.008544684387743473
Epoch: 8606 Loss: 0.008544040843844414
Epoch: 8607 Loss: 0.008543560281395912
Epoch: 8608 Loss: 0.008538462221622467
Epoch: 8609 Loss: 0.008538593538105488
Epoch: 8610 Loss: 0.008539178408682346
Epoch: 8611 Loss: 0.00853512343019247
Epoch: 8612 Loss: 0.008533341810107231
Epoch: 8613 Loss: 0.008533624932169914
Epoch: 8614 Loss: 0.008529353886842728
Epoch: 8615 Loss: 0.008528451435267925
Epoch: 8616 Loss: 0.008528130128979683
Epoch: 8617 Loss: 0.008524374105036259
Epoch: 8618 Loss: 0.008525579236447811
Epoch: 8619 Loss: 0.008523046970367432
Epoch: 8620 Loss: 0.008521104231476784
Epoch: 8621 Loss: 0.008516672067344189
Epoch: 8622 Loss: 0.008517042733728886
Epoch: 8623 Loss: 0.00851520150899887
Epoch: 8624 Loss: 0.008513791486620903
Epoch: 8625 Loss: 0.008513854816555977
Epoch: 8626 Loss: 0.008505510166287422
Epoch: 8627 Loss: 0.008509596809744835
Epoch: 8628 Loss: 0.008506360463798046
Epoch: 8629 Loss: 0.0085081

Epoch: 8837 Loss: 0.008170551620423794
Epoch: 8838 Loss: 0.008167793974280357
Epoch: 8839 Loss: 0.008164811879396439
Epoch: 8840 Loss: 0.008165650069713593
Epoch: 8841 Loss: 0.00816398486495018
Epoch: 8842 Loss: 0.008159701712429523
Epoch: 8843 Loss: 0.008157346397638321
Epoch: 8844 Loss: 0.008158671669661999
Epoch: 8845 Loss: 0.008155964314937592
Epoch: 8846 Loss: 0.008154930546879768
Epoch: 8847 Loss: 0.008151599206030369
Epoch: 8848 Loss: 0.008153031580150127
Epoch: 8849 Loss: 0.008149291388690472
Epoch: 8850 Loss: 0.008150257170200348
Epoch: 8851 Loss: 0.008146768435835838
Epoch: 8852 Loss: 0.0081437723711133
Epoch: 8853 Loss: 0.008142848499119282
Epoch: 8854 Loss: 0.008142662234604359
Epoch: 8855 Loss: 0.008140676654875278
Epoch: 8856 Loss: 0.008138191886246204
Epoch: 8857 Loss: 0.008138657547533512
Epoch: 8858 Loss: 0.008137345314025879
Epoch: 8859 Loss: 0.008134963922202587
Epoch: 8860 Loss: 0.008132047951221466
Epoch: 8861 Loss: 0.00813095923513174
Epoch: 8862 Loss: 0.008130721

Epoch: 9066 Loss: 0.007825756445527077
Epoch: 9067 Loss: 0.00782162044197321
Epoch: 9068 Loss: 0.00782177783548832
Epoch: 9069 Loss: 0.007820550352334976
Epoch: 9070 Loss: 0.007817714475095272
Epoch: 9071 Loss: 0.007818377576768398
Epoch: 9072 Loss: 0.007815380580723286
Epoch: 9073 Loss: 0.00781499594449997
Epoch: 9074 Loss: 0.007813607342541218
Epoch: 9075 Loss: 0.007809776347130537
Epoch: 9076 Loss: 0.0078086755238473415
Epoch: 9077 Loss: 0.007808041293174028
Epoch: 9078 Loss: 0.007807029411196709
Epoch: 9079 Loss: 0.007805798668414354
Epoch: 9080 Loss: 0.007805380504578352


KeyboardInterrupt: 

最后我们用训练好的模型尝试在1到100这些数字上玩FizzBuzz游戏

In [17]:
# Output now
testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
with torch.no_grad():
    testY = model(testX)
predictions = zip(range(1, 101), list(testY.max(1)[1].data.tolist()))

print([fizz_buzz_decode(i, x) for (i, x) in predictions])

['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', '19', 'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', '38', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', 'fizz', '67', '68', 'fizz', 'buzz', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'buzz', '86', '87', '88', '89', 'fizzbuzz', '91', '92', '93', '94', 'buzz', 'fizz', '97', '98', 'fizz', 'buzz']


In [18]:
print(np.sum(testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])))
testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])

98


array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True, False,  True,  True,  True,
        True,  True, False,  True,  True,  True,  True,  True,  True,
        True])