-
Notifications
You must be signed in to change notification settings - Fork 52
Closed
Description
Given a static program:
import bitblas
# enabling debug output
bitblas.common.MAX_ERROR_MESSAGE_LENGTH = 100000
bitblas.set_log_level("Debug")
matmul_config = bitblas.MatmulConfig(
M=1024, # M dimension
N=17280, # N dimension
K=3200, # K dimension
A_dtype="int8", # activation A dtype
W_dtype="int8", # weight W dtype
accum_dtype="int32", # accumulation dtype
out_dtype="float32", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
propagate_a=False,
propagate_b=3,
)
matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)
matmul.hardware_aware_finetune(topk=20, parallel_build=True)
print(matmul.scheduled_ir_module)
print(matmul.get_source())The async copy instructions will be correctly lowered. However, when we encounter a dynamic case:
import bitblas
# enabling debug output
bitblas.common.MAX_ERROR_MESSAGE_LENGTH = 100000
bitblas.set_log_level("Debug")
matmul_config = bitblas.MatmulConfig(
M=[1024], # M dimension
N=17280, # N dimension
K=3200, # K dimension
A_dtype="int8", # activation A dtype
W_dtype="int8", # weight W dtype
accum_dtype="int32", # accumulation dtype
out_dtype="float32", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
propagate_a=False,
propagate_b=3,
)
matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)
matmul.hardware_aware_finetune(topk=20, parallel_build=True)
print(matmul.scheduled_ir_module)
print(matmul.get_source())Generated source is:
#pragma unroll
for (int ax0_ax1_ax2_ax3_ax4_fused_2_1 = 0; ax0_ax1_ax2_ax3_ax4_fused_2_1 < 4; ++ax0_ax1_ax2_ax3_ax4_fused_2_1) {
{
unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)(buf_dyn_shmem + (((((((ax3_0_0 + 1) & 1) * 8192) + (((int)threadIdx.y) * 4096)) + (((int)threadIdx.z) * 2048)) + (ax0_ax1_ax2_ax3_ax4_fused_2_1 * 512)) + (((int)threadIdx.x) * 16)))));
#else
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)(buf_dyn_shmem + (((((((ax3_0_0 + 1) & 1) * 8192) + (((int)threadIdx.y) * 4096)) + (((int)threadIdx.z) * 2048)) + (ax0_ax1_ax2_ax3_ax4_fused_2_1 * 512)) + (((int)threadIdx.x) * 16))))
);
#endif
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
:: "r"(addr), "l"((void*)(B + ((((((((((int)blockIdx.x) * 409600) + (((int)threadIdx.y) * 204800)) + (((int)threadIdx.z) * 102400)) + ((ax0_ax1_ax2_ax3_ax4_fused_2_1 >> 1) * 51200)) + (ax3_0_0 * 1024)) + ((ax0_ax1_ax2_ax3_ax4_fused_2_1 & 1) * 512)) + (((int)threadIdx.x) * 16)) + 1024))), "n"(16)
);
}
}Metadata
Metadata
Assignees
Labels
No labels