In [1]:
import torch
import torch_mlir

# 1. 定义简单的PyTorch模型
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(self.conv(x))

# 2. 初始化模型与示例输入
model = SimpleModel().eval()
example_input = torch.randn(1, 3, 32, 32)  # NCHW格式

# 3. 转换为MLIR的torch方言
torch_mlir_module = torch_mlir.compile(
    model, example_input, output_type="torch"
)
print("=== Torch Dialect ===")
print(torch_mlir_module)

# 4. 进一步转换为linalg方言（低层张量操作）
linalg_mlir_module = torch_mlir.compile(
    model, example_input, output_type="linalg-on-tensors"
)
print("\n=== Linalg Dialect ===")
print(linalg_mlir_module)

# 5. 导出为LLVM IR（最终可编译为机器码）
llvm_mlir_module = torch_mlir.compile(
    model, example_input, output_type="llvm"
)
print("\n=== LLVM IR ===")
print(llvm_mlir_module)

ModuleNotFoundError: No module named 'torch_mlir'