In [11]:
import torch
import torch.nn as nn
import numpy as np
import onnx
import onnxruntime
from onnxruntime.training import artifacts
import torch.nn.functional as F


In [12]:
torch.__version__, onnx.__version__, onnxruntime.__version__

('2.1.0', '1.14.1', '1.16.3')

In [13]:
path_to_forward_only_onnx_model = 'lstm_model.onnx'

# Load the forward-only ONNX model
model = onnx.load(path_to_forward_only_onnx_model)

# Extract model's parameters
all_params = [param.name for param in model.graph.initializer]

trainable_layers = ['fc', 'onnx']
requires_grad = [param for param in all_params if any(layer in param for layer in trainable_layers)]
frozen_params = [param for param in all_params if param not in requires_grad]
print(requires_grad, frozen_params)

print(model.graph.output[0].name)

['fc.weight', 'fc.bias', 'onnx::LSTM_109', 'onnx::LSTM_110', 'onnx::LSTM_111'] []
my_output


In [14]:
path_to_output_artifact_directory = 'training_artifacts'
artifacts.generate_artifacts(model,
                             requires_grad=requires_grad,
                             frozen_params=frozen_params,
                             loss=artifacts.LossType.CrossEntropyLoss,
                             optimizer=artifacts.OptimType.AdamW,
                             artifact_directory=path_to_output_artifact_directory)

2023-12-26 13:21:27,334 root [INFO] - Loss function enum provided: CrossEntropyLoss
2023-12-26 13:21:27,336 root [DEBUG] - Building training block _TrainingBlock
2023-12-26 13:21:27,336 root [DEBUG] - Building block: CrossEntropyLoss
2023-12-26 13:21:27,346 root [DEBUG] - Building gradient graph for training block _TrainingBlock
2023-12-26 13:21:27.319396 [I:onnxruntime:Default, constant_sharing.cc:256 ApplyImpl] Total shared scalar initializer count: 1
2023-12-26 13:21:27,371 root [DEBUG] - The loss output is onnx::loss::8. The gradient graph will be built starting from onnx::loss::8_grad.
2023-12-26 13:21:27,379 root [DEBUG] - Adding gradient accumulation nodes for training block _TrainingBlock
2023-12-26 13:21:27,381 root [INFO] - Training model path training_artifacts/training_model.onnx already exists. Overwriting.
2023-12-26 13:21:27,384 root [INFO] - Saved training model to training_artifacts/training_model.onnx
2023-12-26 13:21:27,385 root [INFO] - Eval model path training_arti

In [5]:
from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Assuming the paths are correctly defined
path_to_the_checkpoint_artifact = 'training_artifacts/checkpoint'
path_to_the_training_model = 'training_artifacts/training_model.onnx'
path_to_the_eval_model = 'training_artifacts/eval_model.onnx'
path_to_the_optimizer_model = 'training_artifacts/optimizer_model.onnx'

# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)

# Create the module
module = Module(path_to_the_training_model,
                state,
                path_to_the_eval_model,
                device="cpu")

optimizer = Optimizer(path_to_the_optimizer_model, module)

In [6]:
import numpy as np

def generate_training_data(data_size):
    # Generate random integers for input data X
    X = np.random.uniform(0.0, 10.0, (data_size, 6))

    # Compute output data y
    y = (np.sum(X, axis=1) / 20)
    y = y.astype(int)
    # Compute Y as the sum of each row in X divided by 3

    return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

X, y = generate_training_data(500)

X_train = X.detach().numpy().reshape(-1, 6, 1)
y_train = y.detach().numpy()

In [7]:
# Define epochs and batch size
epochs = 100
batch_size = 8  # You can adjust the batch size as needed

# Training loop

for epoch in range(epochs):
    for i in range(0, len(X_train), batch_size):
        # Extract batches
        batch_X = X_train[i:i + batch_size]
        batch_y = y_train[i:i + batch_size]

        # Set the module to training mode
        module.train()

        # Forward pass (assuming the module accepts input and target)
        training_loss = module(batch_X, batch_y)

        # Backward pass and optimization
        optimizer.step()

        # Reset gradients
        module.lazy_reset_grad()

    # Print epoch statistics, etc.
    if (epoch+1) % 10 == 0:
        print(f'Epoch {epoch + 1}/{epochs}, Training Loss: {training_loss}')

Epoch 10/100, Training Loss: 0.11016654968261719
Epoch 20/100, Training Loss: 0.04419594258069992
Epoch 30/100, Training Loss: 0.050201013684272766
Epoch 40/100, Training Loss: 0.05832836031913757
Epoch 50/100, Training Loss: 0.049569305032491684
Epoch 60/100, Training Loss: 0.027245810255408287
Epoch 70/100, Training Loss: 0.023111190646886826
Epoch 80/100, Training Loss: 0.015518329106271267
Epoch 90/100, Training Loss: 0.030633224174380302
Epoch 100/100, Training Loss: 0.01015745010226965


In [8]:
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)


In [10]:
# Assuming 'my_output' is the name of your onnx model's output
output_names = ['my_output']  # List of output names

# Export the model for inferencing
module.export_model_for_inferencing('inference_model.onnx', output_names)

import onnxruntime

ort_session = onnxruntime.InferenceSession("inference_model.onnx", providers=["CPUExecutionProvider"])
# ort_session = onnxruntime.InferenceSession("classification_model.onnx", providers=["CPUExecutionProvider"])


ort_inputs = {ort_session.get_inputs()[0].name: X_train[10:15]}

ort_outs = ort_session.run(None, ort_inputs)

print(ort_outs[0].argmax(axis=1), y_train[10:15])

[1 2 1 1 2] [1 2 1 1 2]
