# CNN to ONNX

> CNN with torch: [denev6/deep-learning-codes](https://github.com/denev6/deep-learning-codes/blob/main/cnn_mnist.ipynb)

- learning_rate = 0.003
- epochs = 5
- batch_size = 64

```text
CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=1568, out_features=32, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=32, out_features=10, bias=True)
  )
)
```

Accuracy: 0.98390

## Save as ONNX

In [10]:
model.cpu()
params = model.state_dict()
torch.save(params, f"{model_path}net.prm", pickle_protocol=4)

In [12]:
import torch.onnx

onnx_path = model_path + "cnn.onnx"
dummy_input = torch.empty(1, 1, 28, 28, dtype=torch.float32)
torch.onnx.export(model, dummy_input, onnx_path,
        input_names=["input"], output_names=["output"])

In [13]:
import onnx

onnx_model = onnx.load(onnx_path)
onnx.save(onnx.shape_inference.infer_shapes(onnx_model), onnx_path)