In [30]:
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ConditionalModel(nn.Module):
    def __init__(self):
        super(ConditionalModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)
    def forward(self, x):
        if x.sum() > 0:
            x = self.fc1(x)
        else:
            x = self.fc2(x)
        x = self.relu(x)
        return x
model = ConditionalModel()
model = model.to(device)
model.eval()

ConditionalModel(
  (fc1): Linear(in_features=10, out_features=5, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=10, out_features=3, bias=True)
)

In [60]:
example_input = torch.randn(5, 10)
torch.export.export(model, (example_input,))

UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/var/folders/n1/ygzk0n895x15grwb5rm1ld300000gn/T/ipykernel_38275/3602047528.py", line 11, in forward
    if x.sum() > 0:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


In [32]:
compiled_model = torch.compile(model)
compiled_model

OptimizedModule(
  (_orig_mod): ConditionalModel(
    (fc1): Linear(in_features=10, out_features=5, bias=True)
    (relu): ReLU()
    (fc2): Linear(in_features=10, out_features=3, bias=True)
  )
)

In [33]:
input_data = torch.randn(5, 10).to(device)
model_output = compiled_model(input_data)
model_output

tensor([[0.7042, 0.0000, 1.1878],
        [0.0000, 1.1963, 0.0000],
        [0.0000, 0.0000, 0.2733],
        [0.0000, 1.1183, 0.2549],
        [0.0000, 0.1559, 0.5826]], grad_fn=<CompiledFunctionBackward>)

In [59]:
input_data = torch.randn(5, 10).to(device)
output_original = model(input_data)
print("Output from original model:", output_original)
output_compiled = compiled_model(input_data)
print("Output from compiled model:", output_compiled)

Output from original model: tensor([[0.9162, 0.7815, 0.3962],
        [0.6812, 0.3794, 0.0000],
        [0.0000, 0.0000, 0.5863],
        [0.0426, 0.4930, 0.3120],
        [0.0000, 0.4797, 0.0000]], grad_fn=<ReluBackward0>)
Output from compiled model: tensor([[0.9162, 0.7815, 0.3962],
        [0.6812, 0.3794, 0.0000],
        [0.0000, 0.0000, 0.5863],
        [0.0426, 0.4930, 0.3120],
        [0.0000, 0.4797, 0.0000]], grad_fn=<CompiledFunctionBackward>)


In [63]:
torch.compiler.list_backends()

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()


['cudagraphs', 'inductor', 'onnxrt', 'openxla', 'tvm']

In [None]:
new_compiled_model = torch.compile(model)
new_compiled_model(input_data)

tensor([[0.9162, 0.7815, 0.3962],
        [0.6812, 0.3794, 0.0000],
        [0.0000, 0.0000, 0.5863],
        [0.0426, 0.4930, 0.3120],
        [0.0000, 0.4797, 0.0000]], grad_fn=<CompiledFunctionBackward>)