In [13]:
import math
import torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
from recsys.modules.embeddings import BlockEmbeddingBag, QREmbeddingBag

# Example inputs
vocab_size = 1024*1024 # |V|
embedding_dim = 512 # E_e
block_embedding_dim = 128 # E_b

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

input_shape = (16384, 16384)
input_x = torch.randint(0, vocab_size, input_shape).to(device)
reduce_op = 'max'

# Initiate modules
blk_embed = BlockEmbeddingBag(
                vocab_size, 
                block_embedding_dim,
                embedding_dim,
                mode=reduce_op,
                device=device)

qr_embed = QREmbeddingBag(
                embedding_dim,
                num_buckets=math.ceil(math.sqrt(vocab_size)),
                verbose=False).to(device)

costly_embed = nn.EmbeddingBag(
            vocab_size, 
            embedding_dim, 
            mode=reduce_op,
            device=device)

# Query step:
import time

t1 = time.time()
blk_output = blk_embed(input_x.clone())
t2 = time.time()
costly_output = costly_embed(input_x.clone())
t3 = time.time()
# with profile(
#         activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
#         profile_memory=True, 
#         record_shapes=True,
#         on_trace_ready=tensorboard_trace_handler('prof-log/qr_emb_fw'),
# ) as prof:
qr_output = qr_embed(input_x.clone())
t4 = time.time()

grad = torch.randn(blk_output.shape).to(device)
t5 = time.time()
blk_output.backward(grad.clone())
t6 = time.time()
costly_output.backward(grad.clone())
t7 = time.time()
# qr_output.backward(grad.clone())
t8 = time.time()

print(f'Forward time comparison: {t2-t1}s : {t3-t2}s: {t4-t3}s')
print(f'Backward time comparison: {t6-t5}s : {t7-t6}s: {t8-t7}s')

def compute_mem(model): 
    mem_params = sum([param.nelement()*param.element_size() 
                    for param in model.parameters()])
    mem_bufs = sum([buf.nelement()*buf.element_size() 
                    for buf in model.buffers()])
    mem = mem_params + mem_bufs
    return mem // 1024**2

print(f'''memory comparison: blk_embed:{compute_mem(blk_embed)}mb
                        : costly_embed:{compute_mem(costly_embed)}mb
                        : qr_embed:{compute_mem(qr_embed)}mb''')

Forward time comparison: 0.0004940032958984375s : 0.00017881393432617188s: 0.00563812255859375s
Backward time comparison: 0.0012764930725097656s : 0.0002722740173339844s: 2.193450927734375e-05s
memory comparison: blk_embed:512mb
                        : costly_embed:2048mb
                        : qr_embed:4mb


For costly embed (w/ 4dp), the maximum supported embedding size is 35 * 0.5b = 17.5b
For block (mv) embed (w/ 4tp+dp), the maximum supported embedding size is 35 * 0.5b * 16 = 280b

Save exponentially small memory to the world size

world size is w
vocabulary size is V
embedding dim is E
block embedding dim is E//w

total saved space from all devices = (V//w*E - V//w*E//w - E//w*E) * w

#param limit at one device = l
new #param limit = l * w**2

Communication: allreduce step 
    - Round-robin: O((w-1)*V*E)
Computation: linear layer:
    - O(w*E*E//w)
Other cost: embedding lookup:
    - O(w*V//w)

Overall, time complexity is O(V+V*E*(w-1)+E^2) = O(V) when V >> E >> w, which is identical to one single embedding (V,E)'s lookup.

Bottleneck of qr embedding is two embedding lookups, which can be parallelized.

In [None]:
# embedding parallelism

Cache size= cache_sets * cache_lines * embedding_dim * element_size

cache_sets = 500_000
cache_lines = 1
embedding_dim = 256
element_size = 4

In [3]:
import torch 
model = torch.nn.Linear(10,10)
[param.element_size() for param in model.parameters()]

[4, 4]

In [7]:
cache_sets = 500_000
cache_lines = 1
embedding_dim = 256
element_size = 4

cache_size= cache_sets * cache_lines * embedding_dim * element_size
cache_size / 1024**3

0.476837158203125

In [19]:
import numpy as np
# Passed.
def lbmgr_fair_divide(field_dims,num_groups) -> None:
    dim_per_rank = sum(field_dims) // num_groups
    dim_indices = np.array(range(len(field_dims)))
    
    cuts = dict()
    num_cuts = 0
    
    agg = dim_per_rank
    # Find cut positions and shard groups
    for ind in dim_indices:
        while field_dims[ind] > agg:
            if num_cuts >= num_groups - 1:
                break
            if ind in cuts:
                cuts[ind].append(agg)
            else:
                cuts[ind] = [agg]
            agg += dim_per_rank
            num_cuts += 1
        
        agg -= field_dims[ind]
    
    emb_dim = dim_per_rank
    
    return cuts, emb_dim

In [21]:
cuts, emb_dim = lbmgr_fair_divide([10,200,400,2000,1000,500],4)

In [24]:
import torch
tensor = torch.randn((12,6))

In [28]:
torch.max(tensor,torch.zeros(tensor.shape))

tensor([[0.0000, 2.2175, 1.1030, 0.3941, 0.6608, 0.0000],
        [0.0000, 0.0000, 0.0000, 2.0703, 0.0000, 1.5998],
        [0.0000, 1.5906, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.8165, 0.0000, 0.0000, 0.7085],
        [0.0000, 0.0000, 0.1005, 0.0087, 1.3287, 0.7515],
        [0.0000, 0.5046, 0.0000, 0.0000, 0.0000, 0.4024],
        [0.0000, 0.0000, 0.0000, 1.4834, 0.0000, 0.0000],
        [0.0000, 0.1852, 0.0000, 2.3544, 0.4991, 0.3674],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2637, 0.5050, 2.1840, 0.0000, 0.0000, 0.7402],
        [0.1266, 0.0000, 0.4667, 0.0000, 0.3489, 0.0000],
        [1.0668, 1.1004, 0.3868, 0.7628, 1.3871, 0.9825]])

In [5]:
import numpy as np
import torch

embeddings_per_feat=[100,2020,3203,3434,3023,4545,123,4566]
num_groups = 4

def fair_initialize(embeddings_per_feat=embeddings_per_feat,num_groups=num_groups) -> None:
    num_embeddings_per_rank = sum(embeddings_per_feat) // num_groups
    dim_indices = np.array(range(len(embeddings_per_feat)))
    groups = []
    offsets = []
    _curr_grp = []
    _curr_offs = [0]

    cuts = dict()
    _num_cuts = 0
    _agg = num_embeddings_per_rank
    # Find cut positions and shard groups
    for ind in dim_indices:
        while embeddings_per_feat[ind] > _agg:
            if _num_cuts >= num_groups - 1: # never cut when enough groups
                break
            if ind in cuts.keys():
                cuts[ind].append(_agg)
            else:
                cuts[ind] = [_agg]
            _num_cuts += 1
            
            offsets.append(torch.tensor(_curr_offs))
            _curr_offs = [0]
            _curr_grp.append(ind)
            groups.append(_curr_grp)
            _curr_grp = []
            
            _agg += num_embeddings_per_rank
        
        if _agg >= embeddings_per_feat[ind] and len(_curr_offs) == 1:
            _curr_offs.append(embeddings_per_feat[ind]-(_agg-num_embeddings_per_rank))
        else:
            _curr_offs.append(embeddings_per_feat[ind])
        
        _agg -= embeddings_per_feat[ind]
        _curr_grp.append(ind)
    
    offsets.append(torch.tensor(_curr_offs[:-1]))
    for i in range(len(offsets)):
        offsets[i] = torch.cumsum(offsets[i], dim=0)
    groups.append(_curr_grp)
        
    return num_embeddings_per_rank, cuts, groups, offsets

fair_initialize()

(5253,
 {2: [3133], 4: [1749], 5: [3979]},
 [[0, 1, 2], [2, 3, 4], [4, 5], [5, 6, 7]],
 [tensor([   0,  100, 2120]),
  tensor([   0,   70, 3504]),
  tensor([   0, 1274]),
  tensor([  0, 566, 689])])

In [39]:
# Complete load balancer
from recsys.modules.embeddings import LoadBalanceManager
    
lbmgr = LoadBalanceManager([100,2020,3203,3434302,3023,45459,123,4566],4,128,False)

In [40]:
rand_input = []

for i in range(len([100,2020,3203,3434302,3023,45459,123,4566])):
    rand_input.append(torch.randint(0,[100,2020,3203,3434302,3023,45459,123,4566][i],
                                            size=(8,)).unsqueeze(1))
    
rand_input_tensor = torch.cat(rand_input,dim=1)

In [29]:
rand_input_tensor

tensor([[     87,     573,     812, 2086251,    2501,   10109,     108,    2057],
        [      1,      29,    1133, 3358788,    1908,    1268,      15,    2393],
        [     10,    1185,    1839, 2375642,    1997,   20508,      17,    4385],
        [     93,     586,     916, 2344477,    1541,   19701,      84,    4052],
        [     54,    1904,    1948, 1378270,     570,   25398,      32,    1569],
        [     16,    1674,    1889,  471842,    1077,   38336,     117,    2978],
        [     60,    1044,    1953, 3246596,     625,   19882,      20,    2391],
        [     69,    1381,    1512, 1169965,     442,   29359,      71,    1383]])

In [45]:
lbmgr.embeddings_per_feat

[100, 2020, 3203, 3434302, 3023, 45459, 123, 4566]

In [46]:
lbmgr.offsets

[tensor([      0, 3434302]),
 tensor([   0, 2020]),
 tensor([    0, 45459]),
 tensor([   0, 4566])]

In [43]:
lbmgr.groups

[array([3, 0]), array([1, 2]), array([5, 6]), array([7, 4])]

In [41]:
lbmgr.get_num_embeddings_on_rank(0)

3434402

In [47]:
print(rand_input_tensor)
lbmgr.shard_tensor(rand_input_tensor,rank=0)

tensor([[     69,    1590,    2101, 3002450,    1443,   18711,       3,    4532],
        [     83,    1153,    2446, 1065773,     561,   31027,      58,    3101],
        [     40,     396,     967, 1962002,    1553,   24817,      54,      77],
        [     38,     888,     843, 2855171,    1787,   22663,     107,     395],
        [     42,    1190,    2661, 1749865,    2260,    2255,      96,    2929],
        [     76,     821,    1003, 1770283,    2643,   13591,     118,    2505],
        [     58,    1583,     151, 1194557,     109,   35343,     109,    3112],
        [     36,    1546,    2257, 2188749,     462,   29215,      51,     226]])


tensor([[3002450, 3434371],
        [1065773, 3434385],
        [1962002, 3434342],
        [2855171, 3434340],
        [1749865, 3434344],
        [1770283, 3434378],
        [1194557, 3434360],
        [2188749, 3434338]])