Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[embedding] rename FreqAwareEmbedding -> CachedEmbedding #1699

Merged
merged 5 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions colossalai/nn/parallel/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module

from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
from .cache_embedding import CachedEmbeddingBag, ParallelCachedEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \
ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewiseSpiltCache

__all__ = [
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr',
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
]
14 changes: 7 additions & 7 deletions colossalai/nn/parallel/layers/cache_embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .cache_mgr import CachedParamMgr, EvictionStrategy
from .copyer import LimitBuffIndexCopyer
from .freq_aware_embedding import FreqAwareEmbeddingBag
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
from .cached_embedding import CachedEmbeddingBag
from .parallel_cached_embedding import ParallelCachedEmbeddingBag
from .embedding_config import TablewiseEmbeddingBagConfig
from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise
from .parallel_freq_aware_embedding_tablewise_split_cache import ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise
from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache

__all__ = [
'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag',
'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy',
'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
]
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor:

# move sure the cuda rows will not be evicted!
with record_function("(cache) prepare_rows_on_cuda"):
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
with self.timer("prepare_rows_on_cuda") as timer:
self._prepare_rows_on_cuda(comm_cpu_row_idxs)

self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from torch.nn.parameter import Parameter


class FreqAwareEmbeddingBag(BaseEmbeddingBag):
"""FreqAwareEmbeddingBag
class CachedEmbeddingBag(BaseEmbeddingBag):
"""CachedEmbeddingBag

Frequency Aware Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space.
Cached Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space.
It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`.
You can also apply a navie LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU.

Expand Down Expand Up @@ -54,8 +54,8 @@ def __init__(self,
buffer_size: int = 0,
pin_weight: bool = False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
scale_grad_by_freq, sparse, mode, include_last_offset)
super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
scale_grad_by_freq, sparse, mode, include_last_offset)

assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0"
self.evict_strategy = evict_strategy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F
from typing import List, Optional, Iterator, Tuple

from .freq_aware_embedding import FreqAwareEmbeddingBag
from .cached_embedding import CachedEmbeddingBag
from colossalai.nn._ops._utils import dual_all_to_all

from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
Expand All @@ -28,7 +28,7 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
return offset, offset + size_list[rank], False


class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
class ParallelCachedEmbeddingBag(CachedEmbeddingBag):

def __init__(self,
num_embeddings,
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self,
embedding_dim, self.rank, self.world_size)
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index

super(ParallelFreqAwareEmbeddingBag,
super(ParallelCachedEmbeddingBag,
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy)
Expand Down Expand Up @@ -115,7 +115,7 @@ def from_pretrained(
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio: float = 0.7,
buffer_size: int = 0,
) -> 'ParallelFreqAwareEmbeddingBag':
) -> 'ParallelCachedEmbeddingBag':
rows, cols = embedding.shape
embedding_bag = cls(rows,
cols,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.distributed as dist
import torch.nn.functional as F

from .freq_aware_embedding import FreqAwareEmbeddingBag
from .cached_embedding import CachedEmbeddingBag
from .cache_mgr import EvictionStrategy
from .embedding_config import TablewiseEmbeddingBagConfig
from colossalai.tensor import ProcessGroup
Expand All @@ -12,9 +12,9 @@
import time


class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
"""
all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag.
all tables assigned to this class instance are managed by a single CachedEmbeddingBag.
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
"""

Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(self,
self.cache_ratio = cache_ratio
# table-associate cache
cuda_row_num = int(cache_ratio * self.num_embeddings)
super(ParallelFreqAwareEmbeddingBagTablewise,
super(ParallelCachedEmbeddingBagTablewise,
self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn
from torch.profiler import record_function

from .freq_aware_embedding import FreqAwareEmbeddingBag
from .cached_embedding import CachedEmbeddingBag

from colossalai.tensor import ProcessGroup
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
Expand All @@ -14,9 +14,9 @@
import abc


class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
"""
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
every table assigned to this class instance is managed by a CachedEmbeddingBag.
"""

def __init__(self,
Expand All @@ -34,7 +34,7 @@ def __init__(self,
warmup_ratio=0.7,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
super(ParallelFreqAwareEmbeddingBagTablewiseSpiltCache, self).__init__()
super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
Expand All @@ -49,31 +49,31 @@ def __init__(self,
self.include_last_offset = include_last_offset
self.pg = ProcessGroup(tp_degree=self.world_size)

# prepare FreqAwareEmbeddingBag list
# prepare CachedEmbeddingBag list

self.freq_aware_embedding_bag_list: nn.ModuleList = nn.ModuleList()
self.cached_embedding_bag_list: nn.ModuleList = nn.ModuleList()
for config in embedding_bag_config_list:
if config.assigned_rank != self.rank:
continue
self.freq_aware_embedding_bag_list.append(
FreqAwareEmbeddingBag(num_embeddings=config.num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=config.initial_weight,
mode=mode,
include_last_offset=include_last_offset,
dtype=dtype,
device=device,
cuda_row_num=config.cuda_row_num,
ids_freq_mapping=config.ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=config.buffer_size,
pin_weight=pin_weight,
evict_strategy=evict_strategy))
self.cached_embedding_bag_list.append(
CachedEmbeddingBag(num_embeddings=config.num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=config.initial_weight,
mode=mode,
include_last_offset=include_last_offset,
dtype=dtype,
device=device,
cuda_row_num=config.cuda_row_num,
ids_freq_mapping=config.ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=config.buffer_size,
pin_weight=pin_weight,
evict_strategy=evict_strategy))

# prepare list shape for all_to_all output
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
Expand Down Expand Up @@ -109,8 +109,8 @@ def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sampl
if per_sample_weights != None:
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
with record_function("(tablewise) tablewise forward"):
local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets,
local_per_sample_weights))
local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets,
local_per_sample_weights))

# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
local_output = torch.cat(local_output_list, 1)
Expand All @@ -126,13 +126,13 @@ def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sampl
def element_size(self):
if len(self.assigned_table_list) == 0:
return 0
return self.freq_aware_embedding_bag_list[0].cache_weight_mgr.weight.element_size()
return self.cached_embedding_bag_list[0].cache_weight_mgr.weight.element_size()

def print_comm_stats_(self):
cuda_to_cpu_elem_num = 0
cpu_to_cuda_elem_num = 0
for freq_aware_embedding_bag in self.freq_aware_embedding_bag_list:
cuda_to_cpu_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
cpu_to_cuda_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
for cached_embedding_bag in self.cached_embedding_bag_list:
cuda_to_cpu_elem_num += cached_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
cpu_to_cuda_elem_num += cached_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem")
print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem")
38 changes: 19 additions & 19 deletions tests/test_layers/test_cache_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \
ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
from typing import List

NUM_EMBED, EMBED_DIM = 10, 8
Expand Down Expand Up @@ -106,13 +106,13 @@ def test_reorder_with_freq():
def test_freq_aware_embed(use_LFU: bool):
device = torch.device('cuda', 0)
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
model = FreqAwareEmbeddingBag(NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)
model = CachedEmbeddingBag(NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)

assert model.weight.shape[0] == NUM_EMBED
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
Expand Down Expand Up @@ -151,14 +151,14 @@ def test_freq_aware_embed(use_LFU: bool):
@pytest.mark.parametrize('init_freq', [True, False])
def test_lfu_strategy(init_freq: bool):
# minimal test to check behavior
Bag = FreqAwareEmbeddingBag(5,
5,
cache_ratio=3 / 5,
buffer_size=0,
pin_weight=True,
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
warmup_ratio=1.0,
evict_strategy=EvictionStrategy.LFU)
Bag = CachedEmbeddingBag(5,
5,
cache_ratio=3 / 5,
buffer_size=0,
pin_weight=True,
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
warmup_ratio=1.0,
evict_strategy=EvictionStrategy.LFU)

# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
offsets = torch.tensor([0], device="cuda:0")
Expand Down Expand Up @@ -233,7 +233,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
_weight = torch.cat([weight_table1, weight_table2], 0)
else:
_weight = weight_table3
model = ParallelFreqAwareEmbeddingBagTablewise(
model = ParallelCachedEmbeddingBagTablewise(
embedding_bag_config_list,
embedding_dim=5,
_weight=_weight,
Expand Down Expand Up @@ -300,7 +300,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))

model = ParallelFreqAwareEmbeddingBag.from_pretrained(
model = ParallelCachedEmbeddingBag.from_pretrained(
coloweight,
include_last_offset=True,
freeze=False,
Expand Down