Skip to content

Commit

Permalink
ZeRO-Inference refresh (#4197)
Browse files Browse the repository at this point in the history
* INT4 weight only quantization (#479)

* INT4 weight only quantization

* pre commit

* fix UT

* fix UT

* fix UT

* fix UT

* fix UT

* fix UT

* fix UT

* add zero3 test

* quantize small weight first to prevent oom

* fold quantization config into ds_config

* Fix license & refactor ds_config & rebase master

* fix UT

* Moving quantization into post_init_method and add int4 dequantization kernel (#522)

* Add experimental int4 dequantize kernel

* move quantiation into post_init_method

* fix

* Refactor: move int4 code to deepspeed/inference (#528)

* Move int 4 code to deepspeed/inference

* fix

* fix

* fix

* zero++ tutorial PR (#3783)

* [Fix] _conv_flops_compute when padding is a str and stride=1 (#3169)

* fix conv_flops_compute when padding is a str when stride=1

* fix error

* change type of paddings to tuple

* fix padding calculation

* apply formatting check

---------

Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* fix interpolate flops compute (#3782)

* use `Flops Profiler` to test `model.generate()` (#2515)

* Update profiler.py

* pre-commit run --all-files

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

---------

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>

* revert PR #3611 (#3786)

* bump to 0.9.6

* ZeRO++ chinese blog (#3793)

* zeropp chinese blog

* try better quality images

* make title larger

* even larger...

* various fix

* center captions

* more fixes

* fix format

* remove staging trigger (#3792)

* DeepSpeed-Triton for Inference (#3748)

Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Arash Bakhtiari <arash@bakhtiari.org>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Ethan Doe <yidoe@microsoft.com>
Co-authored-by: yidoe <68296935+yidoe@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* ZeRO++ (#3784)

Co-authored-by: HeyangQin <heyangqin@microsoft.com>
Co-authored-by: GuanhuaWang <alexwgh333@gmail.com>
Co-authored-by: cmikeh2 <connorholmes@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>

* adding zero++ to navigation panel of deepspeed.ai (#3796)

* Add ZeRO++ Japanese blog (#3797)

* zeropp chinese blog

* try better quality images

* make title larger

* even larger...

* various fix

* center captions

* more fixes

* fix format

* add ZeRO++ Japanese blog

* add links

---------

Co-authored-by: HeyangQin <heyangqin@microsoft.com>
Co-authored-by: Conglong Li <conglong.li@gmail.com>

* Bug Fixes for autotuner and flops profiler (#1880)

* fix autotuner when backward is not called

* fix format

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* Missing strided copy for gated MLP (#3788)

Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>

* Requires grad checking. (#3789)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* bump to 0.10.0

* Fix Bug in transform.cu (#3534)

* Bug fix

* Fixed formatting error

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>

* bug fix: triton importing error (#3799)

Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* Fix dequant bug

* Address PR feedback

* Use super() __exit__

* Fix unit tests

---------

Co-authored-by: Donglin Zhuang <donglinzhuang@outlook.com>
Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
Co-authored-by: Bill Luo <50068224+zhiruiluo@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Guorun <84232793+CaffreyR@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: stephen youn <13525892+stephen-youn@users.noreply.github.com>
Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Arash Bakhtiari <arash@bakhtiari.org>
Co-authored-by: Ethan Doe <yidoe@microsoft.com>
Co-authored-by: yidoe <68296935+yidoe@users.noreply.github.com>
Co-authored-by: GuanhuaWang <alexwgh333@gmail.com>
Co-authored-by: cmikeh2 <connorholmes@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Conglong Li <conglong.li@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Ramya Ramineni <62723901+rraminen@users.noreply.github.com>
  • Loading branch information
22 people committed Sep 11, 2023
1 parent 542dc0d commit aa4a740
Show file tree
Hide file tree
Showing 14 changed files with 1,232 additions and 18 deletions.
8 changes: 8 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,11 @@ void launch_sr_fake_quantize_kernel_asym(T* vals,
int group_num,
int num_bits,
cudaStream_t stream);

void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size,
cudaStream_t stream);
23 changes: 23 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,26 @@ at::Tensor dequantize(at::Tensor& quantized_data,
return output;
}

at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in,
at::Tensor& scale_buffer,
at::Tensor& min_val_buffer,
int num_group,
int group_size)
{
auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto output = torch::empty({num_group, group_size}, output_options);

launch_dequantize_int4_to_half_experimental((uint8_t*)data_in.data_ptr(),
(half*)output.data_ptr(),
(half*)scale_buffer.data_ptr(),
(half*)min_val_buffer.data_ptr(),
num_group,
group_size,
at::cuda::getCurrentCUDAStream());

return output;
}

std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
Expand Down Expand Up @@ -247,6 +267,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("quantize", &quantize_kernel);
m.def("dequantize", &dequantize<__half>);
m.def("dequantize_fp32", &dequantize<float>);
m.def("dequantize_int4_to_half_experimental",
&dequantize_int4_to_half_experimental,
"Dequantize int4 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
}
230 changes: 230 additions & 0 deletions csrc/quantization/quantize_int4.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include <assert.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "memory_access_utils.h"

template <typename T, int N>
struct alignas(sizeof(T) * N) AlignedArray {
using Element = T;
static const int kElements = N;

__device__ __host__ AlignedArray() {}

__device__ __host__ AlignedArray(const T& rhs)
{
#pragma unroll
for (int idx = 0; idx < kElements; ++idx) { this->at(idx) = rhs; }
}

__device__ __host__ T& operator[](int offset)
{
return reinterpret_cast<T&>(this->buffer[offset]);
}

__device__ __host__ const T& operator[](int offset) const
{
return reinterpret_cast<const T&>(this->buffer[offset]);
}

__device__ __host__ T& at(int offset) { return reinterpret_cast<T&>(this->buffer[offset]); }

__device__ __host__ const T& at(int offset) const
{
return reinterpret_cast<const T&>(this->buffer[offset]);
}

__device__ __host__ AlignedArray<T, N> operator+(const AlignedArray<T, N>& rhs) const
{
AlignedArray<T, N> ret;

#pragma unroll
for (int idx = 0; idx < kElements; ++idx) { ret[idx] = this->at(idx) + rhs.at(idx); }

return ret;
}

__device__ __forceinline__ void clear()
{
#pragma unroll
for (int idx = 0; idx < kElements; ++idx) { this->at(idx) = Element(0); }
}

Element buffer[N];
};

template <typename T>
struct reduce_max {
__device__ __forceinline__ T operator()(const T& lhs, const T& rhs)
{
return lhs > rhs ? lhs : rhs;
}
};

template <typename T>
struct reduce_min {
__device__ __forceinline__ T operator()(const T& lhs, const T& rhs)
{
return lhs < rhs ? lhs : rhs;
}
};

template <typename T, int N>
struct subtract {
__device__ __forceinline__ AlignedArray<T, N> operator()(const AlignedArray<T, N>& lhs,
const T& rhs)
{
AlignedArray<T, N> ret;

#pragma unroll
for (int idx = 0; idx < N; ++idx) { ret[idx] = lhs[idx] - rhs; }

return ret;
}
};

template <typename T, int N>
struct plus {
__device__ __forceinline__ AlignedArray<T, N> operator()(const AlignedArray<T, N>& lhs,
const T& rhs)
{
AlignedArray<T, N> ret;

#pragma unroll
for (int idx = 0; idx < N; ++idx) { ret[idx] = lhs[idx] + rhs; }

return ret;
}
};

template <typename T, int N>
struct multiply {
__device__ __forceinline__ AlignedArray<T, N> operator()(const AlignedArray<T, N>& lhs,
const T& rhs)
{
AlignedArray<T, N> ret;

#pragma unroll
for (int idx = 0; idx < N; ++idx) { ret[idx] = lhs[idx] * rhs; }

return ret;
}
};

template <typename T, int N>
struct clamp {
__device__ __forceinline__ AlignedArray<T, N> operator()(const AlignedArray<T, N>& lhs,
const T& min_val,
const T& max_val)
{
AlignedArray<T, N> ret;

#pragma unroll
for (int idx = 0; idx < N; ++idx) {
ret[idx] = reduce_max<T>()(reduce_min<T>()(lhs[idx], max_val), min_val);
}

return ret;
}
};

template <typename T, int N>
struct round_int;

template <int N>
struct round_int<half, N> {
__device__ __forceinline__ AlignedArray<half, N> operator()(const AlignedArray<half, N>& lhs)
{
AlignedArray<half, N> ret;

#pragma unroll
for (int idx = 0; idx < N; ++idx) { ret[idx] = hrint(lhs[idx]); }

return ret;
}
};

template <typename T, int N>
struct divide {
__device__ __forceinline__ AlignedArray<T, N> operator()(const AlignedArray<T, N>& lhs,
const T& rhs)
{
AlignedArray<T, N> ret;

#pragma unroll
for (int idx = 0; idx < N; ++idx) { ret[idx] = lhs[idx] / rhs; }

return ret;
}
};

template <typename T, int N, typename Reducer>
__device__ __forceinline__ T to_scalar(const AlignedArray<T, N>& data)
{
Reducer re;
T res = data[0];

#pragma unroll
for (int idx = 1; idx < N; ++idx) { res = re(res, data[idx]); }

return res;
}

template <int N>
__device__ __forceinline__ AlignedArray<half, N * 2> int4_to_half(
const AlignedArray<uint8_t, N>& data)
{
AlignedArray<half, N * 2> ret;

#pragma unroll
for (int idx = 0; idx < N * 2; idx += 2) {
ret[idx] = half(int(data[idx / 2] >> 4));
ret[idx + 1] = half(int(data[idx / 2] & 0xf));
}

return ret;
}

__global__ void dequantize_int4_to_half(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size)
{
using AccessType = AlignedArray<uint8_t, 4>;
using AccessTypeOut = AlignedArray<half, 8>;

for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < num_group * group_size / 8;
idx += blockDim.x * gridDim.x) {
int id_group = idx / (group_size / 8);
AccessType value = reinterpret_cast<AccessType*>(data_in)[idx];
half scale = scale_buffer[id_group];
half min_value = min_val_buffer[id_group];

AccessTypeOut output = int4_to_half(value);
output = divide<half, 8>()(output, scale);
output = plus<half, 8>()(output, min_value);

reinterpret_cast<AccessTypeOut*>(data_out)[idx] = output;
}
}

void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size,
cudaStream_t stream)
{
int num_warp = num_group / 4;
int num_block = num_warp / 8; // 256 trd / block

dequantize_int4_to_half<<<num_block, 256, 0, stream>>>(
data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size);
}
2 changes: 0 additions & 2 deletions deepspeed/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,3 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .engine import InferenceEngine
2 changes: 2 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class BaseQuantConfig(DeepSpeedConfigModel):

class WeightQuantConfig(BaseQuantConfig):
enabled = True
quantized_initialization: Dict = {}
post_init_quant: Dict = {}


class ActivationQuantConfig(BaseQuantConfig):
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/inference/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
Loading

0 comments on commit aa4a740

Please sign in to comment.