In [1]:
import datasets
import numpy as np
import torch

mnist = datasets.load_dataset("mnist")

xtrain = torch.FloatTensor(np.array(mnist["train"]["image"]).reshape(-1, 784) / 255)
ytrain = torch.LongTensor(np.array(mnist["train"]["label"]))

xtest = torch.FloatTensor(np.array(mnist["test"]["image"]).reshape(-1, 784) / 255)
ytest = torch.LongTensor(np.array(mnist["test"]["label"]))

print(xtrain.shape, ytrain.shape)
print(xtest.shape, ytest.shape)

torch.Size([60000, 784]) torch.Size([60000])
torch.Size([10000, 784]) torch.Size([10000])


In [2]:
device = "cuda"

model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10),
)

optimizer = torch.optim.Adam(model.parameters())

train_loader = torch.utils.data.DataLoader(list(zip(xtrain, ytrain)), batch_size=64)
val_loader = torch.utils.data.DataLoader(list(zip(xtest, ytest)), batch_size=64)


def compute_metrics(loader):
    n_correct, n_total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        preds = torch.argmax(logits, axis=-1)
        n_correct += (preds == y).sum()
        n_total += len(x)
    return n_correct / n_total


pt_model = model.to(device)
for epoch in range(5):
    print(f"epoch {epoch}, test acc = {compute_metrics(val_loader)}")
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
print(f"final test acc = {compute_metrics(val_loader)}")



epoch 0, test acc = 0.13099999725818634
epoch 1, test acc = 0.9383999705314636
epoch 2, test acc = 0.9569000005722046
epoch 3, test acc = 0.9620999693870544
epoch 4, test acc = 0.9649999737739563
final test acc = 0.9681999683380127


In [3]:
import io

import onnx

# Generate a random input.
example_input = torch.randn(64, 784, device=device)

input_names = ["x"]
output_names = ["y"]
dynamic_axes = {"x": {0: "batch_size"}, "y": {0: "batch_size"}}

f = io.BytesIO()
torch.onnx.export(
    pt_model,
    example_input,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

In [5]:
import os

from onnxruntime.training import artifacts

requires_grad = [
    name for name, param in model.named_parameters() if param.requires_grad
]

frozen_params = [
    name for name, param in model.named_parameters() if not param.requires_grad
]

out_dir = "mlp"
os.makedirs(out_dir, exist_ok=True)

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    additional_output_names=output_names,
    artifact_directory=out_dir,
)

2024-03-04 22:29:33.942613446 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer ConstantSharing modified: 0 with status: OK
2024-03-04 22:29:33.942637632 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer LayerNormFusion modified: 0 with status: OK
2024-03-04 22:29:33.942653101 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer CommonSubexpressionElimination modified: 0 with status: OK
2024-03-04 22:29:33.942658862 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer GeluFusion modified: 0 with status: OK
2024-03-04 22:29:33.942663551 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer SimplifiedLayerNormFusion modified: 0 with status: OK
2024-03-04 22:29:33.942668740 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer FastGeluFusion modified: 0 with status: OK
2024-03-04 22:29:33.942673559 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer Qui