In [1]:
import torch
import numpy as np
import random
import traceback as tb

In [2]:
from torch._export import capture_pre_autograd_graph
from torch.export import export, ExportedProgram

def f(x, y):
    return x + y
example_args = (torch.randn(3, 3), torch.randn(3, 3))

pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)

AssertionError: Expected an nn.Module instance.

# Some details

Different dimensions than the ones expected from our model (obviously) break things

In [None]:
# Works correctly
print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3)))

# Errors
try:
    print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2)))
except Exception as e:
    tb.print_exc(limit=1)

But there can be some dynamism. We can put some constraints on the inputs

In [None]:
from torch.export import dynamic_dim
constraints = [
    # Input 0, dimension 1 is dynamic
    dynamic_dim(example_args[0], 1),
    # Input 0, dimension 1 must be greater than or equal to 1
    1 <= dynamic_dim(example_args[0], 1),
    # Input 0, dimension 1 must be less than or equal to 10
    dynamic_dim(example_args[0], 1) <= 10,
    # Input 1, dimension 1 is equal to input 0, dimension 1
    dynamic_dim(example_args[1], 1) == dynamic_dim(example_args[0], 1),
]
pre_autograd_aten_dialect = capture_pre_autograd_graph(
    f, example_args, constraints=constraints
)
aten_dialect: ExportedProgram = export(f, example_args, constraints=constraints)

# Works correctly
print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3)))
print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2)))

# Errors because it violates our constraint that input 0, dim 1 <= 10
try:
    print(aten_dialect(torch.ones(3, 15), torch.ones(3, 15)))
except Exception:
    tb.print_exc(limit=1)
    
# Errors because it violates our constraint that input 0, dim 1 == input 1, dim 1
try:
    print(aten_dialect(torch.ones(3, 3), torch.ones(3, 2)))
except Exception:
    tb.print_exc(limit=1)