Skip to content

Commit

Permalink
[embeddings] add already_split_along_rank flag for tablewise mode (#1584
Browse files Browse the repository at this point in the history
)
  • Loading branch information
CsRic committed Sep 13, 2022
1 parent 77399dc commit f3403ff
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from colossalai.nn._ops._utils import dual_all_to_all_tablewise

from typing import List
import time


class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
Expand Down Expand Up @@ -79,8 +80,43 @@ def __init__(self,
for rank in self.rank_of_tables:
self.embedding_dim_per_rank[rank] += embedding_dim

def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
batch_size = (offsets.shape[0]) // self.global_tables_num
def forward(self,
indices: torch.Tensor,
offsets: torch.Tensor = None,
per_sample_weights=None,
shape_hook=None,
already_split_along_rank=True):
if not already_split_along_rank:
# not recommanded. it takes time.
batch_size = (offsets.shape[0]) // self.global_tables_num
local_indices, local_offsets, local_per_sample_weights = self.split_along_rank(
batch_size, indices, offsets, per_sample_weights)
else:
# recommanded.
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
local_per_sample_weights, self.include_last_offset, self.padding_idx)
local_output = torch.cat(local_output.split(batch_size), 1)
remains = batch_size % self.world_size
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
if shape_hook is not None:
output_full = shape_hook(output_full)
return output_full

def split_along_rank(self,
batch_size,
indices: torch.Tensor,
offsets: torch.Tensor = None,
per_sample_weights=None):
'''
if input indices and offsets haven't been splitted along assigned rank, this function will do it.
it takes time. please consider splitting data during batch loading.
'''
local_indices_list: List(torch.Tensor) = []
local_offsets_list: List(torch.Tensor) = []
if per_sample_weights != None:
Expand Down Expand Up @@ -145,20 +181,7 @@ def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sampl
local_per_sample_weights = None
if per_sample_weights != None:
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)

local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
local_per_sample_weights, self.include_last_offset, self.padding_idx)
local_output = torch.cat(local_output.split(batch_size), 1)

remains = batch_size % self.world_size
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
if shape_hook is not None:
output_full = shape_hook(output_full)
return output_full
return local_indices, local_offsets, local_per_sample_weights

def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_layers/test_cache_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
in KJT format
'''
res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device),
already_split_along_rank=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
if rank == 0:
Expand Down

0 comments on commit f3403ff

Please sign in to comment.