Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,15 +394,6 @@
"AtenIntBoolOpModule_basic",
"AtenIntMM_basic",
"AtenItemFpOpModule_basic",
"AtenMatmulQMixedSigni8Transpose_basic",
"AtenMatmulQMixedSigni8_basic",
"AtenMatmulQint8MV_basic",
"AtenMatmulQint8_basic",
"AtenMatmulQint8VM_basic",
"AtenMatmulQint8VV_basic",
"AtenMmQMixedSigni8_basic",
"AtenMmQint8_basic",
"AtenMmQuint8_basic",
"QuantizedReluInt32_basic",
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
Expand Down
130 changes: 75 additions & 55 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils):


# ==============================================================================
# For DQ-Q fake quantization ops
import torch.ao.quantization.fx._decomposed


class AtenMmQint8(torch.nn.Module):
Expand All @@ -352,12 +354,14 @@ def __init__(self):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.mm(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.0215, -25, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0176, 18, -128, 127, torch.int8
)
z = torch.mm(x, y)
return z


@register_test_case(module_factory=lambda: AtenMmQint8())
Expand All @@ -384,12 +388,14 @@ def __init__(self):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160)
qy = torch.dequantize(qy)
qz = torch.mm(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.199, 65, 0, 255, torch.uint8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0215, 160, 0, 255, torch.uint8
)
z = torch.mm(x, y)
return z


@register_test_case(module_factory=lambda: AtenMmQuint8())
Expand All @@ -416,12 +422,14 @@ def __init__(self):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
qy = torch.dequantize(qy)
qz = torch.mm(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.03, -66, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.025, 160, 0, 255, torch.uint8
)
z = torch.mm(x, y)
return z


@register_test_case(module_factory=lambda: AtenMmQMixedSigni8())
Expand Down Expand Up @@ -475,12 +483,14 @@ def __init__(self):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.0215, -25, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0176, 18, -128, 127, torch.int8
)
z = torch.matmul(x, y)
return z


@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
Expand All @@ -505,12 +515,14 @@ def __init__(self):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.0215, -25, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0176, 18, -128, 127, torch.int8
)
z = torch.matmul(x, y)
return z


@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
Expand All @@ -535,12 +547,14 @@ def __init__(self):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.0215, -25, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0176, 18, -128, 127, torch.int8
)
z = torch.matmul(x, y)
return z


@register_test_case(module_factory=lambda: AtenMatmulQint8MV())
Expand All @@ -565,12 +579,14 @@ def __init__(self):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.0215, -25, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0176, 18, -128, 127, torch.int8
)
z = torch.matmul(x, y)
return z


@register_test_case(module_factory=lambda: AtenMatmulQint8())
Expand All @@ -597,12 +613,14 @@ def __init__(self):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.03, -66, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.025, 160, 0, 255, torch.uint8
)
z = torch.matmul(x, y)
return z


@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8())
Expand All @@ -629,13 +647,15 @@ def __init__(self):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
qy = torch.dequantize(qy)
qy = torch.transpose(qy, 1, 2)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.03, -66, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.025, 160, 0, 255, torch.uint8
)
y = torch.transpose(y, 1, 2)
z = torch.matmul(x, y)
return z


@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())
Expand Down
Loading