In [1]:
import os
import torch
import pickle
import numpy as np
from tqdm import tqdm
from datetime import datetime
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Subset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
def unpickle(file):
    with open(file, "rb") as fo:
        dict = pickle.load(fo, encoding="bytes")

    return dict
def load_data(data_path):
    data = unpickle(data_path)
    # print(data.keys())
    _data = np.array(data["data"])
    _labels = np.array(data["labels"])
    print("data loaded.")
    return _data, _labels

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels, out_channels,
            kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        # shortcut（当通道数或 stride 不一致时）
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResidualCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.layer1 = ResidualBlock(3, 32)
        self.layer2 = ResidualBlock(32, 64)
        self.layer3 = ResidualBlock(64, 128)

        self.pool = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(self.layer1(x))   # 32×32 → 16×16
        x = self.pool(self.layer2(x))   # 16×16 → 8×8
        x = self.pool(self.layer3(x))   # 8×8 → 4×4

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def get_accuracy(model, loader, topk=1):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.topk(outputs.data, topk, dim=1)
            total += labels.size(0)
            for i in range(labels.size(0)):
                if labels[i] in predicted[i]:
                    correct += 1
    return correct / total

def train_one_epoch(model, criterion, optimizer, train_loader, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in tqdm(train_loader, desc="Training"):
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def validate_model(model, criterion, test_loader, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_loss /= total
    test_accuracy = correct / total
    return test_loss, test_accuracy

## 数据处理

In [8]:
# 读取原始数据
data, labels = load_data("HASYv2")
labels = labels.squeeze(1)

data = np.transpose(data, (3, 2, 0, 1))
data = torch.from_numpy(data).float()
labels = torch.from_numpy(labels).long()

data loaded.


In [19]:
# 读取增强数据
ckpt = torch.load("HASYv2_balanced_500.pt", map_location="cpu")

data = ckpt["data"]
labels = ckpt["labels"]

In [20]:
print(data.shape)
print(labels.shape)

torch.Size([398626, 3, 32, 32])
torch.Size([398626])


In [21]:
dataset = TensorDataset(data, labels)

labels_np = labels.numpy()

splitter = StratifiedShuffleSplit(
    n_splits=1,
    test_size=0.2,
    random_state=114
)

train_idx, test_idx = next(splitter.split(np.zeros(len(labels_np)), labels_np))

train_dataset = Subset(dataset, train_idx)
test_dataset  = Subset(dataset, test_idx)

batch_size = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

## 训练

In [22]:
num_classes = 370
num_epochs = 7
lr = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = SimpleCNN(num_classes).to(device)
model = ResidualCNN(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(num_epochs):
    print(f"\nEpoch [{epoch+1}/{num_epochs}]")

    train_loss, train_acc = train_one_epoch(
        model, criterion, optimizer, train_loader, device
    )
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

    val_loss, val_acc = validate_model(
        model, criterion, test_loader, device
    )
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")


Epoch [1/7]


Training: 100%|██████████| 4983/4983 [00:43<00:00, 115.32it/s]

Train Loss: 1.0779, Train Acc: 0.6953





Val Loss: 0.7288, Val Acc: 0.7729

Epoch [2/7]


Training: 100%|██████████| 4983/4983 [00:43<00:00, 113.68it/s]

Train Loss: 0.6250, Train Acc: 0.7957





Val Loss: 0.5884, Val Acc: 0.8064

Epoch [3/7]


Training: 100%|██████████| 4983/4983 [00:44<00:00, 111.37it/s]

Train Loss: 0.5131, Train Acc: 0.8263





Val Loss: 0.5257, Val Acc: 0.8267

Epoch [4/7]


Training: 100%|██████████| 4983/4983 [00:44<00:00, 113.17it/s]

Train Loss: 0.4401, Train Acc: 0.8462





Val Loss: 0.4981, Val Acc: 0.8338

Epoch [5/7]


Training: 100%|██████████| 4983/4983 [00:42<00:00, 116.23it/s]

Train Loss: 0.3853, Train Acc: 0.8627





Val Loss: 0.4901, Val Acc: 0.8354

Epoch [6/7]


Training: 100%|██████████| 4983/4983 [00:43<00:00, 115.81it/s]

Train Loss: 0.3423, Train Acc: 0.8758





Val Loss: 0.4746, Val Acc: 0.8450

Epoch [7/7]


Training: 100%|██████████| 4983/4983 [00:42<00:00, 116.53it/s]

Train Loss: 0.3078, Train Acc: 0.8868





Val Loss: 0.4681, Val Acc: 0.8491


## 模型效果测试

In [23]:
for i in range(1,6):
    print(get_accuracy(model, test_loader, topk=i))

0.8491332814891002
0.9521360660261395
0.9740109876326418
0.9833178636831147
0.9879838446679878


In [24]:
from collections import Counter

train_count = Counter(labels.tolist())

def bucket_by_count(n):
    if n >= 1000:
        return "head (>=1000)"
    elif n >= 200:
        return "mid (200-999)"
    elif n >= 50:
        return "tail (50-199)"
    else:
        return "extreme-tail (<50)"


model.eval()

per_class_correct = Counter()
per_class_total = Counter()

with torch.no_grad():
    for inputs, labels_batch in test_loader:
        inputs = inputs.to(device)
        labels_batch = labels_batch.to(device)

        outputs = model(inputs)
        preds = torch.argmax(outputs, dim=1)

        for y_true, y_pred in zip(labels_batch, preds):
            y_true = int(y_true.item())
            per_class_total[y_true] += 1
            if y_true == int(y_pred.item()):
                per_class_correct[y_true] += 1

bucket_acc = {
    "head (>=1000)": [],
    "mid (200-999)": [],
    "tail (50-199)": [],
    "extreme-tail (<50)": []
}

for cls in per_class_total:
    acc = per_class_correct[cls] / per_class_total[cls]
    bucket = bucket_by_count(train_count[cls])
    bucket_acc[bucket].append(acc)

print("Per-bucket accuracy:")
for bucket, accs in bucket_acc.items():
    if len(accs) == 0:
        continue
    print(
        f"{bucket:22s} | "
        f"classes: {len(accs):3d} | "
        f"mean acc: {sum(accs)/len(accs):.3f}"
    )

Per-bucket accuracy:
head (>=1000)          | classes: 369 | mean acc: 0.844


## 转成ONNX

In [25]:
model.eval()
dummy_input = torch.randn(1, 3, 32, 32, device=device)

import torch.onnx

onnx_path = "simplecnn_fp32.onnx"

torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=13,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "logits": {0: "batch_size"}
    }
)

In [83]:
!pip install onnx onnxruntime onnxruntime-tools

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Collecting onnx
  Downloading http://mirrors.aliyun.com/pypi/packages/fb/71/d3fec0dcf9a7a99e7368112d9c765154e81da70fcba1e3121131a45c245b/onnx-1.20.1-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.5/17.5 MB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting onnxruntime
  Downloading http://mirrors.aliyun.com/pypi/packages/ef/88/9cc25d2bafe6bc0d4d3c1db3ade98196d5b355c0b273e6a5dc09c5d5d0d5/onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting onnxruntime-tools
  Downloading http://mirrors.aliyun.com/pypi/packages/6f/b0/db0e73356df0aaa8737e6f13c0dac499b5d904d3fa267c8ebf24515e8001/onnxruntime_tools-1.7.0-py3-none-any.whl (212 kB)


In [26]:
import numpy as np
from torch.utils.data import DataLoader
from onnxruntime.quantization import CalibrationDataReader

class CalibDataReader(CalibrationDataReader):
    def __init__(self, dataloader, num_batches=10):
        self.dataloader = dataloader
        self.iterator = iter(dataloader)
        self.num_batches = num_batches
        self.count = 0

    def get_next(self):
        if self.count >= self.num_batches:
            return None
        self.count += 1

        inputs, _ = next(self.iterator)
        return {"input": inputs.numpy()}

calib_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True
)

calib_reader = CalibDataReader(calib_loader, num_batches=20)


In [27]:
from onnxruntime.quantization import quantize_static, QuantType

quant_onnx_path = "residualcnn_augument_int8.onnx"

quantize_static(
    model_input=onnx_path,
    model_output=quant_onnx_path,
    calibration_data_reader=calib_reader,
    weight_type=QuantType.QInt8,
    activation_type=QuantType.QUInt8
)




## ONNX推理

In [28]:
# ONNX推理
import onnxruntime as ort

sess = ort.InferenceSession(
    quant_onnx_path,
    providers=["CPUExecutionProvider"]
)
def onnx_infer(session, inputs):
    ort_inputs = {"input": inputs.numpy()}
    logits = session.run(None, ort_inputs)[0]
    return logits

correct = 0
total = 0

for inputs, labels in test_loader:
    logits = onnx_infer(sess, inputs)
    preds = np.argmax(logits, axis=1)

    correct += (preds == labels.numpy()).sum()
    total += labels.size(0)

print("ONNX INT8 Acc:", correct / total)

ONNX INT8 Acc: 0.8485061335072623


In [91]:
for inputs, labels in test_loader:
    print(inputs.shape)
    print(labels.shape)
    break

torch.Size([64, 3, 32, 32])
torch.Size([64])


In [92]:
for i in sess.get_inputs():
    print(i.name, i.shape, i.type)

input ['batch_size', 3, 32, 32] tensor(float)
