In [1]:
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader

import torch
from torch import nn, optim
from torch.nn import functional as F

from ignite.engine import create_supervised_trainer, create_supervised_evaluator, Events
from ignite.metrics import Accuracy, MeanSquaredError

# import onnx

In [2]:
trans = transforms.Compose([transforms.ToTensor()])

data_path = "../../datasets/mnist/"
data = datasets.MNIST(data_path, transform=trans) # download=True)

batch_size = 60
train_data = DataLoader(dataset=data, batch_size=batch_size*10, shuffle=True)

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, 5, stride=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, stride=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        
        self.fc1 = nn.Linear(32, 50)
        self.fc2 = nn.Linear(50, 10)
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.conv2(x)
#         print(x.shape)
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
cnn = CNN()

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr=1e-3)

trainer = create_supervised_trainer(cnn, optimizer, criterion)
evaluator = create_supervised_evaluator(
    cnn,
    metrics={
        "accuracy": Accuracy()
    }
)

In [6]:
@trainer.on(Events.EPOCH_COMPLETED)
def progress(trainer):
    evaluator.run(train_data)
    accuracy = evaluator.state.metrics["accuracy"]
    epoch = trainer.state.epoch
    print(f"{epoch:<2} ~> loss: {trainer.state.output:.3f} | accuracy: {accuracy:.3f}")

In [7]:
state = trainer.run(train_data, max_epochs=10)

1  ~> loss: 0.544 | accuracy: 0.840
2  ~> loss: 0.354 | accuracy: 0.902
3  ~> loss: 0.266 | accuracy: 0.924
4  ~> loss: 0.194 | accuracy: 0.936
5  ~> loss: 0.178 | accuracy: 0.947
6  ~> loss: 0.186 | accuracy: 0.955
7  ~> loss: 0.124 | accuracy: 0.960
8  ~> loss: 0.151 | accuracy: 0.964
9  ~> loss: 0.130 | accuracy: 0.968
10 ~> loss: 0.107 | accuracy: 0.970


In [6]:
frag_iter = iter(train_data)
frag = frag_iter.next() # um fragmento do dataset


torch.onnx.export(cnn, frag[0], "convnet.onnx", verbose=True)

In [25]:
ls -lh | grep onnx

-rw-r--r-- 1 lincoln users  97K jan  7 14:00 convnet.onnx
-rw-r--r-- 1 lincoln users  28K jan  7 13:59 onnx_view.ipynb


In [9]:
import onnx

In [10]:
convnet_onnx = onnx.load("convnet.onnx")
out = onnx.helper.printable_graph(convnet_onnx.graph)
print(out)

graph torch-jit-export (
  %0[FLOAT, 600x1x28x28]
) initializers (
  %1[FLOAT, 16x1x5x5]
  %2[FLOAT, 16]
  %3[FLOAT, 32x16x5x5]
  %4[FLOAT, 32]
  %5[FLOAT, 50x32]
  %6[FLOAT, 50]
  %7[FLOAT, 10x50]
  %8[FLOAT, 10]
) {
  %9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [0, 0, 0, 0], strides = [2, 2]](%0, %1, %2)
  %10 = Relu(%9)
  %11 = MaxPool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%10)
  %12 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [0, 0, 0, 0], strides = [2, 2]](%11, %3, %4)
  %13 = Relu(%12)
  %14 = MaxPool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%13)
  %15 = Constant[value = <Scalar Tensor []>]()
  %16 = Shape(%14)
  %17 = Gather[axis = 0](%16, %15)
  %18 = Constant[value = <Scalar Tensor []>]()
  %19 = Unsqueeze[axes = [0]](%17)
  %20 = Unsqueeze[axes = [0]](%18)
  %21 = Concat[axis = 0](%19, %20)
  %22 = Reshape(%14, %21)
  %23 = Gemm[alpha = 1, beta = 1, transB = 1](%22, %5, %6)
  %24 =

In [12]:
%%bash

python3 /usr/lib/python3.7/site-packages/onnx/tools/net_drawer.py --input convnet.onnx --output convnet.dot

dot -Grankdir=TB -Tpng convnet.dot -o convnet.png

In [22]:
%%html
<img src="convnet.png" width=30%>