# 05 模型导出与推理验证：ONNX（CTR + Matching）

- **目标**：演示 Torch-RecHub 模型导出 ONNX（含 dynamic axes），并用 onnxruntime 做一次最小推理验证。

## 依赖
- 导出：`onnx>=1.20.0`
- 推理验证：`onnxruntime`
- 可选（量化）：`onnxruntime`（INT8 动态量化），`onnxconverter-common`（FP16 转换）


In [1]:
import os
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from tqdm import tqdm

from torch_rechub.basic.features import DenseFeature, SparseFeature, SequenceFeature
from torch_rechub.models.ranking import DeepFM
from torch_rechub.models.matching import DSSM
from torch_rechub.trainers import CTRTrainer
from torch_rechub.utils.data import DataGenerator
from torch_rechub.utils.onnx_export import ONNXExporter
from torch_rechub.utils.model_utils import generate_dummy_input_dict

SEED = 2022
DEVICE = "cuda:0"
torch.manual_seed(SEED)

EXPORT_DIR = "./onnx_exports"
os.makedirs(EXPORT_DIR, exist_ok=True)
print("EXPORT_DIR:", os.path.abspath(EXPORT_DIR))


EXPORT_DIR: e:\RecommendSystemProject\torch-rechub\tutorials\onnx_exports


In [2]:
# ---------- Part A: CTR（DeepFM）导出 + onnxruntime 推理验证 ----------

DATASET_PATH = "../examples/ranking/data/criteo/criteo_sample.csv"
EPOCH = 1
BATCH_SIZE = 2048
LR = 1e-3
WEIGHT_DECAY = 1e-3


def convert_numeric_feature(val):
    v = int(val)
    if v > 2:
        return int(np.log(v) ** 2)
    else:
        return v - 2


def get_criteo_data_dict(data_path):
    data = pd.read_csv(data_path, compression="gzip") if data_path.endswith(".gz") else pd.read_csv(data_path)
    dense_features = [f for f in data.columns.tolist() if f.startswith("I")]
    sparse_features = [f for f in data.columns.tolist() if f.startswith("C")]

    data[sparse_features] = data[sparse_features].fillna("0")
    data[dense_features] = data[dense_features].fillna(0)

    for feat in tqdm(dense_features, desc="discretize dense"):
        sparse_features.append(feat + "_cat")
        data[feat + "_cat"] = data[feat].apply(lambda x: convert_numeric_feature(x))

    sca = MinMaxScaler()
    # MinMaxScaler 默认输出 float64，这会导致后续 dataloader/ONNX 输入变成 double。
    # 这里显式转成 float32，保证与导出的 ONNX（通常期望 float32）一致。
    data[dense_features] = sca.fit_transform(data[dense_features]).astype(np.float32)

    for feat in tqdm(sparse_features, desc="label encode sparse"):
        lbe = LabelEncoder()
        data[feat] = lbe.fit_transform(data[feat])

    dense_feas = [DenseFeature(name) for name in dense_features]
    sparse_feas = [SparseFeature(name, vocab_size=data[name].nunique(), embed_dim=16) for name in sparse_features]

    y = data["label"]
    x = data.drop(columns=["label"])
    return dense_feas, sparse_feas, x, y


dense_feas, sparse_feas, x, y = get_criteo_data_dict(DATASET_PATH)
dg = DataGenerator(x, y)
train_dl, val_dl, test_dl = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=BATCH_SIZE)

ctr_model = DeepFM(
    deep_features=dense_feas,
    fm_features=sparse_feas,
    mlp_params={"dims": [64, 32], "dropout": 0.1, "activation": "relu"},
)

ctr_trainer = CTRTrainer(
    ctr_model,
    optimizer_params={"lr": LR, "weight_decay": WEIGHT_DECAY},
    n_epoch=EPOCH,
    earlystop_patience=2,
    device=DEVICE,
    model_path="./",
)
ctr_trainer.fit(train_dl, val_dl)

ctr_onnx_path = os.path.join(EXPORT_DIR, "deepfm.onnx")
exporter = ONNXExporter(ctr_model, device=DEVICE)
exporter.export(ctr_onnx_path, opset_version=14, dynamic_batch=True, verbose=False)
print("exported:", ctr_onnx_path)


discretize dense: 100%|██████████| 13/13 [00:00<00:00, 4804.47it/s]
label encode sparse: 100%|██████████| 39/39 [00:00<00:00, 13004.04it/s]

the samples of train : val : test are  80 : 11 : 24





epoch: 0


train: 100%|██████████| 1/1 [00:00<00:00,  4.02it/s]
validation: 100%|██████████| 1/1 [00:00<00:00, 24.93it/s]


epoch: 0 validation: auc: 0.16666666666666666
exported: ./onnx_exports\deepfm.onnx


In [3]:
# 用 onnxruntime 做一次最小推理验证（允许浮点误差）
# 注意：如果你把 DEVICE 设为 cuda，需要把 batch 输入也搬到同一设备。

try:
    import onnxruntime as ort

    # 取一个 batch（dataloader 默认产出 CPU tensors）
    batch_x, _ = next(iter(test_dl))

    ctr_model.eval()
    model_device = next(ctr_model.parameters()).device

    # torch 推理（把输入搬到模型所在 device）；同时把 double → float32
    batch_x_torch = {
        k: (v.float() if v.dtype == torch.float64 else v).to(model_device)
        for k, v in batch_x.items()
    }
    with torch.no_grad():
        torch_out = ctr_model(batch_x_torch).detach().cpu().numpy()

    # ONNXRuntime 推理（输入需为 numpy，通常在 CPU 上即可）
    # 注意：ONNX 常见期望 float32；这里对所有 float64 显式转 float32，并按 onnx 输入签名补齐维度。
    ort_sess = ort.InferenceSession(ctr_onnx_path, providers=["CPUExecutionProvider"])

    # 根据 onnx 输入签名，修正 rank（常见：模型期望 (B,1)，而 dataloader 给的是 (B,)）
    ort_inputs = {}
    ort_input_info = {i.name: i for i in ort_sess.get_inputs()}
    for k, v in batch_x.items():
        if v.dtype == torch.float64:
            v = v.float()
        arr = v.detach().cpu().numpy()

        info = ort_input_info.get(k)
        if info is not None and hasattr(info, "shape"):
            expected_rank = len(info.shape)
            if expected_rank == 2 and arr.ndim == 1:
                arr = arr.reshape(-1, 1)

        ort_inputs[k] = arr

    ort_out = ort_sess.run(None, ort_inputs)[0]

    max_abs_diff = float(np.max(np.abs(torch_out - ort_out)))
    print("torch_out shape:", torch_out.shape, "onnx_out shape:", ort_out.shape)
    print("max_abs_diff:", max_abs_diff)
except ImportError as e:
    print("onnxruntime not installed, skip inference check:", e)
except Exception as e:
    print("inference check failed:", repr(e))


torch_out shape: (24,) onnx_out shape: (24,)
max_abs_diff: 5.960464477539063e-08


## 可选：ONNX 模型量化（INT8 / FP16）

下面演示在已导出的 ONNX 模型基础上进行压缩：
- **INT8 动态量化**：更偏向 CPU 推理加速（常见于 MLP/Linear）。
- **FP16 转换**：更偏向 GPU 推理（降低显存占用）。


In [4]:
# 量化/转换导出的 ONNX（可选）

from torch_rechub.utils import quantize_model

ctr_onnx_int8_path = os.path.join(EXPORT_DIR, "deepfm.int8.onnx")
ctr_onnx_fp16_path = os.path.join(EXPORT_DIR, "deepfm.fp16.onnx")

# INT8 动态量化（需要 onnxruntime）
try:
    quantize_model(ctr_onnx_path, ctr_onnx_int8_path, mode="int8")
    print("exported int8:", ctr_onnx_int8_path)
except Exception as e:
    print("INT8 quantize skipped:", repr(e))

# FP16 转换（需要 onnx + onnxconverter-common）
try:
    quantize_model(ctr_onnx_path, ctr_onnx_fp16_path, mode="fp16", keep_io_types=True)
    print("exported fp16:", ctr_onnx_fp16_path)
except Exception as e:
    print("FP16 convert skipped:", repr(e))




exported int8: ./onnx_exports\deepfm.int8.onnx
exported fp16: ./onnx_exports\deepfm.fp16.onnx


In [5]:
# ---------- Part B: Matching（DSSM）双塔导出 + 最小推理验证 ----------

# 为了让本教程独立且快速，这里构造一个最小 DSSM 模型（不依赖完整训练），并导出 user/item tower。

# user tower features
user_features = [
    SparseFeature("user_id", vocab_size=1000, embed_dim=16),
    SparseFeature("gender", vocab_size=3, embed_dim=16),
    SparseFeature("age", vocab_size=10, embed_dim=16),
    SparseFeature("occupation", vocab_size=30, embed_dim=16),
    SparseFeature("zip", vocab_size=5000, embed_dim=16),
    SequenceFeature("hist_movie_id", vocab_size=5000, embed_dim=16, pooling="mean", shared_with="movie_id"),
]

# item tower features
item_features = [
    SparseFeature("movie_id", vocab_size=5000, embed_dim=16),
    SparseFeature("cate_id", vocab_size=50, embed_dim=16),
]

match_model = DSSM(
    user_features,
    item_features,
    temperature=0.02,
    user_params={"dims": [64], "activation": "prelu"},
    item_params={"dims": [64], "activation": "prelu"},
)

user_onnx_path = os.path.join(EXPORT_DIR, "user_tower.onnx")
item_onnx_path = os.path.join(EXPORT_DIR, "item_tower.onnx")

match_exporter = ONNXExporter(match_model, device=DEVICE)
match_exporter.export(user_onnx_path, mode="user", opset_version=14, dynamic_batch=True, verbose=False)
match_exporter.export(item_onnx_path, mode="item", opset_version=14, dynamic_batch=True, verbose=False)

print("exported:", user_onnx_path)
print("exported:", item_onnx_path)


exported: ./onnx_exports\user_tower.onnx
exported: ./onnx_exports\item_tower.onnx


In [6]:
# 双塔最小推理验证：分别对 user/item tower 做一次 onnxruntime forward

try:
    import onnxruntime as ort

    # 生成与 feature 定义一致的 dummy 输入
    dummy_user = generate_dummy_input_dict(user_features, batch_size=2, seq_length=10, device=DEVICE)
    dummy_item = generate_dummy_input_dict(item_features, batch_size=2, seq_length=10, device=DEVICE)

    match_model.eval()
    with torch.no_grad():
        # user tower
        match_model.mode = "user"
        torch_user_out = match_model(dummy_user).detach().cpu().numpy()
        # item tower
        match_model.mode = "item"
        torch_item_out = match_model(dummy_item).detach().cpu().numpy()

    # onnxruntime
    user_sess = ort.InferenceSession(user_onnx_path, providers=["CPUExecutionProvider"])
    item_sess = ort.InferenceSession(item_onnx_path, providers=["CPUExecutionProvider"])

    ort_user_in = {k: v.detach().cpu().numpy() for k, v in dummy_user.items()}
    ort_item_in = {k: v.detach().cpu().numpy() for k, v in dummy_item.items()}

    ort_user_out = user_sess.run(None, ort_user_in)[0]
    ort_item_out = item_sess.run(None, ort_item_in)[0]

    print("user torch/onnx shapes:", torch_user_out.shape, ort_user_out.shape)
    print("item torch/onnx shapes:", torch_item_out.shape, ort_item_out.shape)

    print("user max_abs_diff:", float(np.max(np.abs(torch_user_out - ort_user_out))))
    print("item max_abs_diff:", float(np.max(np.abs(torch_item_out - ort_item_out))))
except ImportError as e:
    print("onnxruntime not installed, skip inference check:", e)


user torch/onnx shapes: (2, 64) (2, 64)
item torch/onnx shapes: (2, 64) (2, 64)
user max_abs_diff: 5.960464477539063e-08
item max_abs_diff: 5.960464477539063e-08
