-
Notifications
You must be signed in to change notification settings - Fork 639
Description
Hi folks! My team and I are looking into having compiler support for block floating point (BFP) in Torch-MLIR. Wondering what you think about extending the Torch-MLIR support for these cases. Below is a dummy test network I used as an experiment to compile a PyTorch model with BFP additions from qtorch
via torch_mlir
:
Input Description
The model is a basic MatMul followed by a BFP cast and a ReLU activation.
class SimpleModel(nn.Module):
def __init__(self, input_dim, output_size):
super(SimpleModel, self).__init__()
self.matmul = nn.Linear(input_dim, output_size)
self.relu = nn.ReLU()
def forward(self, x):
matmul_out = self.matmul(x.flatten(1))
quantized_matmul_out = block_quantize(matmul_out, wl=8, dim=0, rounding="nearest")
relu_out = self.relu(quantized_matmul_out)
return relu_out
Observed Behaviour
I used the torch_mlir.compile()
API to compile the module into TOSA IR. While the module seems to run forward-propagate fine, the compilation seems to hit an assert in the qtorch.quant_function.block_quantize()
for not having a valid "rounding mode". Also, removing the BFP quantization in the forward-propagate of the module yields a successful compilation.
Lastly, if I quantize the inputs tensors of the module and then call torch_mlir.compile()
on them, there doesn't seem to be any issue - are the casts optimized out in this case?
Script to Reproduce
For convenience, I made a draft PR with a minimal script to reproduce the issue I'm hitting here: #909
FYI @silvasean