In [1]:
import torch
import torchvision # provide access to datasets, models, transforms, utils, etc
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# now let's work with FashionMnist

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [3]:
train_loader = torch.utils.data.DataLoader(train_set
    ,batch_size=10
    ,shuffle=True
)

In [4]:
print(len(train_set))
print(train_set.train_labels)
print(train_set.train_labels.bincount())

60000
tensor([9, 0, 0,  ..., 3, 0, 5])
tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000])




In [5]:
import torch.nn.functional as F
import torch.nn as nn
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, t):
        # (1) input layer
        t = t

        # (2) hidden conv layer
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        # (3) hidden conv layer
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        # (4) hidden linear layer
        t = t.reshape(-1, 12 * 4 * 4)
        t = self.fc1(t)
        t = F.relu(t)

        # (5) hidden linear layer
        t = self.fc2(t)
        t = F.relu(t)

        # (6) output layer
        t = self.out(t)
        t = F.softmax(t, dim=1)

        return t

In [6]:
def get_num_correct(preds, labels):
  #print(preds, labels)
  #print(preds.argmax(dim=1).eq(labels).sum().item())
  return preds.argmax(dim=1).eq(labels).sum().item()

In [7]:
import torch.optim as optim

torch.set_grad_enabled(True)


<torch.autograd.grad_mode.set_grad_enabled at 0x7fc932df5710>

In [8]:
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01)


In [23]:
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.001)

for epoch in range(50):

    total_loss = 0
    total_correct = 0
    for batch in train_loader: # Get Batch
        images, labels = batch 

        preds = network(images) # Pass Batch
        loss = F.cross_entropy(preds, labels) # Calculate Loss

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        total_correct += get_num_correct(preds, labels)

    print(
        "epoch", epoch, 
        "total_correct:", total_correct, 
        "loss:", total_loss,
        "Accuracy:", (total_correct/len(train_set)) * 100
    )

epoch 0 total_correct: 40199 loss: 1079.3872591257095 Accuracy: 66.99833333333333
epoch 1 total_correct: 46976 loss: 1008.9812076091766 Accuracy: 78.29333333333334
epoch 2 total_correct: 48789 loss: 990.209684252739 Accuracy: 81.315
epoch 3 total_correct: 49664 loss: 981.0148547887802 Accuracy: 82.77333333333333
epoch 4 total_correct: 50205 loss: 974.6975557804108 Accuracy: 83.675
epoch 5 total_correct: 50790 loss: 969.2768123149872 Accuracy: 84.65
epoch 6 total_correct: 51133 loss: 965.8692858219147 Accuracy: 85.22166666666666
epoch 7 total_correct: 51446 loss: 962.5946229696274 Accuracy: 85.74333333333334
epoch 8 total_correct: 51514 loss: 961.6579760313034 Accuracy: 85.85666666666667
epoch 9 total_correct: 51754 loss: 959.1803640127182 Accuracy: 86.25666666666667
epoch 10 total_correct: 51896 loss: 957.7703030109406 Accuracy: 86.49333333333334
epoch 11 total_correct: 52060 loss: 956.0620466470718 Accuracy: 86.76666666666667
epoch 12 total_correct: 52153 loss: 955.1941641569138 Accur

In [28]:
# inference 


In [29]:
image.shape

torch.Size([1, 28, 28])

In [30]:
image.unsqueeze(0).shape, label

(torch.Size([1, 1, 28, 28]), 4)

In [31]:
type(preds)

torch.Tensor

In [32]:
s = network(image.unsqueeze(0))

In [20]:

s

tensor([[8.0038e-14, 8.0331e-15, 1.6151e-08, 1.9919e-10, 1.0000e+00, 3.1813e-20,
         7.4116e-12, 1.7748e-17, 3.7619e-08, 1.4401e-11]],
       grad_fn=<SoftmaxBackward>)

In [21]:
s.argmax().item()

4

In [25]:
tcorrect = 0
for image, label in train_set:
  pred_label = network(image.unsqueeze(0))
  pred_label = pred_label.argmax().item()
  if label == pred_label:
    tcorrect += 1
print(f'Accuracy: {(tcorrect/len(train_set)) * 100}')

Accuracy: 91.385


In [26]:
torch.save(network, 'fmnist.onnx')

In [27]:
model_load = torch.load('fmnist.onnx')

In [34]:
pr = model_load(image.unsqueeze(0))

In [35]:
pr

tensor([[9.0785e-22, 2.0790e-28, 5.4591e-11, 4.1819e-20, 1.0000e+00, 2.4126e-36,
         2.1495e-22, 5.5445e-30, 1.0968e-25, 1.5771e-29]],
       grad_fn=<SoftmaxBackward>)

In [36]:
pr.argmax().item()

4

In [38]:
pr.sum().item()

1.0

In [40]:
# inference
image, label = train_set[59890]
image = image.unsqueeze(0)
prediction = model_load(image)
prediction = prediction.argmax().item()
if prediction == label:
  print('correct result', prediction)
else:
  print(prediction, label, "not equal")


correct result 4
