<a href="https://colab.research.google.com/github/near129/othello/blob/feature%2Ffix_alphazero/pytorch_and_onnx_quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

In [None]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.height = self.width = 8
        self.ouput_size = 8 * 8
        self.dropout_late = 0.5
        in_channels = 2
        channels = 64

        self.relu = nn.ReLU()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            self.relu,
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            self.relu,
            nn.Conv2d(channels, channels, 3),
            nn.BatchNorm2d(channels),
            self.relu,
            nn.Conv2d(channels, channels, 3),
            nn.BatchNorm2d(channels),
            self.relu,
        )

        self.fc_input = channels * (self.width - 4) * (self.height - 4)
        self.dropout = nn.Dropout(self.dropout_late, inplace=True)
        self.layer2 = nn.Sequential(
            nn.Linear(self.fc_input, 512),
            nn.BatchNorm1d(512),
            self.relu,
            self.dropout,
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            self.relu,
            self.dropout,
        )

        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(256, 1)
        self.softmax = nn.Softmax(dim=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.layer1(x)
        x = x.view(-1, self.fc_input)
        x = self.layer2(x)
        policy = self.fc3(x)
        value = self.fc4(x)
        return self.softmax(policy), self.tanh(value)


In [None]:
!pip install onnxruntime onnx timm

Collecting timm
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[K     |████████████████████████████████| 376 kB 5.3 MB/s 
Installing collected packages: timm
Successfully installed timm-0.4.12


In [None]:
model_i8 = torch.quantization.quantize_dynamic(Model())
dummy_input = torch.randn(1, 2, 8, 8)
_ = model_i8(dummy_input)
torch.onnx.export(model_i8 , dummy_input, 'model_test.onnx', opset_version=11)

AttributeError: ignored

In [None]:
!pip install onnxoptimizer onnx-simplifier

Collecting onnxoptimizer
  Downloading onnxoptimizer-0.2.6-cp37-cp37m-manylinux2014_x86_64.whl (466 kB)
[?25l[K     |▊                               | 10 kB 27.0 MB/s eta 0:00:01[K     |█▍                              | 20 kB 29.0 MB/s eta 0:00:01[K     |██                              | 30 kB 21.0 MB/s eta 0:00:01[K     |██▉                             | 40 kB 16.1 MB/s eta 0:00:01[K     |███▌                            | 51 kB 5.6 MB/s eta 0:00:01[K     |████▏                           | 61 kB 6.1 MB/s eta 0:00:01[K     |█████                           | 71 kB 5.5 MB/s eta 0:00:01[K     |█████▋                          | 81 kB 6.2 MB/s eta 0:00:01[K     |██████▎                         | 92 kB 6.1 MB/s eta 0:00:01[K     |███████                         | 102 kB 5.3 MB/s eta 0:00:01[K     |███████▊                        | 112 kB 5.3 MB/s eta 0:00:01[K     |████████▍                       | 122 kB 5.3 MB/s eta 0:00:01[K     |█████████▏                      |

In [None]:
dummy_input = torch.randn(1, 2, 8, 8)
torch.onnx.export(Model(), dummy_input, 'model.onnx', opset_version=11)

NameError: ignored

In [None]:
import onnx
import onnxruntime
import numpy as np

In [None]:
model = onnx.load('model.onnx')
modeli8 = onnx.load('modeli8.onnx')


FileNotFoundError: ignored

In [None]:
onnx.checker.check_model(model)
onnx.checker.check_model(modeli8)

In [None]:
ort_session = onnxruntime.InferenceSession('model.onnx')

In [None]:
ort_sessioni8 = onnxruntime.InferenceSession('modeli8.onnx')

In [None]:
import timm
class EfficientNet(torch.nn.Module):
    def __init__(self, backbone='mixnet_s', features=256, dropout=0.3):
        super().__init__()
        self.features = features
        self.dropout = dropout
        self.backbone = timm.create_model(
            backbone, num_classes=self.features, in_chans=2, exportable=True
        )

        self.bn = nn.BatchNorm1d(self.features)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(self.dropout, inplace=True)

        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(256, 1)
        self.softmax = nn.Softmax(dim=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.backbone(x)
        x = self.dropout(self.relu(self.bn(x)))
        policy = self.fc3(x)
        value = self.fc4(x)
        return self.softmax(policy), self.tanh(value)
        
dummy_input = torch.randn(1, 2, 8, 8)
_ = EfficientNet().eval()(dummy_input)
torch.onnx.export(Model(), dummy_input, 'efficientnet.onnx', opset_version=13)

In [None]:
model = Model().eval()
model_i8 = torch.quantization.quantize_dynamic(model).eval()
dummy_input = [torch.randn(1, 2, 8, 8) for _ in range(500)]
%time _=[model(x) for x in dummy_input]
%time _=[model_i8(x) for x in dummy_input]

CPU times: user 623 ms, sys: 8.79 ms, total: 632 ms
Wall time: 633 ms
CPU times: user 880 ms, sys: 49.2 ms, total: 929 ms
Wall time: 931 ms


In [None]:
from onnxruntime.quantization import quantize_dynamic, QuantType, quantize_qat
model_input = "efficientnet.onnx"
model_output = "efficientnet_i8.onnx"
quantize_dynamic(model_input, model_output, weight_type=QuantType.QUInt8)

In [None]:
!python -m onnxsim efficientnet.onnx efficientnet_opt.onnx
!python -m onnxsim efficientnet_i8.onnx efficientnet_i8_opt.onnx

Simplifying...
Checking 0/3...
Checking 1/3...
Checking 2/3...
Ok!
Simplifying...
Checking 0/3...
Checking 1/3...
Checking 2/3...
Ok!


In [None]:
model_input = "efficientnet_opt.onnx"
model_output = "efficientnet_opt_i8.onnx"
quantize_dynamic(model_input, model_output, weight_type=QuantType.QUInt8)

In [None]:
!ls -lh

total 15M
-rw-r--r-- 1 root root 791K Sep 18 14:27 efficientnet_i8.onnx
-rw-r--r-- 1 root root 787K Sep 18 14:27 efficientnet_i8_opt.onnx
-rw-r--r-- 1 root root 3.1M Sep 18 14:26 efficientnet.onnx
-rw-r--r-- 1 root root 785K Sep 18 14:27 efficientnet_opt_i8.onnx
-rw-r--r-- 1 root root 3.1M Sep 18 14:27 efficientnet_opt.onnx
-rw-r--r-- 1 root root 3.1M Sep 18 14:27 efficientnet-opt.onnx
-rw-r--r-- 1 root root 3.1M Sep 18 14:27 efficientnet_opt-opt.onnx
drwxr-xr-x 1 root root 4.0K Sep 16 13:40 sample_data


In [None]:
ort_session = onnxruntime.InferenceSession('efficientnet.onnx')
ort_session_opt = onnxruntime.InferenceSession('efficientnet_opt.onnx')
ort_session_i8 = onnxruntime.InferenceSession('efficientnet_i8.onnx')
# ort_session_opt_i = onnxruntime.InferenceSession('efficientnet_opt_i8.onnx')
ort_session_i8_opt = onnxruntime.InferenceSession('efficientnet_i8_opt.onnx')
dummy_input = [torch.randn(1, 2, 8, 8).numpy().astype(np.float32) for _ in range(500)]
%time _=[ort_session.run(None, {'input.1': x}) for x in dummy_input]
%time _=[ort_session_opt.run(None, {'input.1': x}) for x in dummy_input]
%time _=[ort_session_i8.run(None, {'input.1': x}) for x in dummy_input]
# %time _=[ort_session_opt_i8.run(None, {'input.1': x}) for x in dummy_input]
%time _=[ort_session_i8_opt.run(None, {'input.1': x}) for x in dummy_input]

CPU times: user 167 ms, sys: 0 ns, total: 167 ms
Wall time: 168 ms
CPU times: user 164 ms, sys: 0 ns, total: 164 ms
Wall time: 163 ms
CPU times: user 238 ms, sys: 0 ns, total: 238 ms
Wall time: 239 ms
CPU times: user 233 ms, sys: 0 ns, total: 233 ms
Wall time: 234 ms
