Skip to content

Commit

Permalink
OptimizedLinear implementation (#5355)
Browse files Browse the repository at this point in the history
Optimized version of `nn.Linear` that adds features such as:
      * LoRA w. base weight sharding
      * FP [6,8,12] quantization

Depends on #5336 being merged first

Co-authored-by: @rajhans
Co-authored-by: @aurickq

---------

Co-authored-by: Rajhans Samdani <rajhans.samdani@snowflake.com>
Co-authored-by: Jeff Rasley <jeff.rasley@snowflake.com>
  • Loading branch information
3 people committed Apr 23, 2024
1 parent c66bc42 commit 5e6c9b9
Show file tree
Hide file tree
Showing 8 changed files with 550 additions and 4 deletions.
7 changes: 7 additions & 0 deletions deepspeed/linear/__init__.py
@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .optimized_linear import OptimizedLinear
from .config import LoRAConfig, QuantizationConfig
39 changes: 39 additions & 0 deletions deepspeed/linear/config.py
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from dataclasses import dataclass


@dataclass
class LoRAConfig:
"""
Configuration settings for LoRAOptimizedLinear.
Attributes:
lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64.
lora_alpha (float): LoRA scaling factor, default is 16.
base_weight_sharding (int): The degree to which the base weights are sharded,
should typically be set to the data-parallel world size to maximize the memory
reduction benefits. Defaults to 1, which means this feature is disabled.
"""
lora_r: int = 64
lora_alpha: float = 16.
base_weight_sharding: int = 1


@dataclass
class QuantizationConfig:
"""
Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear,
and QuantizedParameter
Attributes:
q_bits (int): The number of bits used for quantization. Default is 8.
mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3.
group_size (int): The size of the group used for quantization. Default is 512.
"""
q_bits: int = 8
mantissa_bits: int = 3
group_size: int = 512
150 changes: 150 additions & 0 deletions deepspeed/linear/optimized_linear.py
@@ -0,0 +1,150 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import is_dataclass
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist

from .config import LoRAConfig, QuantizationConfig
from .quantization import QuantizedParameter, QuantizedLinear


class OptimizedLinear(nn.Module):
"""
Optimized version of nn.Linear that adds features such as:
* LoRA w. base weight sharding
* FP [6,8,12] quantization
Arguments:
input_dim: Required: size of each input sample
output_dim: Required: size of each output sample
bias: Optional: If set to False, the layer will not learn an additive bias. Default: False
lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree
quantization_config: Optional: QuantizationConfig defining quantization features
dtype: Optional: parameter dtype, only supports bfloat16 currently
Returns:
Returns a new nn.Module depending on the input config. Either native
torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear.
"""

def __new__(self,
input_dim: int,
output_dim: int,
bias: bool = False,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
dtype=torch.bfloat16):

if quantization_config is not None and not is_dataclass(quantization_config):
raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}")
if lora_config is not None and not is_dataclass(lora_config):
raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}")
if lora_config is None and quantization_config is None:
# Everything disabled, fall back to normal nn.Linear
self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype)

elif lora_config:
# lora enabled, quantization may or may not be
self = LoRAOptimizedLinear(input_dim=input_dim,
output_dim=output_dim,
bias=bias,
lora_config=lora_config,
quantization_config=quantization_config,
dtype=dtype)

elif quantization_config:
# only quantization enabled, no lora
self = QuantizedLinear(input_dim=input_dim,
output_dim=output_dim,
bias=bias,
quantization_config=quantization_config,
dtype=dtype)
return self


class LoRAOptimizedLinear(nn.Module):

def __init__(self,
input_dim: int,
output_dim: int,
bias: bool = False,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
device=None,
dtype=torch.bfloat16):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.bias = bias
self.lora_config = lora_config
self.quantization_config = quantization_config
device = get_accelerator().current_device() if device is None else device
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"

self.zero_shards = self.lora_config.base_weight_sharding
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype))
torch.nn.init.xavier_uniform_(w)

if self.quantization_config is not None:
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
self.base_weight = QuantizedParameter(w, quantization_config=quantization_config)
else:
self.base_weight = w

self.base_weight.requires_grad = False

# Use RS lora for now.
self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r)
# Keeping lora weights in bf16 precision for ease of training.
self.lora_weight_1 = nn.Linear(self.input_dim,
self.lora_config.lora_r,
bias=self.bias,
device=device,
dtype=dtype)
self.lora_weight_2 = nn.Linear(self.lora_config.lora_r,
self.output_dim,
bias=self.bias,
device=device,
dtype=dtype)
self.lora_weight_1.weight.requires_grad = True
self.lora_weight_2.weight.requires_grad = True

def full_weight(self):
# This assumes weights are evenly sharded across gpus. which might not be correct.
# in that case, we should flatten before all_gather.
local_weight = self.base_weight.dequantized() if isinstance(self.base_weight,
QuantizedParameter) else self.base_weight
tensor_list = [
torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype)
for _ in range(self.zero_shards)
]
dist.all_gather(tensor_list, local_weight)
weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1))
return weight

def linear_without_F_linear(self, input, weight):
output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
output = output.view(*input.shape[:-1], weight.shape[1])
return output

def forward(self, input_tensor):
# Gather the sharded base weight
if self.zero_shards > 1:
with torch.no_grad():
base_weight = self.full_weight()
elif self.quantization_config:
base_weight = self.base_weight.dequantized()
else:
base_weight = self.base_weight

base_weight_output = F.linear(input_tensor, base_weight)
lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
return base_weight_output + self.lora_scaling_factor * lora_output
137 changes: 137 additions & 0 deletions deepspeed/linear/quantization.py
@@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import copy
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional

from deepspeed.accelerator import get_accelerator
from deepspeed.ops.fp_quantizer import Quantizer, FP_Quantize
from .config import QuantizationConfig


class QuantizedParameter(nn.Parameter):
"""
Quantized parameter class that implements weight quantization. Weights
are stored in quantized form on GPUs, and can be dequantized on-the-fly when
needed by the model. The weights are actually quantized during any `.to(device)`.
Arguments:
data (Tensor): parameter tensor.
requires_grad (bool, optional): if the parameter requires gradient. Defaults
to False and is not supported to be True. Argument provided only for interface
compatibility with torch.nn.Parameter.
quantization_config (QuantizationConfig, optional):
quantizer (Quantizer, optional): Defaults to FP_Quantize but can be any quantizer
that implements deepspeed.ops.fp_quantizer.Quantizer. This argument is also
required since the quantizer is stashed in the Parameter itself, some models
may clone the Parameter by passing an attribute __dict__. For an example, see
tests/unit/linear/test_quant_param.py::TestQuantParam::test_hf_clone
"""

def __new__(
cls,
data: Optional[torch.Tensor] = None,
requires_grad: bool = False, # quantized weights must be frozen
quantization_config: QuantizationConfig = None,
quantizer: Quantizer = None,
):
if requires_grad:
raise ValueError(f"requires_grad=True is not supported with QuantizedParameter")
if data is None:
data = torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config
if quantizer is not None:
self.quantizer = quantizer
else:
# if FPQuantizerBuilder is not compatible in this env this init will fail
self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size)
self._ensure_quantized(self)
return self

def _ensure_quantized(self, tensor: torch.Tensor):
# If the tensor is on the accelerator and is not quantized, then quantize it in-place.
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8:
with get_accelerator().stream(get_accelerator().current_stream(tensor.device)):
tensor.data = self.quantizer.quantize(tensor.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
assert tensor.dtype == torch.int8

def dequantized(self) -> torch.Tensor:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8:
with get_accelerator().stream(get_accelerator().current_stream(self.data.device)):
return self.quantizer.dequantize(self.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
return self.data

def __getstate__(self):
state = self.__dict__
state["data"] = self.data
state["quantization_config"] = self.quantization_config
state["requires_grad"] = self.requires_grad
return state

def __setstate__(self, state):
self.quantizer = state["quantizer"]
self.quantization_config = state["quantization_config"]
self.data = state["data"]
self.requires_grad = state["requires_grad"]

def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
new_instance.quantizer = copy.deepcopy(state["quantizer"])
new_instance.quantization_config = copy.deepcopy(state["quantization_config"])
new_instance.data = copy.deepcopy(state["data"])
return new_instance

def __copy__(self):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance

def cuda(self, device=None, non_blocking=False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)

def to(self, *args, **kwargs):
"""
Move the parameter to the given device. Then, if the device is a cuda device,
quantize it.
"""
tensor = super().to(*args, **kwargs)
self._ensure_quantized(tensor)
return tensor


class QuantizedLinear(nn.Linear):
"""
Linear layer that implements weight quantization. Parameters
are stored via `QuantizedParameter` and are dequantized on-the-fly during any
forward pass.
"""

def __init__(self,
input_dim: int,
output_dim: int,
bias: bool = False,
quantization_config: QuantizationConfig = None,
dtype=torch.bfloat16):
super().__init__(input_dim, output_dim, bias=bias, dtype=dtype)
assert dtype == torch.bfloat16, "currently only supports bfloat16 dtype"
self.weight = QuantizedParameter(self.weight.data, quantization_config=quantization_config)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.dequantized(), self.bias)
2 changes: 1 addition & 1 deletion deepspeed/ops/fp_quantizer/__init__.py
Expand Up @@ -3,4 +3,4 @@

# DeepSpeed Team

from .quantize import FP_Quantize
from .quantize import FP_Quantize, Quantizer
33 changes: 30 additions & 3 deletions deepspeed/ops/fp_quantizer/quantize.py
Expand Up @@ -4,20 +4,47 @@
# DeepSpeed Team

import torch
import abc
from abc import ABC

from deepspeed.ops.op_builder import FPQuantizerBuilder

fp_quant_module = None


class FP_Quantize:
class Quantizer(ABC):
"""
Abstract Quantizer class that implmenents quantize/dequantize methods.
Arguments:
group_size (int, optional): number of values or elements that are grouped
together for the quantization process.
"""

def __init__(self, group_size=512) -> None:
self.group_size = group_size

@abc.abstractmethod
def quantize(self,
input,
q_bits=8,
q_mantisa_bits=3,
stochastic_mode=False,
return_meta_tensor=False) -> torch.Tensor:
...

@abc.abstractmethod
def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
...


class FP_Quantize(Quantizer):

def __init__(self, group_size=512) -> None:
global fp_quant_module
super().__init__(group_size=group_size)
if fp_quant_module is None:
fp_quant_module = FPQuantizerBuilder().load()

self.group_size = group_size
self.orig_dtype = None

def quantize(self,
Expand Down

0 comments on commit 5e6c9b9

Please sign in to comment.