In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

In [5]:
import torch

In [2]:
from sklearn import datasets

# import some data to play with
iris = datasets.load_iris()
X = iris.data[:, :2]
Y = iris.target

## Data Generation

In [3]:
print(X.shape)
print(Y.shape)

(150, 2)
(150,)


In [6]:
X = torch.from_numpy(X)
Y = torch.from_numpy(Y)

In [7]:
import torch.utils.data as Data

In [8]:
torch_data = Data.TensorDataset(X, Y)

In [19]:
loader = Data.DataLoader(
    dataset=torch_data,
    batch_size=64,
    shuffle=True,
    num_workers=2,
)

## NN Generation

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [11]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 32)
        self.fc2 = nn.Linear(32, 8)
        self.fc3 = nn.Linear(8, 3)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [23]:
net = Net()
net = net.float()
print(net)

Net(
  (fc1): Linear(in_features=2, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=8, bias=True)
  (fc3): Linear(in_features=8, out_features=3, bias=True)
)


In [26]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

## Training

In [27]:
for epoch in range(100):
    running_loss = 0.0
    for i, data in enumerate(loader, 0):
        inputs, labels = data
        
        optimizer.zero_grad()
        outputs = net(inputs.float())
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i == 0:
            print(running_loss)
        running_loss = 0.0

0.7966356873512268
0.7801154255867004
0.7522132992744446
0.734838604927063
0.7328603267669678
0.6755171418190002
0.678482174873352
0.6431788206100464
0.6416812539100647
0.6159035563468933
0.5455247759819031
0.5211825370788574
0.5746440291404724
0.5111878514289856
0.5774176120758057
0.5695798993110657
0.48390018939971924
0.5187276005744934
0.5314834117889404
0.4870004951953888
0.5140343904495239
0.4827297031879425
0.5045962333679199
0.4465681314468384
0.5104561448097229
0.5626042485237122
0.45902952551841736
0.4732963442802429
0.5009604692459106
0.5277127027511597
0.4469616413116455
0.6004400253295898
0.44069674611091614
0.5326589345932007
0.5459178686141968
0.44009923934936523
0.5673669576644897
0.4766644835472107
0.5489664673805237
0.46185779571533203
0.5102113485336304
0.5073224306106567
0.37490618228912354
0.4902339577674866
0.4703598916530609
0.489685595035553
0.4858153462409973
0.530290961265564
0.4802493155002594
0.45105990767478943
0.4414791166782379
0.44234728813171387
0.526449

In [28]:
outputs = net(X.float())
result = torch.argmax(outputs, 1).numpy() - Y.numpy()
print('error rate:', sum(1 * (result != 0)) / result.shape[0])

error rate: 0.2866666666666667
