# Target
**Objective**:

Optimize ResNet18 (CIFAR-10) from FP32 to INT8 with minimal accuracy degradation, and provide comparisons on model size, latency, throughput, and accuracy.

**Deliverables**:

1. FP32 baseline, PTQ model, and QAT model
2. TorchScript/ONNX exports
3. Standardized benchmarking scripts with tables and plots

# 1. Environment
Python 3.10+, PyTorch ≥ 2.2（torch.ao.quantization）, torchvision, onnx, onnxruntime, TensorBoard

1. Train/Test Split: torchvision.datasets.CIFAR10
2. Recommended Transforms:
   + Training: RandomCrop(32, 4), RandomHorizontalFlip(), ToTensor(), Normalize(mean, std)
   + Test/Calibration: ToTensor(), Normalize(mean, std)
3. Batch Sizes: train_bs = 128, test_bs = 256

**Preprocessing pipeline**

This preprocessing pipeline converts CIFAR-10 (32×32, [0,1], NCHW) → (224×224, [-1,1], NHWC), fully aligned with the input specification of Keras ResNet50V2.
+ For the training set, random cropping and horizontal flipping are added as augmentations
+ For the test set, only resizing and normalization are applied.

In [23]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import tensorflow as tf
from tensorflow.keras.applications.resnet_v2 import preprocess_input

# 1) 不做 Normalize（保持 [0,1]），保持你原来的增强
train_transform_clean = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),                # -> [0,1], NCHW
])
test_transform_clean = T.Compose([
    T.ToTensor(),                # -> [0,1]
])

root = "/mnt/qat/cifar10"
train_set_clean = datasets.CIFAR10(root=root, train=True,  download=False, transform=train_transform_clean)
test_set_clean  = datasets.CIFAR10(root=root, train=False, download=False, transform=test_transform_clean)

# 2) 关键：DataLoader 用单进程，避免第二个 epoch 崩溃
def make_torch_loaders(train_bs=128, test_bs=256):
    train_loader = DataLoader(train_set_clean, batch_size=train_bs, shuffle=True,
                              num_workers=0, pin_memory=False, persistent_workers=False)
    test_loader  = DataLoader(test_set_clean,  batch_size=test_bs, shuffle=False,
                              num_workers=0, pin_memory=False, persistent_workers=False)
    return train_loader, test_loader

# 3) 生成器：NCHW -> NHWC，不做额外变量创建
def torch_dl_generator(dl):
    for x, y in dl:
        yield x.permute(0, 2, 3, 1).numpy(), y.numpy()

# 4) 把 PyTorch loader 包成 tf.data。注意：不要跨 epoch 复用，按 epoch 重建
IMG = 224
def map_fn(img, label):
    img = tf.image.resize(img, (IMG, IMG))   # 32 -> 224
    img = preprocess_input(img * 255.0)      # ResNet50V2 预处理
    return img, tf.cast(label, tf.int32)

def make_tf_dataset_from_loader(dl, batch_size):
    ds = tf.data.Dataset.from_generator(
        lambda: torch_dl_generator(dl),
        output_signature=(
            tf.TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None,),           dtype=tf.int64),
        )
    ).unbatch()
    ds = (ds.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(batch_size)
            .prefetch(tf.data.AUTOTUNE))
    return ds

# 2. Model Training

## 2.1 load basic model

In [2]:
import tensorflow as tf
print(tf.__version__)
print(tf.config.list_physical_devices('GPU'))

2.20.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## 2.2 add classification layer

In [2]:
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.applications.resnet_v2 import preprocess_input
from tensorflow.keras import layers, models, optimizers

# 1) backbone
base_model = ResNet50V2(include_top=False, weights=None,
                        input_shape=(224, 224, 3), pooling="avg")

base_model.load_weights("resnet50v2_weights_tf_dim_ordering_tf_kernels_notop.h5")

# 2) add classification layer
x = layers.Dense(256, activation="relu")(base_model.output)
x = layers.Dropout(0.5)(x)
output = layers.Dense(10, activation="softmax")(x)
model = models.Model(inputs=base_model.input, outputs=output)

# 3) Freeze → Train the head, then fine-tune
base_model.trainable = False
model.compile(optimizer=optimizers.Adam(1e-3),
              loss="sparse_categorical_crossentropy", metrics=["accuracy"])

## 2.3 training

In [3]:
# 批大小
TRAIN_BS, TEST_BS = 64, 128
EPOCHS = 5

# 先做一次 test_ds（验证集不用每轮重建）
_, test_loader = make_torch_loaders(TRAIN_BS, TEST_BS)
test_ds = make_tf_dataset_from_loader(test_loader, TEST_BS)
val_steps = len(test_loader)   

for ep in range(EPOCHS):
    # 关键：每个 epoch 重新创建 train_loader 和 train_ds
    train_loader, _ = make_torch_loaders(TRAIN_BS, TEST_BS)
    train_ds = make_tf_dataset_from_loader(train_loader, TRAIN_BS)
    steps_per_epoch = len(train_loader)

    print(f"\n==== Epoch {ep+1}/{EPOCHS} ====")
    model.fit(
        train_ds,
        epochs=1,                         # 一次只跑1个epoch
        steps_per_epoch=steps_per_epoch,  # 显式告知步数
        validation_data=test_ds,
        validation_steps=val_steps
    )

==== Epoch 1/5 ====
  1/782 ━━━━━━━━━━━━━━━━━━━━ 4:18:21 20s/step - accuracy: 0.1094 - loss: 3.4535
I0000 00:00:1758820413.279835    2695 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
782/782 ━━━━━━━━━━━━━━━━━━━━ 138s 152ms/step - accuracy: 0.8062 - loss: 0.5749 - val_accuracy: 0.8607 - val_loss: 0.3932

==== Epoch 2/5 ====
782/782 ━━━━━━━━━━━━━━━━━━━━ 91s 116ms/step - accuracy: 0.8489 - loss: 0.4453 - val_accuracy: 0.8738 - val_loss: 0.3567

==== Epoch 3/5 ====
782/782 ━━━━━━━━━━━━━━━━━━━━ 91s 116ms/step - accuracy: 0.8579 - loss: 0.4123 - val_accuracy: 0.8785 - val_loss: 0.3525

==== Epoch 4/5 ====
 51/782 ━━━━━━━━━━━━━━━━━━━━ 1:11 98ms/step - accuracy: 0.8640 - loss: 0.4066

## 2.4 unfreeze the last few layers, train

In [13]:
from tensorflow.keras import optimizers, callbacks
from tensorflow.keras.losses import SparseCategoricalCrossentropy
# 其他层/模型同理：from tensorflow.keras import layers, models

In [17]:
import tensorflow as tf
from tensorflow.keras import optimizers, losses, callbacks

# 1) 解冻：只放开最后一个 stage（conv5），其余保持冻结
base_model.trainable = True
for layer in base_model.layers:
    name = layer.name
    # 冻结除 conv5_block* 以外的层；同时把 BN 全冻结（更稳）
    if ("conv5_block" in name):
        # conv5 的 BN 也建议不训练，避免统计量漂移
        if isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = False
        else:
            layer.trainable = True
    else:
        layer.trainable = False

# 2) 重新编译：小学习率 + label smoothing；可用 AdamW 稍微稳一点
try:
    from tensorflow_addons.optimizers import AdamW
    opt = AdamW(learning_rate=2e-5, weight_decay=1e-4)
except Exception:
    opt = optimizers.Adam(learning_rate=2e-5)


# 选择一个优化器（冻结阶段用大点 LR，解冻微调用小点 LR）
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)   # 冻结阶段
# 解冻微调时改成：opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

# 选择 loss（用 tf.keras 的枚举，兼容性最好）
loss = tf.keras.losses.SparseCategoricalCrossentropy(
    # label_smoothing=0.1,  # 你的环境如果支持就保留；不支持就删掉
    reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
)

model.compile(optimizer=opt, loss=loss, metrics=["accuracy"])

# 3) 回调（你可以沿用之前的）
cbs = [
    callbacks.ModelCheckpoint("best_finetune.h5", save_best_only=True,
                              monitor="val_accuracy", mode="max"),
    callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2, min_lr=1e-6),
    callbacks.EarlyStopping(monitor="val_accuracy", patience=4, restore_best_weights=True),
]

# 4) 微调开跑（3–6 个 epoch 通常够）
history_ft = model.fit(
    train_ds,
    epochs=5,
    validation_data=test_ds,
    callbacks=cbs
)

Epoch 1/5
    782/Unknown [1m120s[0m 137ms/step - accuracy: 0.8630 - loss: 0.4158



[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m148s[0m 172ms/step - accuracy: 0.8725 - loss: 0.3887 - val_accuracy: 0.8962 - val_loss: 0.2988 - learning_rate: 0.0010
Epoch 2/5
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 131ms/step - accuracy: 0.8996 - loss: 0.3038



[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m122s[0m 155ms/step - accuracy: 0.9013 - loss: 0.3005 - val_accuracy: 0.9020 - val_loss: 0.3012 - learning_rate: 0.0010
Epoch 3/5
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 132ms/step - accuracy: 0.9126 - loss: 0.2642



[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m122s[0m 156ms/step - accuracy: 0.9125 - loss: 0.2642 - val_accuracy: 0.9021 - val_loss: 0.2867 - learning_rate: 0.0010
Epoch 4/5
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 132ms/step - accuracy: 0.9174 - loss: 0.2502



[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 158ms/step - accuracy: 0.9177 - loss: 0.2490 - val_accuracy: 0.9040 - val_loss: 0.2961 - learning_rate: 0.0010
Epoch 5/5
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 132ms/step - accuracy: 0.9281 - loss: 0.2141



[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 157ms/step - accuracy: 0.9246 - loss: 0.2260 - val_accuracy: 0.9062 - val_loss: 0.3071 - learning_rate: 0.0010


## 2.5 freeze BatchNorm (BN) layers

In [18]:
import tensorflow as tf

# 冻结 BN 更稳
for l in base_model.layers:
    if isinstance(l, tf.keras.layers.BatchNormalization):
        l.trainable = False

# 小学习率
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),   # 如环境支持可加 label_smoothing=0.1
    metrics=['accuracy'],
)

cbs = [
    tf.keras.callbacks.ModelCheckpoint('best_finetune.h5', save_best_only=True,
                                       monitor='val_accuracy', mode='max'),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5,
                                         patience=1, min_lr=1e-6),
    tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=3,
                                     restore_best_weights=True),
]
model.fit(train_ds, epochs=3, validation_data=test_ds, callbacks=cbs)

Epoch 1/3
    782/Unknown [1m124s[0m 140ms/step - accuracy: 0.9355 - loss: 0.1936



[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m150s[0m 172ms/step - accuracy: 0.9385 - loss: 0.1799 - val_accuracy: 0.9196 - val_loss: 0.2596 - learning_rate: 1.0000e-05
Epoch 2/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 131ms/step - accuracy: 0.9445 - loss: 0.1621



[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m121s[0m 154ms/step - accuracy: 0.9458 - loss: 0.1577 - val_accuracy: 0.9236 - val_loss: 0.2547 - learning_rate: 1.0000e-05
Epoch 3/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 132ms/step - accuracy: 0.9468 - loss: 0.1717



[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m121s[0m 154ms/step - accuracy: 0.9493 - loss: 0.1518 - val_accuracy: 0.9253 - val_loss: 0.2516 - learning_rate: 1.0000e-05


<keras.src.callbacks.history.History at 0x7fed45a9b5e0>

# 3. Base model save

In [19]:
model.save("resnet50_fp32_baseline.h5")               # 结构+权重
# model.save_weights("resnet50_fp32_baseline_weights.h5")



## 3.1 model load

In [17]:
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras import layers, models

# 跟保存时完全一致的结构
base_model = ResNet50V2(include_top=False, weights=None,
                        input_shape=(224, 224, 3), pooling="avg")

# 接分类头
x = layers.Dense(256, activation="relu")(base_model.output)
x = layers.Dropout(0.5)(x)
output = layers.Dense(10, activation="softmax")(x)
model = models.Model(inputs=base_model.input, outputs=output)

In [18]:
# 推荐加载方式
model.load_weights("resnet50_fp32_baseline.h5")

In [1]:
import tensorflow as tf
from tensorflow import keras   # ✅ 用 tf.keras 入口
layers, models = keras.layers, keras.models

print("TF:", tf.__version__)            # 2.20.0
print("Keras from:", keras.__file__)    # 指向 site-packages/keras/...

2025-09-26 23:03:25.581830: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


TF: 2.20.0
Keras from: /root/miniconda3/envs/myconda/lib/python3.10/site-packages/keras/_tf_keras/keras/__init__.py


## 3.2 Record FP32 baseline accuracy

In [20]:
# 记录 FP32 基线精度（后面对比）
fp32_val = model.evaluate(test_ds, return_dict=True)
print(fp32_val)  # {'loss': ..., 'accuracy': ...}

[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 175ms/step - accuracy: 0.9253 - loss: 0.2516
{'accuracy': 0.9253000020980835, 'loss': 0.2515576481819153}


# 4. Model Quantization with PTQ (Post-Training Quantization)

## 4.1 Dynamic Range

PyTorch native quantization mainly targets CPU (x86 → FBGEMM; ARM → QNNPACK). For GPU, INT8 inference typically relies on TensorRT or ONNX Runtime EP.

**Workflow：**

1. Copy the FP32 model and apply fusion (Conv + BN + ReLU)
2. Set backend and qconfig (use fbgemm for x86)
3. Run prepare_fx (or prepare in eager mode)
4. Calibration: run forward passes on 300–1000 representative samples (no backpropagation)
5. convert_fx → obtain INT8 model
6. Evaluation/benchmark

In [21]:
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_dr = converter.convert()
open("resnet50_ptq_dynamic.tflite","wb").write(tflite_dr)

INFO:tensorflow:Assets written to: /tmp/tmp8i9hgauo/assets


INFO:tensorflow:Assets written to: /tmp/tmp8i9hgauo/assets


Saved artifact at '/tmp/tmp8i9hgauo'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  140673706896448: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673706889056: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707199120: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707197360: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707197184: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707198416: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707208976: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707207392: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707210560: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707206160: TensorSpec(shape=(), dtype=tf.resource, name=None)
  1406737072088

W0000 00:00:1758825704.212615    2644 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1758825704.212643    2644 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
I0000 00:00:1758825704.363328    2644 mlir_graph_optimization_pass.cc:437] MLIR V1 optimization pass is not enabled


24577728

## 4.2 Full-Integer Quantization (INT8):

In [25]:
from tqdm import tqdm
import tensorflow as tf

N_SAMPLES = 512  # 推荐 200~1000

def representative_gen():
    # 用 tqdm 包装，显示进度
    for i, (x, _) in enumerate(tqdm(test_ds.unbatch().take(N_SAMPLES),
                                    total=N_SAMPLES,
                                    desc="Calibrating")):
        yield [tf.cast(tf.expand_dims(x, 0), tf.float32)]

In [26]:
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# 如需全 INT8 I/O 接口（非必须）：
# converter.inference_input_type = tf.int8
# converter.inference_output_type = tf.int8
tflite_int8 = converter.convert()
open("resnet50_ptq_int8.tflite", "wb").write(tflite_int8)

INFO:tensorflow:Assets written to: /tmp/tmpbv0oliso/assets


INFO:tensorflow:Assets written to: /tmp/tmpbv0oliso/assets


Saved artifact at '/tmp/tmpbv0oliso'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  140673706896448: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673706889056: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707199120: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707197360: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707197184: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707198416: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707208976: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707207392: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707210560: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140673707206160: TensorSpec(shape=(), dtype=tf.resource, name=None)
  1406737072088

W0000 00:00:1758826593.746931    2644 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1758826593.746955    2644 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
Calibrating: 100%|██████████| 512/512 [02:46<00:00,  3.07it/s]
fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32


24784688

# 5. QAT

Insert fake quantization operators during training so the model learns to adapt to quantization errors. This approach typically achieves significantly better accuracy than PTQ.

**Workflow:**
1. Copy/initialize the FP32 model → apply fusion
2. Set QAT qconfig (e.g., get_default_qat_qconfig("fbgemm"))
3. prepare_qat_fx → continue training for 5–20 epochs (much fewer than training from scratch)
4. convert_fx → obtain the true INT8 quantized model
5. Evaluate/benchmark

In [2]:
import os
os.environ.pop("TF_USE_LEGACY_KERAS", None)

'1'

In [8]:
import os
os.environ.pop("TF_USE_LEGACY_KERAS", None)  # 建议关闭 legacy 开关，避免奇怪兼容问题

import tensorflow as tf, tensorflow_model_optimization as tfmot
from tensorflow import keras

print("TF:", tf.__version__)
print("TF-MOT:", tfmot.__version__)
print("Functional:", getattr(model, "_is_graph_network", False))  # True 才能 QAT

TF: 2.20.0
TF-MOT: 0.8.0
Functional: False


In [10]:
new_model = model

In [4]:
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.keras import compat as kcompat

# 1) 取出 TF-MOT 实际使用的 keras 命名空间（可能是 tf_keras）
keras_q = kcompat.keras
layers_q = keras_q.layers
models_q = keras_q.models

# 2) 用这一套 keras 重建同结构模型，并把你现有模型权重拷过去
#    （你的现有模型是 `model`，来自 tensorflow.keras）
inputs = keras_q.Input(shape=(224, 224, 3))
base   = keras_q.applications.ResNet50V2(include_top=False, weights=None, pooling="avg")
x = base(inputs, training=False)
x = layers_q.Dense(256, activation="relu")(x)
x = layers_q.Dropout(0.5)(x)
outputs = layers_q.Dense(10, activation="softmax")(x)
model_q = models_q.Model(inputs, outputs, name="resnet50v2_cifar10_q")

# 拷贝权重（结构一致即可拷贝成功）
model_q.set_weights(model.get_weights())

# 3) 只量化“头部”两层（避免量化嵌套子模型）
QuantizeAnnotate = tfmot.quantization.keras.quantize_annotate_layer
QuantizeApply    = tfmot.quantization.keras.quantize_apply

inp = model_q.input
h   = base(inp, training=False)  # base 是上面同命名空间构建的子模型
h   = QuantizeAnnotate(layers_q.Dense(256, activation="relu", name="qa_fc"))(h)
h   = layers_q.Dropout(0.5)(h)
out = QuantizeAnnotate(layers_q.Dense(10, activation="softmax", name="qa_logits"))(h)

annotated = models_q.Model(inp, out, name="resnet50v2_head_qat")
qat_model = QuantizeApply(annotated)

# 4) 冻结 BN 更稳（可选）
for l in qat_model.layers:
    if isinstance(l, layers_q.BatchNormalization):
        l.trainable = False

# 5) 编译训练
qat_model.compile(
    optimizer=keras_q.optimizers.Adam(1e-5),
    loss=keras_q.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)

In [24]:
# 超参数
TRAIN_BS, TEST_BS = 64, 128
EPOCHS = 2

# 固定 test_ds（验证集只建一次）
_, test_loader = make_torch_loaders(TRAIN_BS, TEST_BS)
test_ds = make_tf_dataset_from_loader(test_loader, TEST_BS)
val_steps = len(test_loader)

for ep in range(EPOCHS):
    # 每个 epoch 重建 train_loader / train_ds
    train_loader, _ = make_torch_loaders(TRAIN_BS, TEST_BS)
    train_ds = make_tf_dataset_from_loader(train_loader, TRAIN_BS)
    steps_per_epoch = len(train_loader)

    print(f"\n==== QAT Epoch {ep+1}/{EPOCHS} ====")
    qat_model.fit(
        train_ds,
        epochs=1,                         # 一次只跑 1 个 epoch
        steps_per_epoch=steps_per_epoch,  # 显式告知步数
        validation_data=test_ds,
        validation_steps=val_steps
    )


==== QAT Epoch 1/5 ====


2025-09-26 23:28:05.323215: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91300
2025-09-26 23:28:09.590331: I external/local_xla/xla/service/service.cc:163] XLA service 0x7faed4a488d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-09-26 23:28:09.590362: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA RTX A2000 12GB, Compute Capability 8.6
2025-09-26 23:28:09.598451: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1758900489.756915     935 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.



==== QAT Epoch 2/5 ====

==== QAT Epoch 3/5 ====
 62/782 [=>............................] - ETA: 4:30 - loss: 0.2078 - accuracy: 0.9337

KeyboardInterrupt: 

==== QAT Epoch 1/5 ====
782/782 [==============================] - 365s 403ms/step - loss: 0.9048 - accuracy: 0.6864 - val_loss: 0.2608 - val_accuracy: 0.9198

==== QAT Epoch 2/5 ====
782/782 [==============================] - 309s 396ms/step - loss: 0.2605 - accuracy: 0.9216 - val_loss: 0.2345 - val_accuracy: 0.9316

==== QAT Epoch 3/5 ====
 62/782 [=>............................] - ETA: 4:30 - loss: 0.2078 - accuracy: 0.9337

In [25]:
N = 512
def representative_gen():
    for x,_ in test_ds.unbatch().take(N):
        yield [tf.cast(tf.expand_dims(x,0), tf.float32)]

conv = tf.lite.TFLiteConverter.from_keras_model(qat_model)
conv.optimizations = [tf.lite.Optimize.DEFAULT]
conv.representative_dataset = representative_gen
conv.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_qat = conv.convert()
open("resnet50_qat_int8.tflite","wb").write(tflite_qat)

INFO:tensorflow:Assets written to: /tmp/tmpfcyt0xzx/assets


INFO:tensorflow:Assets written to: /tmp/tmpfcyt0xzx/assets
W0000 00:00:1758901242.453378     868 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1758901242.453429     868 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2025-09-26 23:40:42.453919: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpfcyt0xzx
2025-09-26 23:40:42.505355: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-09-26 23:40:42.505409: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpfcyt0xzx
I0000 00:00:1758901242.747376     868 mlir_graph_optimization_pass.cc:437] MLIR V1 optimization pass is not enabled
2025-09-26 23:40:42.785192: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-09-26 23:40:44.165408: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpfcyt0xzx
2025-09-26 23:40:44.509

24815928

# 6. Results Comparison (FP32 vs PTQ vs QAT)

In [26]:
import time, numpy as np, tensorflow as tf

def tflite_bench(tflite_path, sample):
    inter = tf.lite.Interpreter(model_path=tflite_path)
    inter.allocate_tensors()
    inp = inter.get_input_details()[0]["index"]
    out = inter.get_output_details()[0]["index"]
    # 预热
    for _ in range(20):
        inter.set_tensor(inp, sample)
        inter.invoke()
    # 计时
    ts=[]
    for _ in range(200):
        t0=time.perf_counter()
        inter.set_tensor(inp, sample)
        inter.invoke()
        ts.append((time.perf_counter()-t0)*1000)
    return np.percentile(ts,[50,90])

# 取一小批样本做基准
x_one, _ = next(iter(test_ds.take(1)))
x_one = x_one[:1].numpy()
print("PTQ-dynamic ms P50/P90:", tflite_bench("resnet50_ptq_dynamic.tflite", x_one))
print("PTQ-int8   ms P50/P90:", tflite_bench("resnet50_ptq_int8.tflite", x_one))
print("QAT-int8   ms P50/P90:", tflite_bench("resnet50_qat_int8.tflite", x_one))

    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


PTQ-dynamic ms P50/P90: [135.43599844 138.98892291]
PTQ-int8   ms P50/P90: [120.72722614 124.73276444]
QAT-int8   ms P50/P90: [120.73456123 125.14217161]


In [29]:
import os, time, numpy as np, tensorflow as tf

# ========= 通用工具 =========
def _quantize_like(interpreter, x_float):
    """将 float32 NHWC 批量输入量化/或直喂，返回喂给 TFLite 的 numpy 数组。"""
    d = interpreter.get_input_details()[0]
    dtype = d["dtype"]
    if dtype == np.float32:
        return x_float.astype(np.float32)
    elif dtype == np.int8:
        scale, zero = d["quantization"]
        # 将 float 量化成 int8；注意 clip 防溢出
        return np.clip(np.round(x_float / scale + zero), -128, 127).astype(np.int8)
    else:
        raise ValueError(f"Unsupported input dtype: {dtype}")

def _dequantize_like(interpreter, y):
    """将 TFLite 输出（可能是 int8/float32）统一成 float32 概率/对数its。"""
    d = interpreter.get_output_details()[0]
    dtype = d["dtype"]
    if dtype == np.float32:
        return y.astype(np.float32)
    elif dtype == np.int8:
        scale, zero = d["quantization"]
        return (y.astype(np.float32) - zero) * scale
    else:
        raise ValueError(f"Unsupported output dtype: {dtype}")

# ========= TFLite 延迟测试 =========
def tflite_bench(tflite_path, sample, runs=200, warmup=20, num_threads=1):
    inter = tf.lite.Interpreter(model_path=tflite_path, num_threads=num_threads)
    inter.allocate_tensors()
    i_idx = inter.get_input_details()[0]["index"]
    o_idx = inter.get_output_details()[0]["index"]

    x = _quantize_like(inter, sample)

    # 预热
    for _ in range(warmup):
        inter.set_tensor(i_idx, x)
        inter.invoke()

    ts = []
    for _ in range(runs):
        t0 = time.perf_counter()
        inter.set_tensor(i_idx, x)
        inter.invoke()
        _ = inter.get_tensor(o_idx)
        ts.append((time.perf_counter() - t0) * 1000)
    ts = np.array(ts)
    return np.percentile(ts, [50, 90])

# ========= TFLite Top-1 准确率 =========
def tflite_top1_accuracy(tflite_path, ds, max_samples=None, num_threads=1):
    inter = tf.lite.Interpreter(model_path=tflite_path, num_threads=num_threads)
    inter.allocate_tensors()
    i_idx = inter.get_input_details()[0]["index"]
    o_idx = inter.get_output_details()[0]["index"]

    n_correct, n_total = 0, 0
    for x, y in ds:
        x_np = x.numpy()
        y_np = y.numpy()
        # 逐条推理（batch=1延迟更稳定；也可批量，但多数 TFLite 模型是 for b=1）
        for i in range(x_np.shape[0]):
            xi = x_np[i:i+1]
            inter.set_tensor(i_idx, _quantize_like(inter, xi))
            inter.invoke()
            logits = _dequantize_like(inter, inter.get_tensor(o_idx))
            pred = np.argmax(logits, axis=-1)[0]
            n_correct += int(pred == int(y_np[i]))
            n_total += 1
            if (max_samples is not None) and (n_total >= max_samples):
                acc = n_correct / n_total
                return acc
    acc = n_correct / max(n_total, 1)
    return acc

# ========= Keras Top-1 =========
def keras_top1_accuracy(keras_model, ds):
    res = keras_model.evaluate(ds, verbose=0, return_dict=True)
    return float(res.get("accuracy", res.get("acc", 0.0)))

# ========= 汇总评测 =========
def bytes_mb(path): return os.path.getsize(path)/1e6 if os.path.exists(path) else None

def evaluate_all(model_fp32, test_ds,
                 tflite_dyn_path="resnet50_ptq_dynamic.tflite",
                 tflite_int8_path="resnet50_ptq_int8.tflite",
                 tflite_qat_path="resnet50_qat_int8.tflite",
                 bench_threads=1,
                 acc_threads=1,
                 acc_samples=None):
    # 取一个样本做延迟评测（与预处理后的张量一致）
    xb, yb = next(iter(test_ds.take(1)))
    x_one = xb[:1].numpy()

    results = []

    # 1) FP32 Keras
    try:
        acc = keras_top1_accuracy(model_fp32, test_ds)
    except Exception as e:
        acc = None
        print("[WARN] Keras acc eval failed:", e)
    results.append({
        "name": "FP32 (Keras)",
        "size_MB": None,
        "acc": acc,
        "p50_ms": None,
        "p90_ms": None
    })

    # 2) PTQ-dynamic
    if os.path.exists(tflite_dyn_path):
        acc = tflite_top1_accuracy(tflite_dyn_path, test_ds, max_samples=acc_samples, num_threads=acc_threads)
        p50, p90 = tflite_bench(tflite_dyn_path, x_one, num_threads=bench_threads)
        results.append({
            "name": "PTQ-dynamic",
            "size_MB": bytes_mb(tflite_dyn_path),
            "acc": acc,
            "p50_ms": float(p50),
            "p90_ms": float(p90),
        })

    # 3) PTQ-int8
    if os.path.exists(tflite_int8_path):
        acc = tflite_top1_accuracy(tflite_int8_path, test_ds, max_samples=acc_samples, num_threads=acc_threads)
        p50, p90 = tflite_bench(tflite_int8_path, x_one, num_threads=bench_threads)
        results.append({
            "name": "PTQ-int8",
            "size_MB": bytes_mb(tflite_int8_path),
            "acc": acc,
            "p50_ms": float(p50),
            "p90_ms": float(p90),
        })

    # 4) QAT-int8
    if os.path.exists(tflite_qat_path):
        acc = tflite_top1_accuracy(tflite_qat_path, test_ds, max_samples=acc_samples, num_threads=acc_threads)
        p50, p90 = tflite_bench(tflite_qat_path, x_one, num_threads=bench_threads)
        results.append({
            "name": "QAT-int8",
            "size_MB": bytes_mb(tflite_qat_path),
            "acc": acc,
            "p50_ms": float(p50),
            "p90_ms": float(p90),
        })

    # 打印表格
    print("\n=== Summary (threads: bench={}, acc={}; samples={} ) ===".format(
        bench_threads, acc_threads, acc_samples if acc_samples else "ALL"))
    print("{:<12} {:>10} {:>10} {:>10} {:>10}".format("Model", "Size(MB)", "Top1", "P50(ms)", "P90(ms)"))
    for r in results:
        print("{:<12} {:>10} {:>10} {:>10} {:>10}".format(
            r["name"],
            "-" if r["size_MB"] is None else f"{r['size_MB']:.2f}",
            "-" if r["acc"] is None else f"{r['acc']*100:.2f}",
            "-" if r["p50_ms"] is None else f"{r['p50_ms']:.2f}",
            "-" if r["p90_ms"] is None else f"{r['p90_ms']:.2f}",
        ))
    return results

# ========== 调用示例 ==========
# 已有的：model（FP32 Keras 基线）、test_ds、三个 tflite 文件路径
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"]
)
results = evaluate_all(
    model_fp32=model,
    test_ds=test_ds,
    tflite_dyn_path="resnet50_ptq_dynamic.tflite",
    tflite_int8_path="resnet50_ptq_int8.tflite",
    tflite_qat_path="resnet50_qat_int8.tflite",
    bench_threads=4,   
    acc_threads=1,
    acc_samples=5000,  # 为了速度可先抽样评估 5k 张
)


2025-09-26 23:52:44.911383: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2025-09-26 23:52:44.911448: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 4409749761582512711
2025-09-26 23:52:44.911470: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 17874141467772767791



=== Summary (threads: bench=4, acc=1; samples=5000 ) ===
Model          Size(MB)       Top1    P50(ms)    P90(ms)
FP32 (Keras)          -      92.53          -          -
PTQ-dynamic       24.58      92.20      37.88      45.27
PTQ-int8          24.78      89.82      33.46      39.35
QAT-int8          24.82      91.02      31.95      37.93


# 7. Model Export and Deployment

In [30]:
def representative_gen():
    for x, _ in test_ds.take(200):  # 用 200 个 batch 作为代表性样本
        for i in range(x.shape[0]):
            yield [tf.cast(x[i:i+1], tf.float32)]

In [31]:
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

# 输入/输出改为 int8
converter.inference_input_type  = tf.int8
converter.inference_output_type = tf.int8

tflite_model = converter.convert()

with open("resnet50_qat_int8.tflite", "wb") as f:
    f.write(tflite_model)

INFO:tensorflow:Assets written to: /tmp/tmpq9idq8kz/assets


INFO:tensorflow:Assets written to: /tmp/tmpq9idq8kz/assets
W0000 00:00:1758904392.994792     868 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1758904392.994845     868 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2025-09-27 00:33:12.995147: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpq9idq8kz
2025-09-27 00:33:13.053817: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-09-27 00:33:13.053855: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpq9idq8kz
2025-09-27 00:33:13.323445: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-09-27 00:33:14.701602: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpq9idq8kz
2025-09-27 00:33:15.062202: I tensorflow/cc/saved_model/loader.cc:471] SavedModel load for tags { serve }; Status: success: OK. Took 206705

KeyboardInterrupt: 

In [32]:
interpreter = tf.lite.Interpreter(model_path="resnet50_qat_int8.tflite")
interpreter.allocate_tensors()

inp = interpreter.get_input_details()[0]
out = interpreter.get_output_details()[0]

# 取 1 张 CIFAR-10 样本
x_one, _ = next(iter(test_ds.take(1)))
x_one = x_one[:1].numpy()

# 按输入量化
scale, zero = inp["quantization"]
x_quant = (x_one / scale + zero).astype(np.int8)

interpreter.set_tensor(inp["index"], x_quant)
interpreter.invoke()
y_pred = interpreter.get_tensor(out["index"])
print("Predicted:", y_pred.argmax())

    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
  x_quant = (x_one / scale + zero).astype(np.int8)
  x_quant = (x_one / scale + zero).astype(np.int8)


ValueError: Cannot set tensor: Got value of type INT8 but expected type FLOAT32 for input 0, name: serving_default_input_5:0 

## 1. TorchScript（CPU inference）

In [None]:
# src/export.py
import torch
m_int8 = ...  # 加载 PTQ 或 QAT 模型到 CPU
m_int8.eval()
example = torch.randn(1,3,32,32)
ts = torch.jit.trace(m_int8, example)
ts = torch.jit.freeze(ts)
ts.save("results/resnet18_int8_ts.pt")

## 2. ONNX

In [None]:
torch.onnx.export(m_int8, example, "results/resnet18_int8.onnx",
                  input_names=["input"], output_names=["logits"],
                  opset_version=17, dynamic_axes={"input":{0:"N"}, "logits":{0:"N"}})