In [1]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [5]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss()



In [6]:

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 1.0284990072250366
1 1.0025169849395752
2 0.9771362543106079
3 0.9525332450866699
4 0.928604781627655
5 0.9052971005439758
6 0.8826870918273926
7 0.8607214689254761
8 0.8393694162368774
9 0.818602442741394
10 0.7983651757240295
11 0.7786996960639954
12 0.7596547603607178
13 0.7411739826202393
14 0.7232431173324585
15 0.705828070640564
16 0.6888147592544556
17 0.6722877621650696
18 0.656289279460907
19 0.6407888531684875
20 0.6256147027015686
21 0.6108340620994568
22 0.5964106321334839
23 0.5823883414268494
24 0.5687198638916016
25 0.5553773641586304
26 0.542304277420044
27 0.5295599102973938
28 0.5170992612838745
29 0.5048671960830688
30 0.4928366541862488
31 0.4810017943382263
32 0.4694449007511139
33 0.45813775062561035
34 0.447124719619751
35 0.43640127778053284
36 0.4259372651576996
37 0.4156671464443207
38 0.40561598539352417
39 0.3957681953907013
40 0.38617274165153503
41 0.37678927183151245
42 0.36760368943214417
43 0.3586505651473999
44 0.34989601373672485
45 0.34135442972183

390 1.383912717756175e-07
391 1.3024371980918659e-07
392 1.225586458986072e-07
393 1.1530819676863757e-07
394 1.0847075060382849e-07
395 1.0202715117202388e-07
396 9.594869965212638e-08
397 9.022335234476486e-08
398 8.48176142653756e-08
399 7.97256163309612e-08
400 7.493131448654822e-08
401 7.041236926852434e-08
402 6.615881176230687e-08
403 6.215179126911607e-08
404 5.837598138214162e-08
405 5.4819217609747284e-08
406 5.147760617774111e-08
407 4.832840971857877e-08
408 4.536687825407171e-08
409 4.258074426388703e-08
410 3.995696218339617e-08
411 3.748910870626787e-08
412 3.5172853074527666e-08
413 3.299091844155555e-08
414 3.0937787443008347e-08
415 2.901238360664138e-08
416 2.719698599662479e-08
417 2.5497062239310253e-08
418 2.3896314260696272e-08
419 2.23926299725008e-08
420 2.0983396353813077e-08
421 1.965456064567661e-08
422 1.8412556812563707e-08
423 1.7243973360336895e-08
424 1.6144159786790624e-08
425 1.5116729201736234e-08
426 1.4149892813009046e-08
427 1.3242996033113741e-08