In [1]:
import torch
from torch.autograd import Variable as V
from torch.utils.data import Dataset, DataLoader

import numpy as np

In [2]:
class DiabetesDataset(Dataset):
    def __init__(self):
        data = np.loadtxt('data/diabetes.csv', delimiter=',', dtype=np.float32)
        self.x_data = torch.from_numpy(data[:, 0:-1])
        self.y_data = torch.from_numpy(data[:, [-1]])
        self.len = data.shape[0]
    
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
    
    def __len__(self):
        return self.len

In [3]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.l1 = torch.nn.Linear(8, 6) # notice, 8 inputs and...
        self.l2 = torch.nn.Linear(6, 4)
        self.l3 = torch.nn.Linear(4, 1) # one output
            
        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self, x):
        out1 = self.sigmoid(self.l1(x))
        out2 = self.sigmoid(self.l2(out1))
        y_hat = self.sigmoid(self.l3(out2))
        
        return y_hat

model = Model()

criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

dataset = DiabetesDataset()
train_loader = DataLoader(dataset=dataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=2)

In [7]:
for epoch in range(2):
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = V(inputs), V(labels)
    
        y_hat = model(inputs)
        loss = criterion(y_hat, labels)
        print(epoch, i, loss.data[0])
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

0 0 0.8063732385635376
0 1 0.7425640821456909
0 2 0.7506237626075745
0 3 0.710849940776825
0 4 0.7370873689651489
0 5 0.7026951313018799
0 6 0.7203333377838135
0 7 0.7004845142364502
0 8 0.6950200200080872
0 9 0.677420973777771
0 10 0.6726390719413757
0 11 0.6956398487091064
0 12 0.6819764971733093
0 13 0.6473138928413391
0 14 0.6785783171653748
0 15 0.6921446919441223
0 16 0.6647969484329224
0 17 0.6552974581718445
0 18 0.7101877927780151
0 19 0.6698883771896362
0 20 0.644668698310852
0 21 0.6304548978805542
0 22 0.7065246105194092
0 23 0.657615065574646
1 0 0.6448472738265991
1 1 0.6645093560218811
1 2 0.6866055130958557
1 3 0.6753205060958862
1 4 0.6865471601486206
1 5 0.6863905787467957
1 6 0.6647646427154541
1 7 0.631183385848999
1 8 0.6398179531097412
1 9 0.6251775622367859
1 10 0.6224042177200317
1 11 0.6338616013526917
1 12 0.6616095900535583
1 13 0.6471397280693054
1 14 0.6912146210670471
1 15 0.6179956197738647
1 16 0.6614190936088562
1 17 0.646186351776123
1 18 0.61493128538