Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Compressedbackend for Onebit optimizers #5473

Merged
merged 19 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@ def get_op_builder(self, class_name):
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
from op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder
from op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder, PackbitsBuilder
except ImportError:
from deepspeed.ops.op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder
from deepspeed.ops.op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder, PackbitsBuilder

if class_name == "AsyncIOBuilder":
return AsyncIOBuilder
Expand All @@ -278,6 +278,8 @@ def get_op_builder(self, class_name):
return CPUAdamBuilder
elif class_name == "FusedAdamBuilder":
return FusedAdamBuilder
elif class_name == "PackbitsBuilder":
return PackbitsBuilder
else:
return None

Expand Down
100 changes: 100 additions & 0 deletions csrc/xpu/packbits/packing.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include <ipex.h>
#include <torch/extension.h>
#include <iostream>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace xpu;

void packbitskernel(const float* input, uint8_t* output, const int input_size, id<1> item_ct1)
{
// get the sign bit of each float and pack them into byte
int i = item_ct1;
for (int j = 0; j < 8; ++j) {
int k = i * 8 + j;
int bit = k < input_size && (!sycl::signbit(input[k]));
output[i] |= bit << (7 - j);
}
}

void unpackbitskernel(const uint8_t* input, float* output, id<1> item_ct1)
{
// use the bit value to set float, bit 0 -> float -1, bit 1 -> float 1
int i = item_ct1;
output[i] = (float((input[i / 8] >> (7 - i % 8)) & 1) - 0.5) * 2;
}

sycl::queue get_current_queue(at::Device device)
{
c10::impl::VirtualGuardImpl impl(device.type());
c10::Stream _stream = impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false);
sycl::queue queue = xpu::get_queue_from_stream(_stream);
return queue;
}

at::Tensor packbits(at::Tensor tensor, int input_size, int rank)
{
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Liangliang-Ma the function documentation needs to be moved to line 39 right before the function def line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

pack float tensor into uint8 tensor. Every eight float elements get packed into one uint8
if float x >= 0, will be packed as a '1' bit, or will be packed as '0'
Arguments:
tensor: A bool tensor that get packed.
input_size: numel of input tensor
rank: device id in order to get corresponding stream
*/
at::Device device = "xpu:" + std::to_string(rank);
sycl::queue q = get_current_queue(device);

int packed_size = (input_size + 7) / 8;
auto unit8_options = at::TensorOptions().dtype(at::kByte).device(at::kXPU);
at::Tensor packed = torch::zeros({packed_size}, unit8_options);

float* input = (float*)tensor.data_ptr();
uint8_t* output = (uint8_t*)packed.data_ptr();

auto event = q.submit([&](sycl::handler& cgh) {
cgh.parallel_for<>(range(packed_size), [=](id<1> item_ct1) {
packbitskernel(input, output, input_size, item_ct1);
});
});

return packed;
}

at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank)
{
/*
unpack uint8 tensor into float tensor. Every uint8 element get unpacked into eight float
a '1' bit will be converted to a float(1), a '0' bit will be converted to a float(-1).
Arguments:
tensor: A uint8 tensor that get unpacked.
input_size: numel of input tensor
rank: device id in order to get corresponding stream
*/
at::Device device = "xpu:" + std::to_string(rank);
sycl::queue q = get_current_queue(device);

auto float_options = at::TensorOptions().dtype(at::kFloat).device(at::kXPU);
at::Tensor unpacked = torch::empty({input_size * 8}, float_options);

uint8_t* input = (uint8_t*)tensor.data_ptr();
float* output = (float*)unpacked.data_ptr();

auto event = q.submit([&](sycl::handler& cgh) {
cgh.parallel_for<>(range(input_size * 8),
[=](id<1> item_ct1) { unpackbitskernel(input, output, item_ct1); });
});

return unpacked;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("packbits", &packbits, "DeepSpeed XPU packbits (C++)");
m.def("unpackbits", &unpackbits, "DeepSpeed XPU unpackbits (C++)");
}
137 changes: 137 additions & 0 deletions deepspeed/runtime/comm/compressed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import numpy as np
import torch
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import PackbitsBuilder


class CompressedBackend(object):

def __init__(self, mpu=None):
if mpu is None:
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
else:
self.mpu = mpu
self.world_group = self.mpu.get_data_parallel_group()
self.size = dist.get_world_size(group=self.world_group)
self.rank = dist.get_rank(group=self.world_group)
self.packer = PackbitsBuilder().load()

def my_igather(self, rank, size, group, sendbuf, recvbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(dist.irecv(recvbuf[idx], src=idx, group=group))
else:
recvbuf[rank] = sendbuf
else:
req.append(dist.isend(sendbuf, group=group, dst=root))
return req

def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
if rank == root:
for idx in range(size):
if idx != rank:
dist.recv(recvbuf[idx], src=idx, group=group)
else:
recvbuf[rank] = sendbuf
else:
dist.send(sendbuf, group=group, dst=root)

def pack(self, buffer, size):
# pack float tensor into uint8 tensor
packed = self.packer.packbits(buffer.float(), buffer.numel(), self.rank)
return packed.reshape(size, -1)

def unpack(self, buffer, size, dtype):
# unpack uint8 to float tensor
unpacked = self.packer.unpackbits(buffer, buffer.numel(), self.rank)
return unpacked.reshape(size, -1).to(dtype)

def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank):
original_shape = buffer_m.size()
if len(original_shape) > 1:
buffer_m = torch.flatten(buffer_m)

# align size of original_buffer and error
original_size = buffer_m.numel()
worker_error_size = worker_error.numel()
if original_size != worker_error_size:
empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])

buffer_m.add_(worker_error)
worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))

worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

sign_list_packed_tmp = self.pack(buffer_m, self.size).type(torch.int8)

recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])],
dtype=sign_list_packed_tmp[0].dtype,
device=sign_list_packed_tmp.device)

sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)]

recvbuf_scale = [
torch.zeros(1, dtype=worker_scale.dtype, device=get_accelerator().current_device_name())
for _ in range(self.size)
]

# communication phase 1
# all to all for sign
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group)
# all gather for scale
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)

flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten()
compensated_server_m = self.unpack(flattened_recvbuf_sign, self.size, torch.float32) \
.mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)

compensated_server_m.add_(server_error)

server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())

server_error.set_(compensated_server_m -
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

server_sign_packed = self.pack(compensated_server_m, 1).type(torch.int8)

# recvbuf_sign_server
recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])],
dtype=recvbuf_sign.dtype,
device=server_sign_packed.device)

recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)]

# recvbuf_scale_server
recvbuf_scale_server_tmp = torch.zeros([self.size, 1],
dtype=worker_scale.dtype,
device=server_sign_packed.device)

recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)]

# communication Phase 2
dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group)
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)

recvbuf_sign_server = torch.stack(recvbuf_sign_server)

flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten()

buffer_m.data.copy_(
self.unpack(flattened_recvbuf_sign_server, self.size,
torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data)

if original_size != worker_error_size:
buffer_m = buffer_m[0:original_size]
if len(original_shape) > 1:
buffer_m = buffer_m.reshape(original_shape)

return buffer_m
4 changes: 4 additions & 0 deletions deepspeed/runtime/fp16/onebit/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def __init__(self,
from deepspeed.runtime.comm.hccl import HcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
elif self.comm_backend_name == 'compressed':
from deepspeed.runtime.comm.compressed import CompressedBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = CompressedBackend(self.deepspeed.mpu)
self.size = self.comm_backend_handle.size

self.divider = int(self.size * 8 / np.gcd(self.size, 8))
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/fp16/onebit/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def __init__(self,
from deepspeed.runtime.comm.hccl import HcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
elif self.comm_backend_name == 'compressed':
from deepspeed.runtime.comm.compressed import CompressedBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = CompressedBackend(self.deepspeed.mpu)

self.size = self.comm_backend_handle.size

Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/fp16/onebit/zoadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def __init__(self,
from deepspeed.runtime.comm.hccl import HcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
elif self.comm_backend_name == 'compressed':
from deepspeed.runtime.comm.compressed import CompressedBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = CompressedBackend(self.deepspeed.mpu)
self.size = self.comm_backend_handle.size

self.divider = int(self.size * 8 / np.gcd(self.size, 8))
Expand Down
1 change: 1 addition & 0 deletions op_builder/xpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .cpu_adagrad import CPUAdagradBuilder
from .fused_adam import FusedAdamBuilder
from .async_io import AsyncIOBuilder
from .packbits import PackbitsBuilder
26 changes: 26 additions & 0 deletions op_builder/xpu/packbits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from .builder import SYCLOpBuilder


class PackbitsBuilder(SYCLOpBuilder):
BUILD_VAR = "DS_BUILD_PACK_BITS"
NAME = "pack_bits"

def __init__(self):
super().__init__(name=self.NAME)

def absolute_name(self):
return f'deepspeed.ops.{self.NAME}_op'

def sources(self):
return ['csrc/xpu/packbits/packing.cpp']

def include_paths(self):
return ['csrc/xpu/includes']

def cxx_args(self):
args = super().cxx_args()
return args + self.version_dependent_macros()