Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
d8884e6
Refactor BatchMatMulEmitter and BatchMatMulSelector for improved read…
LeiWang1999 Jul 5, 2024
fc84173
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
02f64de
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
397eee6
disable failure email for ci
LeiWang1999 Jul 5, 2024
20f6ad1
remove email notifications.
LeiWang1999 Jul 6, 2024
b93c394
move relax pass from testing to mlc_llm
LeiWang1999 Jul 6, 2024
ba6a6df
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
257693a
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
9bb7f49
Lint Fix
LeiWang1999 Jul 6, 2024
39e7614
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
93eb5a5
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
aa66a90
bug fix in test
LeiWang1999 Jul 6, 2024
ae14a53
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 6, 2024
79b08e4
lint fix.
LeiWang1999 Jul 6, 2024
86fd036
test cuda i4 kernel
LeiWang1999 Jul 7, 2024
6b73a21
Refactor copyright notice in i4matmul.hpp
LeiWang1999 Jul 7, 2024
0ba90c1
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 7, 2024
086d208
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 7, 2024
47a3abd
refactor test as version below python 3.9 cannot handle int32 overflow.
LeiWang1999 Jul 8, 2024
024b247
format lint for test
LeiWang1999 Jul 8, 2024
bfedeaa
Refactor test_int4b_fp16_convert.py for improved readability and main…
LeiWang1999 Jul 8, 2024
e672a23
remove unused design file
LeiWang1999 Jul 8, 2024
21e5430
move tile device from package to base
LeiWang1999 Jul 8, 2024
fd11940
dummy impl for codegen
LeiWang1999 Jul 8, 2024
9ccfa85
Refactor file structure for ladder_permutate module
LeiWang1999 Jul 8, 2024
7c7d73e
Refactor backend class and fix typos in comments
LeiWang1999 Jul 8, 2024
47d5fc5
Deep refactor Lib related code.
LeiWang1999 Jul 8, 2024
53dd0dd
remove ci pull.
LeiWang1999 Jul 10, 2024
d58ac43
LintFix
LeiWang1999 Jul 10, 2024
37cb07c
refactor builder for whl build
LeiWang1999 Jul 10, 2024
f5b9999
Refactor TIRWrapper.wrap() method to include an assertion for the opt…
LeiWang1999 Jul 11, 2024
fb78244
Refactor lib_generator to set library and source paths
LeiWang1999 Jul 11, 2024
706e227
lint fix
LeiWang1999 Jul 11, 2024
63f5515
BitNet vllm integration
LeiWang1999 Jul 16, 2024
de91c0d
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 16, 2024
b9655fd
chore: update codespell to version 2.3.0
LeiWang1999 Jul 16, 2024
fff385f
Lintfix
LeiWang1999 Jul 16, 2024
72a98e7
Bump version to 0.0.1.dev13
LeiWang1999 Jul 18, 2024
5646ab5
lint fix
LeiWang1999 Jul 18, 2024
b965863
disable fast decoding [u]int4xint8 by default.
LeiWang1999 Jul 21, 2024
1198fc7
optimize from dict design in Hint
LeiWang1999 Jul 21, 2024
014213c
Implement SplitK
LeiWang1999 Jul 21, 2024
e0ca752
bitnet benchmark generation.
LeiWang1999 Jul 21, 2024
81b9cf0
Add benchmark script for BitNet integration
LeiWang1999 Jul 21, 2024
02edc0b
AtomicAdd Support
LeiWang1999 Jul 21, 2024
1a70c2d
LintFix
LeiWang1999 Jul 21, 2024
28d851c
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 21, 2024
c447a95
ci fix when 3rdparty tvm is initialized.
LeiWang1999 Jul 21, 2024
79a001b
bug fix for setup
LeiWang1999 Jul 21, 2024
31813b2
fix a bug in block reduce
LeiWang1999 Jul 21, 2024
78b6a3d
typo fix
LeiWang1999 Jul 21, 2024
9c55218
BUG Fix for block reduce.
LeiWang1999 Jul 22, 2024
1aa8868
Lint fix
LeiWang1999 Jul 22, 2024
22f70bf
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 22, 2024
5f082a5
Refactor block reduce schedule template
LeiWang1999 Jul 22, 2024
b4fb31e
transform branch from bitblas to bitblas_tl
LeiWang1999 Jul 22, 2024
35eaa00
Fix subproject commit reference in 3rdparty/tvm
LeiWang1999 Jul 22, 2024
254dd74
chore: update submodule branch from bitblas to bitblas_tl
LeiWang1999 Jul 22, 2024
31a44aa
force update config.cmake
LeiWang1999 Jul 22, 2024
427800e
Bug fix
LeiWang1999 Jul 22, 2024
96db111
Fix subproject commit reference in 3rdparty/cutlass
LeiWang1999 Jul 22, 2024
38b251a
chore: Add submodule for cutlass library
LeiWang1999 Jul 22, 2024
87d1c5a
update tl cutlass path
LeiWang1999 Jul 22, 2024
6200b1e
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 22, 2024
0ffe0b5
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 22, 2024
8e08e77
format fix
LeiWang1999 Jul 22, 2024
df05a64
Copy CUTLASS to the package directory
LeiWang1999 Jul 22, 2024
4f529c5
Refactor setup.py to include additional TVM header files
LeiWang1999 Jul 22, 2024
d02bbc7
lint fix
LeiWang1999 Jul 23, 2024
cffe3fd
bug fix
LeiWang1999 Jul 23, 2024
a8bed74
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 23, 2024
d4eb5fd
Implement Matmul Benchmark Design
LeiWang1999 Jul 23, 2024
4c6c2c1
chore: Update BitBLAS Matmul benchmark script
LeiWang1999 Jul 23, 2024
0acaca1
lint fix
LeiWang1999 Jul 23, 2024
54d2227
Refactor BitBLASMatmulOpsBenchmark for improved readability and maint…
LeiWang1999 Jul 23, 2024
c2edefb
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
e0bc723
lint fix
LeiWang1999 Jul 23, 2024
a4e68d1
Benchmark bot test
LeiWang1999 Jul 23, 2024
df7e9aa
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 23, 2024
1c03365
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
4f319fc
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
a8833d4
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
803f6c6
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
df4572b
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
45ded45
int8 test case
LeiWang1999 Jul 23, 2024
4229676
Refactor compare_benchmark.py to handle missing benchmark results gra…
LeiWang1999 Jul 23, 2024
b883290
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 23, 2024
476ffee
ci fix
LeiWang1999 Jul 23, 2024
9bd34ff
disable ci for test benchmark
LeiWang1999 Jul 23, 2024
e86f4b2
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 23, 2024
75f3dd9
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
79e04aa
remove cli installation
LeiWang1999 Jul 23, 2024
cdd3345
chore: Create virtual environment and install dependencies for benchmark
LeiWang1999 Jul 23, 2024
f099938
Merge branch 'main' into dev
LeiWang1999 Jul 23, 2024
f211ad4
chore: Update benchmark workflow to include comparison step
LeiWang1999 Jul 23, 2024
ddde02a
Lint fix
LeiWang1999 Jul 24, 2024
8045ce9
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 24, 2024
21aee89
Merge branch 'dev' of https://github.com/LeiWang1999/MSBitBLAS into dev
LeiWang1999 Jul 24, 2024
ef1b158
upodate tvm cmmit
LeiWang1999 Jul 25, 2024
a8d8841
Imporve lower warp memory pass
LeiWang1999 Jul 30, 2024
686b929
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 30, 2024
7736c38
Bug fix
LeiWang1999 Jul 30, 2024
199affc
Enhance to support warp schedule.
LeiWang1999 Jul 31, 2024
9d0c25d
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 31, 2024
7c1f52e
Enhance LOP3 Instructions
LeiWang1999 Jul 31, 2024
d1b2bc7
Enhance LOP3 Instructions
LeiWang1999 Jul 31, 2024
2aac6d0
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 31, 2024
802abde
add test for stage3 propagate
LeiWang1999 Jul 31, 2024
d339037
implement propagate func
LeiWang1999 Jul 31, 2024
0f6a033
Stage3 Ladder Permutate integration
LeiWang1999 Jul 31, 2024
00ec916
get_ladder_stage3_propagate
LeiWang1999 Jul 31, 2024
5316577
comments benchmark scirpts as the setting is too big
LeiWang1999 Jul 31, 2024
dd070f9
ci fix for benchmark
LeiWang1999 Jul 31, 2024
6fcc368
lint fix
LeiWang1999 Jul 31, 2024
705580b
chore: Update benchmark workflow to trigger on pull request comments
LeiWang1999 Jul 31, 2024
c5ba940
Add LDMatrix Transform 3
LeiWang1999 Aug 1, 2024
1566990
Support GPTQ Test
LeiWang1999 Aug 1, 2024
c6c70ef
Fuse BlockReduce Schedule
LeiWang1999 Aug 1, 2024
36128f3
Support mma propagate 3
LeiWang1999 Aug 1, 2024
23ff5f4
Support MMA Propagate Stage 3
LeiWang1999 Aug 1, 2024
de3bf08
Lint Fix
LeiWang1999 Aug 1, 2024
d9830ba
Merge block reduce for dequantze config.
LeiWang1999 Aug 1, 2024
e5a4485
fix codeql
LeiWang1999 Aug 2, 2024
a04282b
chore: Update submodule reference to latest commit
LeiWang1999 Aug 4, 2024
314d3e9
chore: Disable common subexpression elimination in TIR passes
LeiWang1999 Aug 4, 2024
f7d33bb
Lint Fix
LeiWang1999 Aug 4, 2024
db633ed
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Aug 4, 2024
201155a
4bit related lop3 updates.
LeiWang1999 Aug 4, 2024
2b73662
lint fix
LeiWang1999 Aug 4, 2024
1a6a0fd
gptq test fix
LeiWang1999 Aug 4, 2024
e84e3ef
Fix for test
LeiWang1999 Aug 4, 2024
f0fbb55
lint fix
LeiWang1999 Aug 4, 2024
bf30688
lint fix
LeiWang1999 Aug 4, 2024
9a360ba
typofix
LeiWang1999 Aug 4, 2024
ee94536
QuantCompress Test
LeiWang1999 Aug 5, 2024
930cd76
chore: Refactor quant_compress_impl.py for readability and maintainab…
LeiWang1999 Aug 5, 2024
8c24776
Enhance docs to update latest works.
LeiWang1999 Aug 5, 2024
c018e3c
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Aug 5, 2024
38f1713
Refactor weight executors in Matmul class for improved readability an…
LeiWang1999 Aug 5, 2024
4a578ce
Refactor weight executors in Matmul class for improved readability an…
LeiWang1999 Aug 5, 2024
4e7126b
Refactor weight executors in Matmul class for improved readability an…
LeiWang1999 Aug 5, 2024
de9fd2e
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into upda…
LeiWang1999 Aug 5, 2024
e405aa2
removed legacy operator
LeiWang1999 Aug 5, 2024
5709db1
Refactor weight executors in Matmul class for improved readability an…
LeiWang1999 Aug 5, 2024
2d90e7b
LintFix
LeiWang1999 Aug 5, 2024
c2d2cfa
Fix GPTQ Repack with the latest weight transform
LeiWang1999 Aug 5, 2024
ed6a0a1
lint fix
LeiWang1999 Aug 5, 2024
d23ab47
bug fix for rescale dequantize
LeiWang1999 Aug 5, 2024
af16059
test fix
LeiWang1999 Aug 5, 2024
ac316fd
typo fix
LeiWang1999 Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@
from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401
from .module import Linear # noqa: F401

import warnings
import functools
import logging
from tqdm import tqdm

Expand Down Expand Up @@ -89,4 +90,26 @@ def _init_logger():

_init_logger()


def deprecated(reason):
"""
This is a decorator which can be used to mark functions as deprecated.
It will result in a warning being emitted when the function is used.
"""

def decorator(func):

@functools.wraps(func)
def new_func(*args, **kwargs):
warnings.warn(
f"Call to deprecated function {func.__name__} ({reason}).",
category=DeprecationWarning,
stacklevel=2)
return func(*args, **kwargs)

return new_func

return decorator


__version__ = "0.0.1.dev13"
9 changes: 5 additions & 4 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2264,10 +2264,11 @@ def get_idx():
lop3_intrin_info["compute"],
)
# Assume the grouped K is the last dim of the scaling
grouped_k = sch.get(bf).reads[1].buffer.shape[-1]
# TODO(lei): This is a hack to get the loop extent
loop_extent = 8 if out_dtype == "float16" else 16
sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k)
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
grouped_k = sch.get(bf).reads[1].buffer.shape[-1]
# TODO(lei): This is a hack to get the loop extent
loop_extent = 8 if out_dtype == "float16" else 16
sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k)
import_source.append(lop3_intrin_info["c_source"])

def tensorize_init_store_compute():
Expand Down
21 changes: 20 additions & 1 deletion bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ def unpack_qzeros(qzeros, bits):
return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1)


def unpack_qweight(qweight, bits):
qweight = qweight.view(torch.int8)
elems_per_int8 = 8 // bits
unpacked_weight = torch.zeros(
(qweight.shape[0], qweight.shape[1] * elems_per_int8),
dtype=torch.int8,
device=qweight.device,
requires_grad=False,
)
for col in range(unpacked_weight.shape[1]):
i = col % elems_per_int8
unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> (bits * i))

# Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303
# NOTE: It appears that casting after the `unpacked_zeros + 1` is important.
return torch.bitwise_and(unpacked_weight, 2**bits - 1)


class Linear(nn.Module):
opt_M = [1, 16, 32, 64, 128, 256, 512]
STORAGE_DTYPE = "int8" # assume int8 storage
Expand Down Expand Up @@ -279,8 +297,9 @@ def load_and_transform_weight(
def repack_from_gptq(self, gptq_module):
# qweight in gptq old quant linear stored with (out_features, in_features), should be transposed.
qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE)
intweight = unpack_qweight(qweight, self.bits).contiguous()
if self.bitblas_matmul.weight_transform is not None:
qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda()
qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda()
self.qweight = qweight
# scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed.
scales = gptq_module.scales.T.contiguous().view(self.torch_dtype)
Expand Down
3 changes: 1 addition & 2 deletions bitblas/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .operator import Operator, OperatorConfig # noqa: F401
from .matmul import Matmul, MatmulConfig # noqa: F401
from .matmul_dequantize import MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig # noqa: F401
from .general_matmul import Matmul, MatmulConfig # noqa: F401
from .ladder_permutate import LadderPermutate, LadderPermutateConfig # noqa: F401
from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig # noqa: F401
from .quant_compress import QuantCompress, QuantCompressConfig # noqa: F401
49 changes: 32 additions & 17 deletions bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from bitblas.utils.target_detector import auto_detect_nvidia_target
from dataclasses import dataclass
from ..ladder_permutate import LadderPermutate, LadderPermutateConfig
from ..quant_compress import QuantCompress, QuantCompressConfig
from ..lop3_permutate import LOP3Permutate, LOP3PermutateConfig
import logging
import torch
Expand Down Expand Up @@ -292,6 +293,7 @@ def dispatch_tir(self,
# create permutate_opertors
self.ladder_permutate_a = self._assign_ladder_permutate_a(target, enable_tuning)
self.ladder_permutate_b = self._assign_ladder_permutate_b(target, enable_tuning)
self.weight_compress = self._assign_weight_compress(target, enable_tuning)
self.lop3_permutate = self._assign_lop3_permutate(target, enable_tuning)
# create cpu weight executors
self.input_executors = self._create_input_executors()
Expand Down Expand Up @@ -338,11 +340,14 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool):
del enable_tuning

if self.propagate_b:
# weight transform should be done in the unpacked level
# otherwise the bit trick should be applied and that is
# too complex to be implemented in the ladder permutation.
ladder_permutate_config = LadderPermutateConfig(
M=self.N,
N=self.K,
datatype=self.A_dtype,
dequantize_bits=self.bit,
dequantize_bits=-1,
storage_dtype=self.storage_dtype,
propagate_kind="B",
transpose_matrix=self.layout == "nt",
Expand All @@ -354,6 +359,25 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool):
)
return None

def _assign_weight_compress(self, target: Target, enable_tuning: bool):
# unused variables
del target
del enable_tuning

require_compress: bool = self.bit in [1, 2, 4]
if require_compress:
quant_compress_config = QuantCompressConfig(
M=self.N,
N=self.K,
input_dtype=self.storage_dtype,
storage_dtype=self.storage_dtype,
dequantize_bits=self.bit)
return QuantCompress(
config=quant_compress_config,
target=tvm.target.Target("llvm"),
)
return None

def _assign_lop3_permutate(self, target: Target, enable_tuning: bool):
# unused variables
del target
Expand Down Expand Up @@ -381,10 +405,12 @@ def _create_input_executors(self):

def _create_weight_executors(self):
weight_executors = OPExecutorCPU()
if self.fast_decoding:
weight_executors.append(self.lop3_permutate)
if self.propagate_b is not TransformKind.NonTransform:
weight_executors.append(self.ladder_permutate_b)
if self.weight_compress is not None:
weight_executors.append(self.weight_compress)
if self.fast_decoding:
weight_executors.append(self.lop3_permutate)
return weight_executors

def _select_implementation(self):
Expand Down Expand Up @@ -452,10 +478,6 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
return self.weight_transform(weight.cpu()).cuda().contiguous()
return weight

from bitblas.quantization import general_compress
import torch
import numpy as np

source_format, bit = self.source_format, self.bit

# Process integer source format
Expand All @@ -464,20 +486,13 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2**(bit - 1)
# Clamp weight values to be within the quantizable range and adjust
weight = torch.clamp(weight, -maxq, maxq).int() + maxq
weight = torch.clamp(weight, -maxq, maxq).char() + maxq
elif source_format in ["fp_e5m2", "fp_e4m3"]:
weight = weight.view(torch.int8)
weight = weight.int()
else:
# For non-integer formats, simply convert weights to integers
weight = weight.int()

np_storage_dtype = getattr(np, self.storage_dtype)

weight = general_compress(
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype)

weight = torch.from_numpy(weight).cuda().contiguous()
# And assume weight is in the range of [-128, 127] for int8
weight = weight.char()

# Apply an optional weight transformation if specified
if self.weight_transform is not None:
Expand Down
18 changes: 18 additions & 0 deletions bitblas/ops/ladder_permutate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..operator import Operator
from .ladder_permutate_impl import select_implementation
from dataclasses import dataclass
import torch


@dataclass(frozen=True)
Expand Down Expand Up @@ -57,6 +58,23 @@ def _select_implementation(self):
target_instruction=self.target_instruction,
)

def forward(self, inp, out=None):
if out is None:
out_shape, out_dtype = self.retrieve_output_shape()
out = torch.zeros(out_shape, dtype=out_dtype).to(inp.device)
self.torch_func(inp, out)
return out

def retrieve_output_shape(self):
"""
Retrieve the output shape of the operator
"""
func = self.prim_func
param = func.params[-1]
assert param in func.buffer_map, f"param {param} not in buffer_map"
arg = func.buffer_map[param]
return [int(i) for i in arg.shape], getattr(torch, arg.dtype)

@property
def M(self):
return self.config.M
Expand Down
18 changes: 15 additions & 3 deletions bitblas/ops/lop3_permutate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,23 @@ def _select_implementation(self):
dequantize_bits=self.dequantize_bits,
)

def forward(self, weight, res):
def forward(self, inp, out=None):
out_shape = inp.shape
out_dtype = inp.dtype
if out is None:
# lop3 transform does not change the shape of the input tensor
out = torch.zeros(out_shape, dtype=out_dtype)
# reinterpret the input tensor to int32 format
args = [arg.view(torch.int32) for arg in [weight, res]]
shape_2dim = self.retrieve_2d_weight_shape()
args = [arg.view(inp.dtype).view(shape_2dim).view(torch.int32) for arg in [inp, out]]
self.torch_func(*args)
return args[-1].view(weight.dtype)
return args[-1].view(out_dtype).view(out_shape)

def retrieve_2d_weight_shape(self):
storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit()))
elems_per_byte = storage_nbit // self.dequantize_bits
weight_shape = (self.M, self.N // elems_per_byte)
return weight_shape

@property
def M(self):
Expand Down
Loading