Skip to content

QPyTorch Support (BFP quantization) #910

@Svoch

Description

@Svoch

Hi folks! My team and I are looking into having compiler support for block floating point (BFP) in Torch-MLIR. Wondering what you think about extending the Torch-MLIR support for these cases. Below is a dummy test network I used as an experiment to compile a PyTorch model with BFP additions from qtorch via torch_mlir:

Input Description

The model is a basic MatMul followed by a BFP cast and a ReLU activation.

class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_size):
        super(SimpleModel, self).__init__()
        self.matmul = nn.Linear(input_dim, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        matmul_out = self.matmul(x.flatten(1))
        quantized_matmul_out = block_quantize(matmul_out, wl=8, dim=0, rounding="nearest")
        relu_out = self.relu(quantized_matmul_out)
        return relu_out

Observed Behaviour

I used the torch_mlir.compile() API to compile the module into TOSA IR. While the module seems to run forward-propagate fine, the compilation seems to hit an assert in the qtorch.quant_function.block_quantize() for not having a valid "rounding mode". Also, removing the BFP quantization in the forward-propagate of the module yields a successful compilation.
Lastly, if I quantize the inputs tensors of the module and then call torch_mlir.compile() on them, there doesn't seem to be any issue - are the casts optimized out in this case?

Script to Reproduce

For convenience, I made a draft PR with a minimal script to reproduce the issue I'm hitting here: #909

FYI @silvasean

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions