In [1]:
from sklearn import datasets
import torch
from torchmetrics import Accuracy
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

In [2]:
iris = datasets.load_iris()

In [3]:
iris.target

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [4]:
x = iris.data
y = iris.target
x = torch.from_numpy(x).float()
y = torch.from_numpy(y).long()

In [5]:
type(x)

torch.Tensor

In [6]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)

In [7]:
class Model(nn.Module):
    def __init__(self, input_features=4, output_features=3, hidden_features=10):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, output_features)
        )
    def forward(self, x):
        return self.network(x)

In [8]:
model = Model()

In [9]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [10]:
for epoch in range(1000):
    model.train()
    y_logits = model(x_train)
    y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)
    y_train = y_train.to(torch.long)
    loss = loss_fn(y_logits, y_pred)
    train_accuracy = Accuracy(task='multiclass', num_classes=3)(y_pred, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    model.eval()
    with torch.inference_mode():
        test_logits = model(x_test)
        test_pred = F.softmax(test_logits, dim=1).argmax(dim=1)
        y_test = y_test.to(torch.long)
        test_loss = loss_fn(test_logits, y_test)
        test_accuracy = Accuracy(task='multiclass', num_classes=3)(test_pred, y_test)
    if epoch % 100 ==0:
        print(f"Epoch: {epoch} | Loss: {loss} | Test Loss: {test_loss} | Test Accuracy: {test_accuracy}")

Epoch: 0 | Loss: 0.5676558613777161 | Test Loss: 1.2487958669662476 | Test Accuracy: 0.30000001192092896
Epoch: 100 | Loss: 0.02406633086502552 | Test Loss: 3.731092691421509 | Test Accuracy: 0.30000001192092896
Epoch: 200 | Loss: 0.004892414435744286 | Test Loss: 5.227289199829102 | Test Accuracy: 0.30000001192092896
Epoch: 300 | Loss: 0.002044660970568657 | Test Loss: 6.063615798950195 | Test Accuracy: 0.30000001192092896
Epoch: 400 | Loss: 0.001114990678615868 | Test Loss: 6.648725986480713 | Test Accuracy: 0.30000001192092896
Epoch: 500 | Loss: 0.0006995636504143476 | Test Loss: 7.100063800811768 | Test Accuracy: 0.30000001192092896
Epoch: 600 | Loss: 0.0004786188364960253 | Test Loss: 7.468376159667969 | Test Accuracy: 0.30000001192092896
Epoch: 700 | Loss: 0.0003472389071248472 | Test Loss: 7.780228137969971 | Test Accuracy: 0.30000001192092896
Epoch: 800 | Loss: 0.0002628119254950434 | Test Loss: 8.051264762878418 | Test Accuracy: 0.30000001192092896
Epoch: 900 | Loss: 0.0002053