In [1]:
import torch
import random
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_wine

In [2]:
wine = load_wine()

In [3]:
features = 13

In [4]:
X_train, X_test, y_train, y_test = train_test_split(
    wine.data[:, :features],
    wine.target,
    test_size=0.3,
    shuffle=True)

In [5]:
X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

In [6]:
X_train.data.shape

torch.Size([124, 13])

In [7]:
class WineNet(torch.nn.Module):
  def __init__(self, n_input, n_hidden_neurons):
    super(WineNet, self).__init__()
    self.fc1 = torch.nn.Linear(n_input, n_hidden_neurons)
    self.activ1 = torch.nn.Sigmoid()
    self.fc2 = torch.nn.Linear(n_hidden_neurons, n_hidden_neurons)
    self.activ2 = torch.nn.Sigmoid()
    self.fc3 = torch.nn.Linear(n_hidden_neurons, 3)
    self.sm = torch.nn.Softmax(dim=1)

  def forward(self, x):
    x = self.fc1(x)
    x = self.activ1(x)
    x = self.fc2(x)
    x = self.activ2(x)
    x = self.fc3(x)
    return x

  def inference(self, x):
    x = self.forward(x)
    x = self.sm(x)
    return x

In [14]:
n_input = features
n_hidden = 50

In [15]:
wine_net = WineNet(n_input, n_hidden)

In [16]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(wine_net.parameters(), lr=1.0e-3)

In [17]:
batch_size = 10

In [18]:
for epoch in range(2000):
  order = np.random.permutation(len(X_train))
  for start_index in range(0, len(X_train), batch_size):
    optimizer.zero_grad()

    batch_indexes = order[start_index : start_index + batch_size]

    x_batch = X_train[batch_indexes]
    y_batch = y_train[batch_indexes]

    preds = wine_net.forward(x_batch)

    loss_value = loss(preds, y_batch)
    loss_value.backward()

    optimizer.step()

  if epoch % 10 == 0:
    test_preds = wine_net.forward(X_test)
    test_preds = test_preds.argmax(dim=1)

In [21]:
print(wine_net.fc1.in_features, np.asarray((test_preds == y_test).float().mean()) > 0.8)

13 True


In [22]:
print(wine_net.fc1.in_features, np.asarray((test_preds == y_test).float().mean()))

13 0.9074074
