In [65]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision
import matplotlib.pyplot as plt

In [66]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cpu device


In [67]:
class NeuralNetwork(nn.Module):
  def __init__(self):
    super(NeuralNetwork, self).__init__()
    self.flatten = nn.Flatten()
    self.linear_relu_stack = nn.Sequential(
        nn.Linear(28*28, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 10),
        nn.ReLU()
    )

  def forward(self, x):
    x = self.flatten(x)
    return self.linear_relu_stack(x)

In [68]:
model = NeuralNetwork()
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)


In [79]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [70]:
training_data = torchvision.datasets.FashionMNIST(root="/", train = True, download = True, transform = transforms.ToTensor())
testing_data = torchvision.datasets.FashionMNIST(root="/", train = False, download = True, transform = transforms.ToTensor())

In [74]:
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=64)
test_dataloader = torch.utils.data.DataLoader(testing_data, batch_size=64)

In [75]:
epochs = 5

def train_loop(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset)
  for batch, (X, y) in enumerate(dataloader):
    #forward
    pred = model(X)
    loss = loss_fn(pred, y)
    #backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
  size = len(dataloader.dataset)
  test_loss, correct = 0, 0

  with torch.no_grad(): #freezes the parameters
    for X, y in dataloader:
      y_pred = model(X)
      test_loss += loss_fn(y_pred, y).item()
      correct += (y_pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= size
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [80]:
for epoch in range(epochs):
  print(f"Epoch: {epoch+1}\n---------------------------")
  train_loop(train_dataloader, model, loss_fn, optimizer)
  test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch: 1
---------------------------
loss: 0.817515  [    0/60000]
loss: 1.010970  [ 6400/60000]
loss: 0.943630  [12800/60000]
loss: 1.292563  [19200/60000]
loss: 1.270513  [25600/60000]
loss: 1.296304  [32000/60000]
loss: 1.190465  [38400/60000]
loss: 1.169772  [44800/60000]
loss: 1.071999  [51200/60000]
loss: 1.449088  [57600/60000]
Test Error: 
 Accuracy: 53.3%, Avg loss: 0.018730 

Epoch: 2
---------------------------
loss: 0.793312  [    0/60000]
loss: 1.008465  [ 6400/60000]
loss: 0.949494  [12800/60000]
loss: 1.273909  [19200/60000]
loss: 1.265860  [25600/60000]
loss: 1.281920  [32000/60000]
loss: 1.184747  [38400/60000]
loss: 1.143161  [44800/60000]
loss: 1.083885  [51200/60000]
loss: 1.440551  [57600/60000]
Test Error: 
 Accuracy: 53.1%, Avg loss: 0.018681 

Epoch: 3
---------------------------
loss: 0.794106  [    0/60000]
loss: 1.003451  [ 6400/60000]
loss: 0.950270  [12800/60000]
loss: 1.265421  [19200/60000]
loss: 1.266934  [25600/60000]
loss: 1.273716  [32000/60000]
loss:

In [82]:
torch.save(model.state_dict(), "./model.pth")

print("Saved PyTorch Model State to model.pth")

Saved PyTorch Model State to model.pth


### We can export our model as a Open Neural Network Exchange Model to train once and perform inference in other hardward and languages.

In [84]:
image.shape

torch.Size([1, 28, 28])

In [86]:
import torch.onnx as onnx
input_image = torch.zeros(image.shape)

In [88]:
!pip install onnx

Collecting onnx
  Downloading onnx-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m45.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnx
Successfully installed onnx-1.15.0


In [89]:
onnx_model = "./model.onnx"
onnx.export(model, input_image, onnx_model)

In [91]:
test_data = torchvision.datasets.FashionMNIST(root="/", train = True, transform = transforms.ToTensor())

In [92]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]
image, label = test_data[0]

In [95]:
!pip install onnxruntime

Collecting onnxruntime
  Downloading onnxruntime-1.17.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: humanfriendly, coloredlogs, onnxruntime
Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnxruntime-1.17.1


In [96]:
import onnxruntime

In [100]:
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

result = session.run([output_name], {input_name: image.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[label]
print(f'Predicted: "{predicted}", Actual: "{actual}"')

Predicted: "T-shirt/top", Actual: "Ankle boot"
