Skip to content

Commit

Permalink
[embedding] rename FreqAwareEmbedding -> CachedEmbedding (#1699)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Oct 13, 2022
1 parent 0e52f3d commit 21962e1
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 76 deletions.
10 changes: 5 additions & 5 deletions colossalai/nn/parallel/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module

from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
from .cache_embedding import CachedEmbeddingBag, ParallelCachedEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \
ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewiseSpiltCache

__all__ = [
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr',
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
]
14 changes: 7 additions & 7 deletions colossalai/nn/parallel/layers/cache_embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .cache_mgr import CachedParamMgr, EvictionStrategy
from .copyer import LimitBuffIndexCopyer
from .freq_aware_embedding import FreqAwareEmbeddingBag
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
from .cached_embedding import CachedEmbeddingBag
from .parallel_cached_embedding import ParallelCachedEmbeddingBag
from .embedding_config import TablewiseEmbeddingBagConfig
from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise
from .parallel_freq_aware_embedding_tablewise_split_cache import ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise
from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache

__all__ = [
'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag',
'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy',
'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
]
3 changes: 2 additions & 1 deletion colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor:

# move sure the cuda rows will not be evicted!
with record_function("(cache) prepare_rows_on_cuda"):
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
with self.timer("prepare_rows_on_cuda") as timer:
self._prepare_rows_on_cuda(comm_cpu_row_idxs)

self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from torch.nn.parameter import Parameter


class FreqAwareEmbeddingBag(BaseEmbeddingBag):
"""FreqAwareEmbeddingBag
class CachedEmbeddingBag(BaseEmbeddingBag):
"""CachedEmbeddingBag
Frequency Aware Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space.
Cached Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space.
It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`.
You can also apply a navie LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU.
Expand Down Expand Up @@ -54,8 +54,8 @@ def __init__(self,
buffer_size: int = 0,
pin_weight: bool = False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
scale_grad_by_freq, sparse, mode, include_last_offset)
super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
scale_grad_by_freq, sparse, mode, include_last_offset)

assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0"
self.evict_strategy = evict_strategy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F
from typing import List, Optional, Iterator, Tuple

from .freq_aware_embedding import FreqAwareEmbeddingBag
from .cached_embedding import CachedEmbeddingBag
from colossalai.nn._ops._utils import dual_all_to_all

from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
Expand All @@ -28,7 +28,7 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
return offset, offset + size_list[rank], False


class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
class ParallelCachedEmbeddingBag(CachedEmbeddingBag):

def __init__(self,
num_embeddings,
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self,
embedding_dim, self.rank, self.world_size)
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index

super(ParallelFreqAwareEmbeddingBag,
super(ParallelCachedEmbeddingBag,
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy)
Expand Down Expand Up @@ -115,7 +115,7 @@ def from_pretrained(
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio: float = 0.7,
buffer_size: int = 0,
) -> 'ParallelFreqAwareEmbeddingBag':
) -> 'ParallelCachedEmbeddingBag':
rows, cols = embedding.shape
embedding_bag = cls(rows,
cols,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.distributed as dist
import torch.nn.functional as F

from .freq_aware_embedding import FreqAwareEmbeddingBag
from .cached_embedding import CachedEmbeddingBag
from .cache_mgr import EvictionStrategy
from .embedding_config import TablewiseEmbeddingBagConfig
from colossalai.tensor import ProcessGroup
Expand All @@ -12,9 +12,9 @@
import time


class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
"""
all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag.
all tables assigned to this class instance are managed by a single CachedEmbeddingBag.
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
"""

Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(self,
self.cache_ratio = cache_ratio
# table-associate cache
cuda_row_num = int(cache_ratio * self.num_embeddings)
super(ParallelFreqAwareEmbeddingBagTablewise,
super(ParallelCachedEmbeddingBagTablewise,
self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn
from torch.profiler import record_function

from .freq_aware_embedding import FreqAwareEmbeddingBag
from .cached_embedding import CachedEmbeddingBag

from colossalai.tensor import ProcessGroup
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
Expand All @@ -14,9 +14,9 @@
import abc


class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
"""
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
every table assigned to this class instance is managed by a CachedEmbeddingBag.
"""

def __init__(self,
Expand All @@ -34,7 +34,7 @@ def __init__(self,
warmup_ratio=0.7,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
super(ParallelFreqAwareEmbeddingBagTablewiseSpiltCache, self).__init__()
super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
Expand All @@ -49,31 +49,31 @@ def __init__(self,
self.include_last_offset = include_last_offset
self.pg = ProcessGroup(tp_degree=self.world_size)

# prepare FreqAwareEmbeddingBag list
# prepare CachedEmbeddingBag list

self.freq_aware_embedding_bag_list: nn.ModuleList = nn.ModuleList()
self.cached_embedding_bag_list: nn.ModuleList = nn.ModuleList()
for config in embedding_bag_config_list:
if config.assigned_rank != self.rank:
continue
self.freq_aware_embedding_bag_list.append(
FreqAwareEmbeddingBag(num_embeddings=config.num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=config.initial_weight,
mode=mode,
include_last_offset=include_last_offset,
dtype=dtype,
device=device,
cuda_row_num=config.cuda_row_num,
ids_freq_mapping=config.ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=config.buffer_size,
pin_weight=pin_weight,
evict_strategy=evict_strategy))
self.cached_embedding_bag_list.append(
CachedEmbeddingBag(num_embeddings=config.num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=config.initial_weight,
mode=mode,
include_last_offset=include_last_offset,
dtype=dtype,
device=device,
cuda_row_num=config.cuda_row_num,
ids_freq_mapping=config.ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=config.buffer_size,
pin_weight=pin_weight,
evict_strategy=evict_strategy))

# prepare list shape for all_to_all output
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
Expand Down Expand Up @@ -109,8 +109,8 @@ def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sampl
if per_sample_weights != None:
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
with record_function("(tablewise) tablewise forward"):
local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets,
local_per_sample_weights))
local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets,
local_per_sample_weights))

# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
local_output = torch.cat(local_output_list, 1)
Expand All @@ -126,13 +126,13 @@ def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sampl
def element_size(self):
if len(self.assigned_table_list) == 0:
return 0
return self.freq_aware_embedding_bag_list[0].cache_weight_mgr.weight.element_size()
return self.cached_embedding_bag_list[0].cache_weight_mgr.weight.element_size()

def print_comm_stats_(self):
cuda_to_cpu_elem_num = 0
cpu_to_cuda_elem_num = 0
for freq_aware_embedding_bag in self.freq_aware_embedding_bag_list:
cuda_to_cpu_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
cpu_to_cuda_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
for cached_embedding_bag in self.cached_embedding_bag_list:
cuda_to_cpu_elem_num += cached_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
cpu_to_cuda_elem_num += cached_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem")
print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem")
38 changes: 19 additions & 19 deletions tests/test_layers/test_cache_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \
ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
from typing import List

NUM_EMBED, EMBED_DIM = 10, 8
Expand Down Expand Up @@ -106,13 +106,13 @@ def test_reorder_with_freq():
def test_freq_aware_embed(use_LFU: bool):
device = torch.device('cuda', 0)
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
model = FreqAwareEmbeddingBag(NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)
model = CachedEmbeddingBag(NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)

assert model.weight.shape[0] == NUM_EMBED
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
Expand Down Expand Up @@ -151,14 +151,14 @@ def test_freq_aware_embed(use_LFU: bool):
@pytest.mark.parametrize('init_freq', [True, False])
def test_lfu_strategy(init_freq: bool):
# minimal test to check behavior
Bag = FreqAwareEmbeddingBag(5,
5,
cache_ratio=3 / 5,
buffer_size=0,
pin_weight=True,
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
warmup_ratio=1.0,
evict_strategy=EvictionStrategy.LFU)
Bag = CachedEmbeddingBag(5,
5,
cache_ratio=3 / 5,
buffer_size=0,
pin_weight=True,
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
warmup_ratio=1.0,
evict_strategy=EvictionStrategy.LFU)

# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
offsets = torch.tensor([0], device="cuda:0")
Expand Down Expand Up @@ -233,7 +233,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
_weight = torch.cat([weight_table1, weight_table2], 0)
else:
_weight = weight_table3
model = ParallelFreqAwareEmbeddingBagTablewise(
model = ParallelCachedEmbeddingBagTablewise(
embedding_bag_config_list,
embedding_dim=5,
_weight=_weight,
Expand Down Expand Up @@ -300,7 +300,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))

model = ParallelFreqAwareEmbeddingBag.from_pretrained(
model = ParallelCachedEmbeddingBag.from_pretrained(
coloweight,
include_last_offset=True,
freeze=False,
Expand Down

0 comments on commit 21962e1

Please sign in to comment.