Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

f32 quantization fp16 #324

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ if(BUDDY_DSL_EXAMPLES)
add_subdirectory(ToyDSL)
endif()

if(BUDDY_QUANTIZATION_EXAMPLES)
add_subdirectory(Quantization)
endif()

configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
Expand Down
81 changes: 81 additions & 0 deletions examples/Quantization/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
cmake_minimum_required(VERSION 3.10)
project(Quantization)

find_package(OpenMP REQUIRED)

# 添加生成模型目标文件的命令
add_custom_command(
OUTPUT forward.o
COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/Quantization/forward.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-arith))" |
${BUDDY_BINARY_DIR}/buddy-opt
-arith-expand
-eliminate-empty-tensors
-empty-tensor-to-alloc-tensor
-one-shot-bufferize
-matmul-paralell-vectorization-optimize
-batchmatmul-optimize
-convert-linalg-to-affine-loops
-affine-loop-fusion
-affine-parallelize
-lower-affine
-convert-scf-to-openmp
-func-bufferize
-arith-bufferize
-tensor-bufferize
-buffer-deallocation
-finalizing-bufferize
-convert-vector-to-scf
-expand-strided-metadata
-convert-vector-to-llvm
-memref-expand
-arith-expand
-convert-arith-to-llvm
-finalize-memref-to-llvm
-convert-scf-to-cf
-llvm-request-c-wrappers
-convert-openmp-to-llvm
-convert-arith-to-llvm
-convert-math-to-llvm
-convert-math-to-libm
-convert-func-to-llvm
-reconcile-unrealized-casts |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llvm-as |
${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3
-o ${BUDDY_BINARY_DIR}/../examples/Quantization/forward.o
DEPENDS buddy-opt ${BUDDY_EXAMPLES_DIR}/Quantization/forward.mlir
COMMENT "Building forward.o"
VERBATIM)

add_library(QUANTIZATION STATIC forward.o)

SET_SOURCE_FILES_PROPERTIES(
template.o
PROPERTIES
EXTERNAL_OBJECT true
GENERATED true)

SET_TARGET_PROPERTIES(
QUANTIZATION
PROPERTIES
LINKER_LANGUAGE C)

add_executable(quantization-run addqf.cpp)
target_link_directories(quantization-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR})


set(QUANTIZATION_LIBS
QUANTIZATION
mlir_c_runner_utils
omp
OpenMP::OpenMP_CXX # 使用正确的 OpenMP 库名称
)

add_compile_options(-mf16c)

if(BUDDY_MLIR_USE_MIMALLOC)
list(APPEND QUANTIZATION_LIBS mimalloc)
endif()

target_link_libraries(quantization-run ${QUANTIZATION_LIBS})
96 changes: 96 additions & 0 deletions examples/Quantization/addqf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#include <buddy/Core/Container.h>
#include <buddy/LLM/TextContainer.h>
#include <chrono>
#include <cstddef>
#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>

using namespace buddy;

extern "C" void _mlir_ciface_forward(MemRef<_Float16, 3> *result, MemRef<float, 3> *params, MemRef<float, 3> *input);

std::vector<std::vector<float>> readDataFromFile(const std::string &filePath, int rows, int cols) {
std::ifstream file(filePath);
std::vector<std::vector<float>> data;
if (file.is_open()) {
std::string line;
while (std::getline(file, line)) {
std::vector<float> row;
std::istringstream iss(line);
float value;
while (iss >> value) {
row.push_back(value);
}
if (row.size() == cols) {
data.push_back(row);
} else {
std::cerr << "Row size does not match expected cols: " << row.size() << " vs " << cols << std::endl;
}
}
file.close();
} else {
std::cerr << "Unable to open file: " << filePath << std::endl;
}
if (data.size() != rows) {
std::cerr << "Data size does not match expected rows: " << data.size() << " vs " << rows << std::endl;
}
return data;
}

void printVector(const std::vector<float> &vec, const std::string &name) {
std::cout << name << ": ";
for (const auto &val : vec) {
std::cout << val << " ";
}
std::cout << std::endl;
}


int main() {
std::string inputFilePath = "/home/xujiahao/Quantization/buddy-mlir/examples/Quantization/input_data.txt";
std::string paramsFilePath = "/home/xujiahao/Quantization/buddy-mlir/examples/Quantization/params_data.txt";

std::vector<std::vector<float>> inputData = readDataFromFile(inputFilePath, 10, 6);
std::vector<std::vector<float>> paramsData = readDataFromFile(paramsFilePath, 10, 6);

std::vector<float> flatInputData;
for (const auto &row : inputData) {
flatInputData.insert(flatInputData.end(), row.begin(), row.end());
}

std::vector<float> flatParamsData;
for (const auto &row : paramsData) {
flatParamsData.insert(flatParamsData.end(), row.begin(), row.end());
}

std::vector<_Float16> flatResultData(10 * 3 * 2, 0.0f);

printVector(flatInputData, "Flat Input Data");
printVector(flatParamsData, "Flat Params Data");

intptr_t input_sizes[3] = {10, 3, 2};
intptr_t params_sizes[3] = {10, 3, 2};
intptr_t result_sizes[3] = {10, 3, 2};

MemRef<float, 3> input(flatInputData.data(), input_sizes);
MemRef<float, 3> params(flatParamsData.data(), params_sizes);
MemRef<_Float16, 3> result(flatResultData.data(), result_sizes);


_mlir_ciface_forward(&result, &params, &input);

_Float16 *resultData = result.getData();
for (int i = 0; i < 10; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 2; ++k) {
std::cout << resultData[i * 3 * 2 + j * 2 + k] << " ";
}
std::cout << "\n";
}
std::cout << "\n";
}

return 0;
}
16 changes: 16 additions & 0 deletions examples/Quantization/forward.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module {
func.func @forward(%arg0: tensor<10x3x2xf32>, %arg1: tensor<10x3x2xf32>) -> tensor<10x3x2xf16> {
%0 = tosa.cast %arg0 : (tensor<10x3x2xf32>) -> tensor<10x3x2xf16>
%1 = tosa.cast %arg1 : (tensor<10x3x2xf32>) -> tensor<10x3x2xf16>
%2 = tosa.add %0, %1 : (tensor<10x3x2xf16>, tensor<10x3x2xf16>) -> tensor<10x3x2xf16>
return %2 : tensor<10x3x2xf16>
}
}
// module {
// func.func @forward(%arg0: tensor<10x3x2xf32>, %arg1: tensor<10x3x2xf32>) -> tensor<10x3x2xf32> {
// %0 = tosa.cast %arg0 : (tensor<10x3x2xf32>) -> tensor<10x3x2xf32>
// %1 = tosa.cast %arg1 : (tensor<10x3x2xf32>) -> tensor<10x3x2xf32>
// %2 = tosa.add %0, %1 : (tensor<10x3x2xf32>, tensor<10x3x2xf32>) -> tensor<10x3x2xf32>
// return %2 : tensor<10x3x2xf32>
// }
// }
16 changes: 16 additions & 0 deletions examples/Quantization/gen_random_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np


input_data = np.random.rand(10, 6)

params_data = np.random.rand(10, 6)

with open('input_data.txt', 'w') as f:
for row in input_data:
f.write(' '.join(map(str, row)) + '\n')

with open('params_data.txt', 'w') as f:
for row in params_data:
f.write(' '.join(map(str, row)) + '\n')

print("gen done!")
10 changes: 10 additions & 0 deletions examples/Quantization/input_data.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
0.41918253352224544 0.26502703213899703 0.6693360575806869 0.17467943759777538 0.3333227064407196 0.4456407571464961
0.8448478465044279 0.13435349286541443 0.9287191035936655 0.14173813628694976 0.7113423853332846 0.9300744462452724
0.7535358764626087 0.36814028212321803 0.8335676861737891 0.028444284371489048 0.06378162522030628 0.9993912226798591
0.2983775144617824 0.6390499602428215 0.5121750865967225 0.9102147243459047 0.004130308972184471 0.04907280040137574
0.5151397779095757 0.2209919270157623 0.3773935888973331 0.553358519583838 0.5256209743662532 0.021602893514575894
0.903654322647565 0.04323882051448136 0.2930296998068076 0.8429590594546426 0.864420702193492 0.9569747690038598
0.07427449542908304 0.0659057566834852 0.36465044968792926 0.7834495240055113 0.24462079605664167 0.9693060703790637
0.8939088621169848 0.08124940616466991 0.25435076790627764 0.09213120538205166 0.09289695705195133 0.7063502174740877
0.5320911316559717 0.028041998772447063 0.599732840142648 0.07420770844862024 0.9826711147591535 0.6137843810407524
0.29541514261981994 0.47574393586928765 0.4883320995133651 0.520450382312919 0.9247619547184011 0.558747650905918
10 changes: 10 additions & 0 deletions examples/Quantization/params_data.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
0.763104628057672 0.8377160652735918 0.8513019402280178 0.9123585227696165 0.4481681308588009 0.5982249853645559
0.19817860971211676 0.53725221421351 0.26093784911288354 0.3244533436781908 0.1439918594907188 0.3024487875373958
0.25396301930256315 0.027000801899303317 0.3585077148026633 0.47260237404623273 0.9793354709026868 0.5838658977777783
0.23533749606637933 0.49431224471215074 0.017819815315252452 0.78779958041884 0.9354206226161268 0.6476278607509749
0.17964232043147288 0.3387666692451754 0.12736492493137852 0.12213961002427098 0.2372822400837763 0.1886183359686744
0.07069298807841318 0.2938284146990432 0.7330911350505327 0.486429869254954 0.40514411287635765 0.1815344585588189
0.1384082900157909 0.002927967558442379 0.17326401686662396 0.44172741447457875 0.8122124811078606 0.7703457864281187
0.7759915330763716 0.985798377512695 0.10427071520807762 0.5034154421989058 0.789812781818153 0.3761954305624271
0.7186803446742486 0.22064728131468347 0.44928494467100166 0.9338660536738846 0.047811145082879536 0.5242529834929468
0.5582416544283104 0.29649787440236564 0.418186520656883 0.3056981671555391 0.49005481049330957 0.7210165708991532
4 changes: 4 additions & 0 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def lower_to_top_level_ir(self):
np_type = np.dtype(np.int64)
case "f32":
np_type = np.dtype(np.float32)
case "f16":
np_type = np.dtype(np.float16)
case _:
raise NotImplementedError(f"Unsupported dtype {dtype}")
self._output_memref.append(
Expand Down Expand Up @@ -391,6 +393,8 @@ def _str_to_mlir_dtype(self, dtype: str) -> ir.Type:
return ir.IntegerType.get_signless(64)
case TensorDType.Float32:
return ir.F32Type.get()
case TensorDType.Float16:
return ir.F16Type.get()
case TensorDType.Bool:
return ir.IntegerType.get_signless(1)
case _:
Expand Down
3 changes: 3 additions & 0 deletions frontend/Python/ops/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,20 @@ def param_extract(
else:
output_shape = list(node.tensor_meta["shape"])
subview_size = functools.reduce(lambda x, y: x * y, output_shape)

offset_attr = ir._denseI64ArrayAttr([offset], None)
size_attr = ir._denseI64ArrayAttr([subview_size], None)
stride = [1]
stride_attr = ir._denseI64ArrayAttr(stride, None)
memref_attr = ir.Attribute.parse("strided<[1], offset: {}>".format(offset))

if offset == 0:
memref_type = ir.MemRefType.get([subview_size], memref_element_type)
else:
memref_type = ir.MemRefType.get(
[subview_size], memref_element_type, memref_attr
)

memref_subview_op = memref.SubViewOp(
memref_type,
params_mlir_node,
Expand Down
71 changes: 69 additions & 2 deletions frontend/Python/ops/tosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,21 @@ def addmm_op(
return op


def bmm_op_quantized(node: BatchMatmulOp, symbol_table, quantized_dtype='fp16') -> ir.Operation:
input_ = symbol_table.get((str(node.args[0]), 0))
mat2 = symbol_table.get((str(node.args[1]), 0))

input_shp = ir.RankedTensorType(input_.type).shape
mat2_shp = ir.RankedTensorType(mat2.type).shape
sizes = [input_shp[0], input_shp[1], mat2_shp[2]]

# Get the quantized element type
result_element_type = get_quantized_type(quantized_dtype, 16)
result_type = ir.RankedTensorType.get(sizes, result_element_type)
op = tosa.MatMulOp(result_type, input_, mat2)
return op


def bmm_op(node: BatchMatmulOp, symbol_table) -> ir.Operation:
"""
Import batch matrix multiplication operation.
Expand Down Expand Up @@ -247,6 +262,56 @@ def add_op(node: AddOp, symbol_table):
return _gen_arith_binary_op(input1, input2, tosa.AddOp)


def get_quantized_type(dtype, bitwidth):
if dtype == 'fp16':
return ir.F16Type.get()
elif dtype == 'bf16':
return ir.BF16Type.get()
else:
raise ValueError("Unsupported quantized type")


def calculate_scale_and_zero_point(min_val,max_val,num_bits=8):
qmin = -2 ** (num_bits - 1)
qmax = 2 ** (num_bits - 1) - 1

scale = (max_val - min_val) / (qmax - qmin)
zero_point = round(qmax - max_val / scale)

return scale,zero_point

def quantize(tensor,scale,zero_point):
return numpy.round(tensor/scale + zero_point).astype(numpy.int8)

def dequantize(tensor, scale, zero_point):
return (tensor.astype(numpy.float32) - zero_point) * scale

def reduce_tensor_to_scalar(tensor, reduce_op):
"""
Reduce tensor to a scalar value using the given reduce operation.
"""
shape = tensor.shape
while len(shape) > 1:
tensor = reduce_op(tensor, axes=[0])
shape = tensor.shape
return tensor

def addqf_op(node: AddOp, symbol_table, quantized_dtype='f16',num_bits=8):
"""
Import tensor addition operation.
From buddy graph ir's `AddOp` operator to MLIR TOSA `add` operation.
"""
input1 = symbol_table.get((str(node.args[0]), 0), node.args[0])
input2 = symbol_table.get((str(node.args[1]), 0), node.args[1])

# Convert input tensors to quantized dtype
input1_cast = tosa.cast(ir.RankedTensorType.get(input1.type.shape, ir.F16Type.get()),input1)
input2_cast = tosa.cast(ir.RankedTensorType.get(input2.type.shape, ir.F16Type.get()),input2)

# Perform addition operation
return _gen_arith_binary_op(input1_cast, input2_cast, tosa.AddOp)


def sub_op(node: SubOp, symbol_table):
"""
Import tensor subtraction operation.
Expand Down Expand Up @@ -1349,14 +1414,15 @@ def clamp_max_op(node: ClampMaxOp, symbol_table):


ops_registry = {
"AddOp": add_op,
"AddOp": addqf_op,
# "AddOp": add_op,
"MulOp": mul_op,
"SubOp": sub_op,
"SumDimOp": sum_op,
"TanhOp": tanh_op,
"AmaxOp": amax_op,
"RsqrtOp": rsqrt_op,
"BatchMatmulOp": bmm_op,
"BatchMatmulOp": bmm_op_quantized,
"CloneOp": clone_op,
"DivOp": div_op,
"ExpOp": exp_op,
Expand All @@ -1382,4 +1448,5 @@ def clamp_max_op(node: ClampMaxOp, symbol_table):
"MeanOp": mean_op,
"ClampMinOp": clamp_min_op,
"ClampMaxOp": clamp_max_op,
# "BmmQFOP": bmm_op_quantized,
}
Loading