In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

batch_size = 64
learning_rate = 0.001
num_epochs = 5

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

torch.save(model.state_dict(), 'model.pth')

Epoch [1/5], Loss: 0.1060
Epoch [2/5], Loss: 0.3784
Epoch [3/5], Loss: 0.0432
Epoch [4/5], Loss: 0.2595
Epoch [5/5], Loss: 0.0130


In [33]:
dummy_input = torch.randn(1, 1, 28, 28)  
onnx_file_path = "model.onnx"
torch.onnx.export(model, dummy_input, onnx_file_path, export_params=True, 
                  opset_version=11, do_constant_folding=True, 
                  input_names=['input'], output_names=['output'])
print(f'Model exported to {onnx_file_path}')

Model exported to model.onnx


In [101]:
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [137]:
import time
import onnx
import onnxruntime as ort

onnx_model = onnx.load(onnx_file_path)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(onnx_file_path)

def test_speed(model, data_loader, framework='pytorch'):
    model.eval()
    total_time = 0
    with torch.no_grad():
        for images, _ in data_loader:
            if framework == 'pytorch':
                start_time = time.time()
                model(images)
                total_time += time.time() - start_time
            else:  
                for img in images:  
                    start_time = time.time()
                    ort_inputs = {ort_session.get_inputs()[0].name: img.unsqueeze(0).numpy()} 
                    ort_session.run(None, ort_inputs)
                    total_time += time.time() - start_time
    return total_time / len(data_loader)

onnx_speed = test_speed(model, test_loader, framework='onnx')
print(f'Average inference time (ONNX): {onnx_speed:.6f} seconds per batch')

pytorch_speed = test_speed(model, test_loader, framework='pytorch')
print(f'Average inference time (PyTorch): {pytorch_speed:.6f} seconds per batch')


Average inference time (ONNX): 0.000790 seconds per batch
Average inference time (PyTorch): 0.000046 seconds per batch


In [163]:
import time
import torch
import onnx
import onnxruntime as ort
from sklearn.metrics import accuracy_score

# Загрузка ONNX модели
onnx_model = onnx.load(onnx_file_path)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(onnx_file_path)

def test_speed(model, data_loader, framework='pytorch'):
    total_time = 0
    with torch.no_grad():
        for images, _ in data_loader:
            if framework == 'pytorch':
                start_time = time.time()
                model(images)
                total_time += time.time() - start_time
            else:  
                for img in images:  
                    start_time = time.time()
                    ort_inputs = {ort_session.get_inputs()[0].name: img.unsqueeze(0).numpy()}  # Добавляем размер батча
                    ort_session.run(None, ort_inputs)
                    total_time += time.time() - start_time
    return total_time / len(data_loader)

def test_accuracy(model, data_loader, framework='pytorch'):
    all_preds = []
    all_labels = []
    preds = []
    with torch.no_grad():
        for images, labels in data_loader:      
            if framework == 'pytorch':
                outputs = model(images)
                _, pred = torch.max(outputs, 1)
                preds.extend(pred)
            else:
                for img in images:  
                    ort_inputs = {ort_session.get_inputs()[0].name: img.unsqueeze(0).numpy()}   # Преобразование в numpy
                    outputs = ort_session.run(None, ort_inputs)
                    preds.extend(torch.tensor(outputs[0]).argmax(axis=1).numpy())
                    
            all_preds = preds
            all_labels.extend(labels.numpy())
    accuracy = accuracy_score(all_labels, all_preds)
    return accuracy

# Измерение скорости
onnx_speed = test_speed(model, train_loader, framework='onnx')
print(f'Average inference time (ONNX): {onnx_speed:.6f} seconds per batch')

pytorch_speed = test_speed(model, train_loader, framework='pytorch')
print(f'Average inference time (PyTorch): {pytorch_speed:.6f} seconds per batch')

# Измерение точности
onnx_accuracy = test_accuracy(model, test_loader, framework='onnx')
print(f'Accuracy (ONNX): {onnx_accuracy:.4f}')

pytorch_accuracy = test_accuracy(model, test_loader, framework='pytorch')
print(f'Accuracy (PyTorch): {pytorch_accuracy:.4f}')

Average inference time (ONNX): 0.000741 seconds per batch
Average inference time (PyTorch): 0.000048 seconds per batch
Accuracy (ONNX): 0.9725
Accuracy (PyTorch): 0.9725
