In [4]:
import numpy as np
import torch
from torch.autograd import Variable

xy = np.loadtxt('data/diabetes.csv', delimiter=',', dtype=np.float32)
x_data = Variable(torch.from_numpy(xy[:, 0:-1]))
y_data = Variable(torch.from_numpy(xy[:, [-1]]))

## Design model

In [5]:
class Model(torch.nn.Module):
    def __init__(self):
        """
        In the constructor we instantiate three nn.Linear module
        """
        super(Model, self).__init__()
        self.l1 = torch.nn.Linear(8, 6)
        self.l2 = torch.nn.Linear(6, 4)
        self.l3= torch.nn.Linear(4, 1)
        
        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self, x):
        """
        In the forward function we accept a Variable of input data and we must return
        a Variable of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Variables.
        """
        out1 = self.sigmoid(self.l1(x))
        out2 = self.sigmoid(self.l2(out1))
        y_pred = self.sigmoid(self.l3(out2))
        return y_pred
    
# our model
model = Model()

## Construct loss and optimizer

In [7]:
# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parametes of the three
# nn.Linear modules which are members of the model.
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

## Training cycle

In [8]:
# Training loop
for epoch in range(100):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x_data)
    
    # Compute and print loss
    loss = criterion(y_pred, y_data)
    print(epoch, loss.data.item())
    
    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 0.7349616289138794
1 0.7260209918022156
2 0.7179750204086304
3 0.710736870765686
4 0.7042251825332642
5 0.6983659267425537
6 0.6930961608886719
7 0.6883549094200134
8 0.6840885281562805
9 0.6802496910095215
10 0.676794707775116
11 0.6736848950386047
12 0.6708842515945435
13 0.6683629155158997
14 0.6660914421081543
15 0.6640451550483704
16 0.6622009873390198
17 0.6605387330055237
18 0.6590403318405151
19 0.6576891541481018
20 0.656470537185669
21 0.6553707718849182
22 0.6543787121772766
23 0.6534829139709473
24 0.6526740193367004
25 0.6519437432289124
26 0.6512842774391174
27 0.6506888270378113
28 0.6501503586769104
29 0.6496636271476746
30 0.6492236852645874
31 0.6488261222839355
32 0.6484665870666504
33 0.6481412649154663
34 0.6478468179702759
35 0.6475803256034851
36 0.6473397612571716
37 0.6471213698387146
38 0.6469240188598633
39 0.6467452645301819
40 0.6465832591056824
41 0.6464365720748901
42 0.6463039517402649
43 0.646183967590332
44 0.6460744738578796
45 0.6459758281707764
46