# 2025 EAI Lab 5

## Topic 1 : From PyTorch To ONNX

### Steps:
1.   Define Model Architecture
2.   Load Weight
3.   Export ONNX File
4.   Quantize To INT8
5.   Building Session



In [14]:
!pip install -U \
    torch torchvision torchaudio \
    onnx onnxscript onnxruntime onnxruntime-tools onnxruntime-gpu \
    gradio

Collecting torch
  Using cached torch-2.9.1-cp313-cp313-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting torchvision
  Using cached torchvision-0.24.1-cp313-cp313-manylinux_2_28_x86_64.whl.metadata (5.9 kB)
Collecting torchaudio
  Using cached torchaudio-2.9.1-cp313-cp313-manylinux_2_28_x86_64.whl.metadata (6.9 kB)
Collecting onnx
  Using cached onnx-1.20.0-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.4 kB)
Collecting onnxscript
  Using cached onnxscript-0.5.6-py3-none-any.whl.metadata (13 kB)
Collecting onnxruntime
  Using cached onnxruntime-1.23.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting onnxruntime-tools
  Using cached onnxruntime_tools-1.7.0-py3-none-any.whl.metadata (14 kB)
Collecting onnxruntime-gpu
  Using cached onnxruntime_gpu-1.23.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.4 kB)
Collecting gradio
  Using cached gradio-6.0.2-py3-none-any.whl.metadata (16 kB)
Collecting

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# TODO
# Design Your ResNet18 Model

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.left = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(out_channels)
                )
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet18(nn.Module):
    def __init__(self,num_classes=10):
        super(ResNet18, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes)
    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels 
        return nn.Sequential(*layers)
    def forward(self, x):
        out = self.conv1(x)
        
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        out = F.avg_pool2d(out, 4)
        
        out = out.view(out.size(0), -1) 
        out = self.fc(out)
        return out


In [13]:
torch_model = ResNet18(num_classes=10)
dummy_input = (torch.randn(1, 3, 32, 32),)

def export_onnx(model, dummy, path):
  print(f"Loading weights from {path}...")
  try:
      state = torch.load(path, map_location=torch.device("cpu"))

      # 1. 如果存檔時包了一層 'state_dict'，先取出來
      if 'state_dict' in state:
          state = state['state_dict']

      # 2. 清洗 State Dict (移除 module. 前綴 與 多餘的統計 key)
      new_state_dict = {}
      for k, v in state.items():
          # 移除 DataParallel 可能產生的 'module.' 前綴
          name = k.replace("module.", "")
          
          # 【關鍵修改】過濾掉包含 total_ops 或 total_params 的 key
          if "total_ops" in name or "total_params" in name:
              continue
              
          new_state_dict[name] = v

      # 3. 載入清洗後的權重
      # strict=True (預設) 會確保權重完全對應，現在清洗乾淨了應該不會報錯
      model.load_state_dict(new_state_dict)
      print("Weights loaded successfully.")

  except Exception as e:
      print(f"Error loading weights: {e}")
      # 如果還是報錯，可以嘗試把下面這行取消註解，強制忽略錯誤（不推薦，除非真的找不到原因）
      # model.load_state_dict(new_state_dict, strict=False) 
      return

  model.eval()

  # Todo : Export ONNX FILE
  torch.onnx.export(
      model,
      dummy,
      "NM6131051_FP32.onnx",
      input_names=["input"],
      output_names=["output"],
      opset_version=11,
  )
  pass

if __name__ == "__main__":
  # 提醒 : 記得先把 best_model.pth 上傳到 Content 資料夾
  export_onnx(model=torch_model, dummy=dummy_input, path="best_model.pth")


Loading weights from best_model.pth...
Weights loaded successfully.


ModuleNotFoundError: No module named 'onnxscript'

In [None]:
import os, numpy as np
from PIL import Image
import onnxruntime as ort
from onnxruntime.quantization import CalibrationDataReader

CIFAR10_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
CIFAR10_STD  = np.array([0.2470, 0.2435, 0.2616], dtype=np.float32)

def preprocess_32x32(pil_img: Image.Image) -> np.ndarray:
    arr = np.asarray(pil_img.convert("RGB").resize((32, 32)), dtype=np.float32) / 255.0
    arr = (arr - CIFAR10_MEAN) / CIFAR10_STD
    return arr.transpose(2, 0, 1)[None, ...]  # (1,3,32,32)

class CIFARLikeCalibReader(CalibrationDataReader):
    def __init__(self, image_dir: str = None, input_name: str = "input",
                 batch_size: int = 32, num_batches: int = 10):
        self.input_name  = input_name
        self.batch_size  = batch_size
        self.num_batches = num_batches
        self.paths = []
        if image_dir and os.path.isdir(image_dir):
            for f in os.listdir(image_dir):
                if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
                    self.paths.append(os.path.join(image_dir, f))
        self._mode_random = len(self.paths) == 0
        self._pos = 0
        self._emitted = 0

    def get_next(self):
        if self._emitted >= self.num_batches:
            return None
        if self._mode_random:
            batch = np.random.randn(self.batch_size, 3, 32, 32).astype(np.float32)
        else:
            items = []
            for _ in range(self.batch_size):
                if self._pos >= len(self.paths):
                    break
                img = Image.open(self.paths[self._pos])
                self._pos += 1
                items.append(preprocess_32x32(img))
            if not items:
                return None
            batch = np.concatenate(items, axis=0).astype(np.float32)
        self._emitted += 1
        return {self.input_name: batch}

    def rewind(self):
        self._pos = 0
        self._emitted = 0

FP32_MODEL = "image_classifier_model.onnx"
INT8_MODEL = "image_classifier_model_int8.onnx"


_tmp = ort.InferenceSession(FP32_MODEL, providers=["CPUExecutionProvider"])
INPUT_NAME = _tmp.get_inputs()[0].name
print("Calib will use input name:", INPUT_NAME)


In [None]:
from onnxruntime.quantization import quantize_static, QuantType, CalibrationMethod



reader = CIFARLikeCalibReader(
    image_dir=None,
    input_name=INPUT_NAME,
    batch_size=1,
    num_batches=50
)


def quantize_to_int8(fp32_path, int8_path, reader, method="MinMax"):
    # Todo : quantize_static
    quantize_static(

    )
    print("Saved INT8 model:", INT8_MODEL)

quantize_to_int8(FP32_MODEL, INT8_MODEL, reader)

In [None]:
import time
import numpy as np
import onnxruntime as ort

def run(sess, x):
    return sess.run(None, {sess.get_inputs()[0].name: x})[0]

x_demo = np.random.randn(1,3,32,32).astype(np.float32)

# Todo : build session function
def build_session(model_path, providers):
  return



sess_fp32 = build_session(model_path=FP32_MODEL, providers=["CPUExecutionProvider"])
sess_int8 = build_session(model_path=INT8_MODEL, providers=["CPUExecutionProvider"])

y_fp32 = run(sess_fp32, x_demo)
y_int8 = run(sess_int8, x_demo)

l2_rel = np.linalg.norm(y_fp32 - y_int8) / (np.linalg.norm(y_fp32) + 1e-12)
print(f"[Check] relative L2 diff FP32 vs INT8: {l2_rel:.6f}")

def bench(sess, x, n=50):
    t0 = time.time()
    for _ in range(n):
        sess.run(None, {sess.get_inputs()[0].name: x})
    return (time.time() - t0) / n

print("FP32 avg sec:", bench(sess_fp32, x_demo))
print("INT8 avg sec:", bench(sess_int8, x_demo))

so = ort.SessionOptions()
so.enable_profiling = True



## Topic 2 : Gradio


In [None]:
! pip install gradio

In [None]:
import onnxruntime as ort
import numpy as np
from PIL import Image
import gradio as gr
import time

# ====== Config ======
MODEL_PATH_INT8 = "image_classifier_model_int8.onnx"   # INT8 ONNX Model
MODEL_PATH_FP32 = "image_classifier_model.onnx"     # FP32 ONNX Model
LABELS = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']

# CIFAR-10 Normalization Parameter
CIFAR10_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
CIFAR10_STD  = np.array([0.2470, 0.2435, 0.2616], dtype=np.float32)

# ====== Utils ======
def softmax_np(x: np.ndarray) -> np.ndarray:
    x = x - np.max(x)
    ex = np.exp(x)
    return ex / np.sum(ex)

# TODO : preprocess input image function
def preprocess(image: Image.Image) -> np.ndarray:
    """輸入 PIL Image → (1,3,32,32) float32"""
    if not isinstance(image, Image.Image):
        raise ValueError("Plese Upload Image")


    return arr

# ====== ONNX Sessions ======
providers = ort.get_available_providers()

sess_int8 = build_session(MODEL_PATH_INT8, providers=providers)
in_int8  = sess_int8.get_inputs()[0].name
out_int8 = sess_int8.get_outputs()[0].name


try:
    sess_fp32 = build_session(MODEL_PATH_FP32, providers=providers)
    in_fp32  = sess_fp32.get_inputs()[0].name
    out_fp32 = sess_fp32.get_outputs()[0].name
    _fp32_err = ""
except Exception as e:
    sess_fp32, in_fp32, out_fp32 = None, None, None
    _fp32_err = f"[FP32 load failure] {type(e).__name__}: {e}"

# ====== Compare FP32 and INT8 ======
# TODO : Compare FP32 and INT8
def compare_fp32_int8(image: Image.Image):
    if image is None:
        return {}, {}, "Please Upload Your Image。"
    if sess_fp32 is None:
        return {}, {}, (_fp32_err or "The FP32 model has not been provided, so a comparison cannot be made.")

    x = preprocess(image)

    # Your progarm


    p_fp32 = softmax_np()
    p_int8 = softmax_np()

    def top3_map(p):
        idx = np.argpartition(p, -3)[-3:]
        idx = idx[np.argsort(p[idx])[::-1]]
        return {LABELS[i]: float(p[i]) for i in idx}

    top3_fp32 = top3_map(p_fp32)
    top3_int8 = top3_map(p_int8)

    summary = (
        f"FP32 inference time: {fp32_ms:.2f} ms\n"
        f"INT8 inference time: {int8_ms:.2f} ms\n"
        f"Speedup (FP32/INT8): {(fp32_ms / max(int8_ms, 1e-9)):.2f}×"
    )
    return top3_fp32, top3_int8, summary

# ====== Gradio UI ======
# TODO : Building GUI Interface
demo = gr.Interface(
    fn = compare_fp32_int8,
    inputs =
    outputs =
    title =
    description =
)

if __name__ == "__main__":
  # TODO : building a public web

