In [None]:
# pytorch_vit_to_onnx_int8.py
import os, glob, random
from dataclasses import dataclass
from typing import Optional, Tuple, List

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

In [None]:
import kagglehub
DATA_ROOT = kagglehub.dataset_download("imsparsh/flowers-dataset")
print(DATA_ROOT)

Using Colab cache for faster access to the 'flowers-dataset' dataset.
/kaggle/input/flowers-dataset


In [None]:
TRAIN_DIR = os.path.join(DATA_ROOT, "train")
TEST_DIR  = os.path.join(DATA_ROOT, "test")

IMG = 224
BATCH = 1
EPOCHS = 2
NUM_CLASSES = 5


# ====== 1) RMSNorm (PyTorch) ======
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, use_bias: bool = False):
        super().__init__()
        self.eps = eps
        self.use_bias = use_bias
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta  = nn.Parameter(torch.zeros(dim)) if use_bias else None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (..., dim)
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        y = x / rms * self.gamma
        if self.beta is not None:
            y = y + self.beta
        return y


# ====== 2) Transformer Encoder 블록 ======
class TransformerEncoder(nn.Module):
    def __init__(self, dim: int, heads: int, mlp_dim: int, dropout: float = 0.1, use_layernorm: bool = False):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6) if use_layernorm else RMSNorm(dim, eps=1e-6)
        self.attn  = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout, batch_first=True)
        self.drop1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(dim, eps=1e-6) if use_layernorm else RMSNorm(dim, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, N, dim]
        h = self.norm1(x)
        h, _ = self.attn(h, h, h, need_weights=False)  # self-attention
        x = x + self.drop1(h)

        h = self.norm2(x)
        h = self.mlp(h)
        return x + h


# ====== 3) ViT (간단 버전, CLS 없이 GAP) ======
class TinyViT(nn.Module):
    def __init__(
        self,
        image_size: int = 224,
        patch: int = 16,
        num_classes: int = 5,
        dim: int = 128,
        depth: int = 5,
        heads: int = 8,
        mlp_dim: int = 256,
        dropout: float = 0.1,
        use_layernorm: bool = False,
    ):
        super().__init__()
        assert image_size % patch == 0
        num_patches = (image_size // patch) ** 2
        self.patch = patch
        self.dim = dim
        self.num_patches = num_patches

        # Patch embedding: Conv + reshape
        self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch, stride=patch, padding=0)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, dim) * 0.02)

        self.encoders = nn.ModuleList([
            TransformerEncoder(dim=dim, heads=heads, mlp_dim=mlp_dim, dropout=dropout, use_layernorm=use_layernorm)
            for _ in range(depth)
        ])

        self.final_norm = nn.LayerNorm(dim, eps=1e-6) if use_layernorm else RMSNorm(dim, eps=1e-6)
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        # x: [B, 3, H, W], H=W=image_size
        x = self.patch_embed(x)                 # [B, dim, H/ps, W/ps]
        x = x.flatten(2).transpose(1, 2)        # [B, N, dim]
        x = x + self.pos_embed                  # [B, N, dim]
        for blk in self.encoders:
            x = blk(x)
        x = self.final_norm(x)
        x = x.mean(dim=1)                       # GAP over tokens
        logits = self.head(x)                   # [B, num_classes]
        return logits
        '''
        # x: [1, 3, 224, 224] (배치 1로 고정)
        x = self.patch_embed(x)                           # [1, dim, Hp, Wp]
        x = x.permute(0, 2, 3, 1).contiguous()            # [1, Hp, Wp, dim]
        x = x.view(1, self.num_patches, self.dim)         # <-- 완전 상수 shape
        x = x + self.pos_embed                            # [1, N, dim]
        for blk in self.encoders:
          x = blk(x)
        x = self.final_norm(x)
        x = x.mean(dim=1)                                 # [1, dim]
        logits = self.head(x)                             # [1, num_classes]
        return logits


# ====== 4) 데이터셋/전처리 ======
train_tf = transforms.Compose([
    transforms.Resize((IMG, IMG)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),                    # [0,1]
])



def build_loaders(batch=BATCH):
    train_ds = datasets.ImageFolder(root=TRAIN_DIR, transform=train_tf)
    #val_ds   = datasets.ImageFolder(root=TEST_DIR,  transform=val_tf)
    train_ld = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True)
    #val_ld   = DataLoader(val_ds,   batch_size=batch, shuffle=False, num_workers=2, pin_memory=True)
    return train_ld


# ====== 5) 간단 학습 루프 ======
def train_one_epoch(model, loader, optim, device):
    model.train()
    ce = nn.CrossEntropyLoss()
    total, correct, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optim.zero_grad(set_to_none=True)
        logits = model(x)
        loss = ce(logits, y)
        loss.backward()
        optim.step()
        loss_sum += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return loss_sum/total, correct/total


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    ce = nn.CrossEntropyLoss()
    total, correct, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce(logits, y)
        loss_sum += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return loss_sum/total, correct/total

device = "cuda" if torch.cuda.is_available() else "cpu"
model = TinyViT(
        image_size=IMG, patch=16, num_classes=NUM_CLASSES,
        dim=128, depth=5, heads=8, mlp_dim=256, dropout=0.1,
        use_layernorm=False
    ).to(device)

train_ld = build_loaders()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(EPOCHS):
        tr_loss, tr_acc = train_one_epoch(model, train_ld, optim, device)
        print(f"[{epoch+1}/{EPOCHS}] train loss {tr_loss:.4f} acc {tr_acc:.3f}")

torch.save(model.state_dict(), "tinyvit.pt")


[1/2] train loss 1.4591 acc 0.329
[2/2] train loss 1.2114 acc 0.473


In [None]:
!pip install onnx onnxruntime onnxsim

Collecting onnx
  Downloading onnx-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.0 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting onnxsim
  Downloading onnxsim-0.4.36.tar.gz (21.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.0/21.0 MB[0m [31m101.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m121.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.23.2-cp312

In [None]:
# 더미 입력: 고정 배치1
dummy = torch.randn(1, 3, IMG, IMG, dtype=torch.float32)
model.to('cpu')
# 안전한 opset 13 (Neural ART/임베디드 호환 용이)
torch.onnx.export(
    model, (dummy,), "tinyvit_fp32.onnx",
    input_names=["input"], output_names=["logits"],
    opset_version=14,
    do_constant_folding=False,
    dynamic_axes=None  # 고정 입력(1x3x224x224). 필요하면 {"input":{0:"B"}}로 변경
)
print("Exported: tinyvit_fp32.onnx")

  torch.onnx.export(


Exported: tinyvit_fp32.onnx


In [None]:
import onnx
from onnxsim import simplify

# load your predefined ONNX model
model = onnx.load('tinyvit_fp32.onnx')

# convert model
model_simp, check = simplify(model)

assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp,'tinyvit_fp32_sim.onnx')


In [None]:
!python -m onnxruntime.quantization.preprocess --input tinyvit_fp32_sim.onnx --output tinyvit_fp32_infer.onnx

In [None]:
# quantize_static_qdq.py
import os, glob
import numpy as np
from PIL import Image

from onnxruntime.quantization import quantize_static, CalibrationDataReader, CalibrationMethod, QuantType
from onnxruntime.quantization import preprocess
from onnxruntime import InferenceSession

IMG = 224
TEST_DIR = os.getenv("FLOWERS_TEST_DIR", "/kaggle/input/flowers-dataset/test")

INPUT_NAME = "input"   # export_onnx.py에서 지정
ONNX_IN  = "tinyvit_fp32_sim.onnx"
ONNX_INF = "tinyvit_fp32_infer.onnx"
ONNX_INT8 = "tinyvit_int8_qdq.onnx"



# 3-2) 캘리브레이터
class ImageFolderDataReader(CalibrationDataReader):
    def __init__(self, folder, input_name, img_size=224, max_images=200):
        self.input_name = input_name
        self.img_paths = sorted(
            sum([glob.glob(os.path.join(folder, ext)) for ext in ("*.jpg", "*.png", "*.jpeg")], [])
        )
        if not self.img_paths:
            raise FileNotFoundError(f"No images found under {folder}")
        self.img_paths = self.img_paths[:max_images]
        self.enum_data = None
        self.count = 0

    def get_next(self):
        if self.enum_data is None:
            self.enum_data = self._data_iter()
        return next(self.enum_data, None)

    def _data_iter(self):
        for p in self.img_paths:
            img = Image.open(p).convert("RGB").resize((IMG, IMG), Image.BILINEAR)
            arr = np.asarray(img, dtype=np.float32) / 255.0      # [H,W,3] in [0,1]
            arr = np.transpose(arr, (2,0,1))                     # [3,H,W]
            arr = np.expand_dims(arr, 0)                         # [1,3,H,W]
            yield { self.input_name: arr }

# 3-3) 정적 Q/DQ 양자화
dr = ImageFolderDataReader(TEST_DIR, INPUT_NAME, img_size=IMG, max_images=200)
quantize_static(
    model_input=ONNX_INF,
    model_output=ONNX_INT8,
    calibration_data_reader=dr,
    #calibration_method=CalibrationMethod.MinMax,   # 필요시 Percentile/Entropy로 변경
    per_channel=True,                              # conv/linear에 유리
    reduce_range=False,
    weight_type=QuantType.QInt8,                   # 가중치 INT8
    activation_type=QuantType.QInt8                # 활성값 INT8 (임베디드 친화)
)
print("Quantized:", ONNX_INT8)

# 3-4) 간단 검증
sess = InferenceSession(ONNX_INT8, providers=["CPUExecutionProvider"])
print("Inputs:", [i.name for i in sess.get_inputs()], sess.get_inputs()[0].shape, sess.get_inputs()[0].type)
print("Outputs:", [o.name for o in sess.get_outputs()], sess.get_outputs()[0].shape, sess.get_outputs()[0].type)

Quantized: tinyvit_int8_qdq.onnx
Inputs: ['input'] [1, 3, 224, 224] tensor(float)
Outputs: ['logits'] [1, 5] tensor(float)


In [None]:
import onnx
from onnxsim import simplify

# load your predefined ONNX model
model = onnx.load('tinyvit_int8_qdq.onnx')

# convert model
model_simp, check = simplify(model)

assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp,'tinyvit_int8_sim.onnx')
