In [None]:
%matplotlib inline


PyTorch: Tensors and autograd
-------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Tensors, and uses PyTorch autograd to compute gradients.


A PyTorch Tensor represents a node in a computational graph. If ``x`` is a
Tensor that has ``x.requires_grad=True`` then ``x.grad`` is another Tensor
holding the gradient of ``x`` with respect to some scalar value.



In [1]:
import torch

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Tensors during the backward pass.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y using operations on Tensors; these
    # are exactly the same operations we used to compute the forward pass using
    # Tensors, but we do not need to keep references to intermediate values since
    # we are not implementing the backward pass by hand.
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    # Compute and print loss using operations on Tensors.
    # Now loss is a Tensor of shape (1,)
    # loss.item() gets the a scalar value held in the loss.
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Tensors with requires_grad=True.
    # After this call w1.grad and w2.grad will be Tensors holding the gradient
    # of the loss with respect to w1 and w2 respectively.
    loss.backward()

    # Manually update weights using gradient descent. Wrap in torch.no_grad()
    # because weights have requires_grad=True, but we don't need to track this
    # in autograd.
    # An alternative way is to operate on weight.data and weight.grad.data.
    # Recall that tensor.data gives a tensor that shares the storage with
    # tensor, but doesn't track history.
    # You can also use torch.optim.SGD to achieve this.
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 28381492.0
1 22915732.0
2 22275188.0
3 22675172.0
4 21835244.0
5 18752290.0
6 13996350.0
7 9264500.0
8 5662078.0
9 3391066.0
10 2083559.75
11 1360151.875
12 953874.0
13 715977.625
14 567213.8125
15 467190.3125
16 395062.4375
17 339958.4375
18 296030.03125
19 259944.5
20 229666.359375
21 203892.265625
22 181777.171875
23 162631.625
24 145928.421875
25 131300.34375
26 118417.8046875
27 107039.4921875
28 96969.6328125
29 88039.8515625
30 80113.4609375
31 73027.0234375
32 66678.1796875
33 60973.5
34 55835.40625
35 51200.37890625
36 47008.58203125
37 43215.48828125
38 39774.12109375
39 36645.6796875
40 33802.0546875
41 31211.03125
42 28846.119140625
43 26682.685546875
44 24702.517578125
45 22887.583984375
46 21222.892578125
47 19693.5546875
48 18288.181640625
49 16996.474609375
50 15806.67578125
51 14711.568359375
52 13700.8515625
53 12766.8701171875
54 11903.2705078125
55 11103.830078125
56 10363.6904296875
57 9678.041015625
58 9042.0751953125
59 8451.630859375
60 7903.74560546875
61 739

408 0.002075634431093931
409 0.0020120148546993732
410 0.0019498651381582022
411 0.0018913953099399805
412 0.0018332985928282142
413 0.001777996658347547
414 0.001725560869090259
415 0.0016751803923398256
416 0.0016239627730101347
417 0.001578743802383542
418 0.0015317507786676288
419 0.0014877215726301074
420 0.0014441825915127993
421 0.0014024110278114676
422 0.0013616678770631552
423 0.0013251120690256357
424 0.0012850005878135562
425 0.0012479388387873769
426 0.001212829607538879
427 0.0011789383133873343
428 0.0011464550625532866
429 0.0011148572666570544
430 0.0010833883425220847
431 0.0010518708731979132
432 0.00102336669806391
433 0.0009952818509191275
434 0.0009669457795098424
435 0.0009415525128133595
436 0.0009173309081234038
437 0.0008919470710679889
438 0.00086927943630144
439 0.0008469211752526462
440 0.0008246144861914217
441 0.0008029320160858333
442 0.0007817996665835381
443 0.0007623230339959264
444 0.0007423057686537504
445 0.0007225826848298311
446 0.000704405363649