This script is based on the code by hunkim:

https://github.com/hunkim/PyTorchZeroToAll/blob/master/05_linear_regression.py

He has a set of video tutorials on Youtube:

https://www.youtube.com/watch?v=113b7O3mabY&index=5&list=PLYhW1ajuwgHdF9tWF_QfjcgGkK4C2rcYv


In [7]:
import torch
from torch.autograd import Variable

# JMT: Our data consists of {x,y} = { (1.0,2.0) , (2.0,4.0) , (3.0,6.0) }.

x_data = Variable(torch.Tensor([[1.0], [2.0], [3.0]]))
y_data = Variable(torch.Tensor([[2.0], [4.0], [6.0]]))


class Model(torch.nn.Module):

    def __init__(self):
        """
        In the constructor we instantiate two nn.Linear module
        """
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(1, 1)  # One in and one out

    def forward(self, x):
        """
        In the forward function we accept a Variable of input data and we must return
        a Variable of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Variables.
        """
        y_pred = self.linear(x)
        return y_pred

# our model
model = Model()


# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Training loop
for epoch in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x_data)

    # Compute and print loss
    loss = criterion(y_pred, y_data)
    print(epoch, loss.data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


# After training
# JMT : Using the model, predict the output if input is x=4.
hour_var = Variable(torch.Tensor([[4.0]]))
y_pred = model(hour_var)
print("predict (after training)", 4, model(hour_var).data[0][0])

# JMT: Output the weights of the model.
print("\nThe weights are: \n")
for param in model.parameters():
    print(param.data)

0 75.26031494140625
1 33.51506423950195
2 14.931100845336914
3 6.657889366149902
4 2.9747323989868164
5 1.3349382877349854
6 0.6047958135604858
7 0.2796059548854828
8 0.13469198346138
9 0.07003293931484222
10 0.04110413044691086
11 0.028083031997084618
12 0.022145919501781464
13 0.019364256411790848
14 0.017989322543144226
15 0.0172425527125597
16 0.01677735149860382
17 0.01643950678408146
18 0.01616024225950241
19 0.01590879261493683
20 0.01567159779369831
21 0.015442589297890663
22 0.015218967571854591
23 0.014999482780694962
24 0.014783605933189392
25 0.014571011997759342
26 0.014361481182277203
27 0.014155084267258644
28 0.013951650820672512
29 0.013751084916293621
30 0.01355350948870182
31 0.0133587084710598
32 0.013166684657335281
33 0.012977455742657185
34 0.012790997512638569
35 0.01260712742805481
36 0.012425982393324375
37 0.012247397564351559
38 0.01207137480378151
39 0.011897857300937176
40 0.011726931668817997
41 0.011558336205780506
42 0.011392271146178246
43 0.0112285157

398 6.584836955880746e-05
399 6.489916995633394e-05
400 6.396848039003089e-05
401 6.304729322437197e-05
402 6.213965389179066e-05
403 6.124542414909229e-05
404 6.036760896677151e-05
405 5.949703336227685e-05
406 5.864479317096993e-05
407 5.780268838861957e-05
408 5.6973687605932355e-05
409 5.615329428110272e-05
410 5.534836964216083e-05
411 5.455355130834505e-05
412 5.3764051699545234e-05
413 5.299553959048353e-05
414 5.223256812314503e-05
415 5.148180207470432e-05
416 5.0742262828862295e-05
417 5.0014248699881136e-05
418 4.929557326249778e-05
419 4.858453758060932e-05
420 4.7886707761790603e-05
421 4.7201112465700135e-05
422 4.651967537938617e-05
423 4.585263377521187e-05
424 4.519275898928754e-05
425 4.454116788110696e-05
426 4.3902386096306145e-05
427 4.327051283326e-05
428 4.265006282366812e-05
429 4.2036353988805786e-05
430 4.143307887716219e-05
431 4.0836392145138234e-05
432 4.0252518374472857e-05
433 3.9673213905189186e-05
434 3.910392842954025e-05
435 3.854020178550854e-05
436 