# Prototype hybrid embedding : data-parallel frequent categories and model- parallel infrequent categories

In [1]:
import numpy as np
from copy import deepcopy

In [2]:
batch_size = 64
num_slots = 10
num_nodes = 2
num_networks_per_node = 4

In [3]:
def flatten_data(data):
    # concatenate all iterations
    samples_data = np.concatenate([deepcopy(data[i][1]) for i in range(len(data))], axis=1)

    # data dimensions
    embedding_sizes = data[0][0]
    num_tables = samples_data.shape[0]
    num_samples = samples_data.shape[1]

    samples = np.zeros(num_tables * num_samples, dtype=np.int32)
    category_index_offset = 0
    for j in range(num_tables):
        for i in range(num_samples):
            samples[j*num_samples + i] =  category_index_offset + samples_data[j, i]
        category_index_offset += embedding_sizes[j]

    return samples

In [4]:
# Generate synthetic data
embed_sizes = np.random.randint(1, 2*batch_size, num_slots);
num_categories = sum(embed_sizes)
print("num_categories:", num_categories)

num_categories: 593


In [5]:
print(",".join(map(str, embed_sizes)))

31,42,40,64,73,55,26,124,117,21


In [6]:
data = []
num_batches = 15
for i in range(num_batches):
    batch = np.zeros((num_slots, batch_size))
    for j in range(num_slots):
        batch[j, :] = np.random.randint(0, embed_sizes[j], batch_size)
    data.append((embed_sizes, batch))

In [7]:
samples = flatten_data(data)

In [8]:
# configure nodes and gpus

class Gpu:

    def __init__(self):
        self.frequent_categories = None
        self.category_frequent_index = None
        self.frequent_embedding_vectors = None
        self.frequent_partial_gradients = None
        self.category_location = None
        self.node = None

    def init_embedding_cache(self, num_frequent, embedding_vec_size):
        self.num_frequent = num_frequent
        self.frequent_embedding_vectors = np.zeros(num_frequent*embedding_vec_size, dtype=np.float32)
        self.frequent_partial_gradients = np.zeros(num_frequent*embedding_vec_size, dtype=np.float32)
        
class Node:

    def __init__(self, num_gpus):
        self.gpus = [Gpu() for i in range(num_gpus)]
        for i in range(num_gpus):
            self.gpus[i].gpu_id = i
            self.gpus[i].node = self # reference to this node

class Network:

    def __init__(self, nodes):
        self.nodes = nodes

    def all_reduce(self):
        pass

    def all_to_all(self):
        pass

In [9]:
nodes = [Node(num_networks_per_node) for i in range(num_nodes)]
gpus = [gpu for node in nodes for gpu in node.gpus]
num_gpus = len(gpus)
for i in range(num_nodes):
    nodes[i].node_id = i
network = Network(nodes)

In [10]:
uniques, counts = np.unique(samples, return_counts=True)
# sorted_uniques = sorted(zip(counts, uniques), key=lambda x: -x[0])
threshold = num_batches
mask = counts > threshold
frequent = set(uniques[mask])
num_frequent = len(frequent)
print(num_frequent, "/", num_categories)

237 / 593


In [11]:
# category_location
num_infrequent = num_categories - num_frequent
category_location = num_categories * np.ones((num_categories,2), dtype=np.int32)
infrequent_index = np.zeros(num_categories)
infrequent_mask = np.ones(num_categories, dtype=bool)
for c in frequent:
    infrequent_mask[c] = 0
infrequent_index[infrequent_mask] = range(num_infrequent)
for c in range(num_categories):
    if c not in frequent:
        index = infrequent_index[c]
        category_location[c,:] = [index % num_gpus, index // num_gpus]

In [12]:
for gpu in gpus:
    gpu.category_location = category_location

In [13]:
n_display = 20
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'
print(f'{color.BOLD}{color.RED}category          |->  category location {color.END}')
for category in range(n_display):
    location = category_location[category,:]
    if location[0] < num_categories:
        print(f'category {color.BOLD}{category:3d}{color.END}      |->  category location {color.BOLD}{color.GREEN}{location}{color.END}')
    else:
        print(f'category {color.BOLD}{category:3d}{color.END}      |->  category location   {color.BOLD}{color.RED}END{color.END}')

[1m[91mcategory          |->  category location [0m
category [1m  0[0m      |->  category location   [1m[91mEND[0m
category [1m  1[0m      |->  category location   [1m[91mEND[0m
category [1m  2[0m      |->  category location   [1m[91mEND[0m
category [1m  3[0m      |->  category location   [1m[91mEND[0m
category [1m  4[0m      |->  category location   [1m[91mEND[0m
category [1m  5[0m      |->  category location   [1m[91mEND[0m
category [1m  6[0m      |->  category location   [1m[91mEND[0m
category [1m  7[0m      |->  category location   [1m[91mEND[0m
category [1m  8[0m      |->  category location   [1m[91mEND[0m
category [1m  9[0m      |->  category location   [1m[91mEND[0m
category [1m 10[0m      |->  category location   [1m[91mEND[0m
category [1m 11[0m      |->  category location   [1m[91mEND[0m
category [1m 12[0m      |->  category location   [1m[91mEND[0m
category [1m 13[0m      |->  category location   [1m[91mE

In [14]:
assert np.sum(category_location[:,0] != num_categories) == num_infrequent

# Index calculations

In [18]:
from bisect import bisect_left

def get_node_gpu(node_id, gpu_id):
    # not efficient, but that's not the point here! :P
    node = None
    gpu = None
    for node_ in nodes:
        if node_.node_id == node_id:
            node = node_
            break
    for gpu_ in node.gpus:
        if gpu_.gpu_id == gpu_id:
            gpu = gpu_
            break
    return node, gpu

def get_network_id(node_id, gpu_id):
    for i in range(len(gpus)):
        if gpus[i].node.node_id == node_id and gpus[i].gpu_id == gpu_id:
            return i
    raise KeyError(f"Not found: node {node_id}, GPU {gpu_id}")

def cub_DeviceSelect(gpu, samples, network_id):
    location = gpu.category_location[samples,:]
    samples_mask = (location[:,0] == network_id)
    samples_filter = np.r_[:samples.size][samples_mask]
    return samples_filter

# model indices: forward-send, backward-receive
def calculate_model_indices(samples, node_id, gpu_id):
    _, gpu = get_node_gpu(node_id, gpu_id)
    network_id = get_network_id(node_id, gpu_id)
    section_size = samples.size // num_gpus

    sample_model_indices = cub_DeviceSelect(gpu, samples, network_id)
    network_offset_model_indices = np.zeros(num_gpus + 1, dtype=np.int32)
    for i in range(num_gpus):
        network_offset_model_indices[i] = bisect_left(sample_model_indices, i * section_size)
    network_offset_model_indices[-1] = sample_model_indices.size

    return sample_model_indices, network_offset_model_indices

# network indices: forward-receive, backward-send
def calculate_network_indices(samples, node_id, gpu_id):
    _, gpu = get_node_gpu(node_id, gpu_id)

    section_size = samples.size // num_gpus
    network_id = get_network_id(node_id, gpu_id)
    start_idx = network_id * section_size
    end_idx = min((network_id + 1) * section_size, samples.size)
    sub_batch = samples[start_idx:end_idx]

    location = gpu.category_location[sub_batch,:]
    samples_mask = location[:,0] < num_categories
    infrequent_indices = deepcopy(np.r_[:sub_batch.size][samples_mask])
    network_indices = deepcopy(location[:, 0][samples_mask])
    sorted_indices = np.array(sorted(zip(network_indices, infrequent_indices),
                                     key=lambda x: x[0]))

    sample_network_offsets = np.zeros(num_gpus + 1, dtype=np.int32)
    if len(network_indices):
        sample_network_indices = sorted_indices[:,1]
        for i in range(num_gpus):
            sample_network_offsets[i] = bisect_left(sorted_indices[:,0], i)
    else:
        sample_network_indices = np.zeros(0)
    sample_network_offsets[-1] = len(network_indices)
    
    return sample_network_indices, sample_network_offsets

In [19]:
iteration = 0
batch = flatten_data([data[0]])

In [20]:
model_indices = {}
model_indices_offsets = {}
for node_ in nodes:
    for gpu_ in node_.gpus:
        node_id = node_.node_id
        gpu_id = gpu_.gpu_id
        idx, off = calculate_model_indices(batch, node_id, gpu_id)
        model_indices[(node_id, gpu_id)] = idx
        model_indices_offsets[(node_id, gpu_id)] = off

# print(model_indices)
# print(model_indices_offsets)

network_indices = {}
network_indices_offsets = {}
for node_ in nodes:
    for gpu_ in node_.gpus:
        node_id = node_.node_id
        gpu_id = gpu_.gpu_id
        idx, off = calculate_network_indices(batch, node_id, gpu_id)
        network_indices[(node_id, gpu_id)] = idx
        network_indices_offsets[(node_id, gpu_id)] = off

# print(network_indices)
# print(network_indices_offsets)

In [21]:
print(list(np.reshape(category_location, 2*num_categories)))

[593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 0, 0, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 1, 0, 593, 593, 593, 593, 593, 593, 2, 0, 593, 593, 593, 593, 593, 593, 3, 0, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 593, 

In [22]:
print(list(batch))

[20, 24, 7, 26, 19, 20, 22, 28, 13, 27, 23, 16, 19, 8, 10, 14, 5, 17, 20, 9, 25, 27, 28, 27, 5, 25, 5, 26, 27, 12, 8, 6, 5, 10, 6, 26, 12, 6, 9, 19, 19, 5, 0, 27, 30, 22, 3, 26, 3, 10, 13, 5, 13, 24, 14, 30, 22, 0, 24, 9, 2, 18, 13, 9, 65, 62, 62, 46, 40, 67, 61, 70, 72, 63, 33, 66, 55, 37, 42, 70, 31, 68, 67, 47, 69, 58, 72, 70, 49, 54, 39, 37, 71, 51, 72, 56, 54, 53, 65, 40, 48, 43, 32, 60, 67, 52, 58, 52, 41, 66, 50, 39, 69, 54, 54, 59, 55, 49, 69, 50, 32, 58, 50, 36, 38, 38, 57, 46, 102, 111, 110, 102, 91, 91, 91, 91, 81, 97, 80, 85, 79, 89, 100, 78, 80, 90, 109, 109, 79, 74, 110, 108, 108, 90, 79, 96, 97, 96, 109, 81, 108, 84, 98, 92, 93, 87, 104, 100, 81, 95, 99, 90, 99, 78, 82, 91, 86, 78, 93, 94, 95, 91, 80, 107, 110, 101, 86, 77, 111, 112, 84, 77, 160, 164, 151, 146, 117, 146, 147, 133, 175, 175, 118, 148, 147, 163, 141, 147, 154, 164, 165, 116, 171, 158, 128, 167, 163, 170, 149, 154, 129, 169, 124, 147, 150, 113, 154, 155, 172, 176, 169, 115, 136, 151, 159, 125, 131, 174, 119

In [23]:
for key in sorted(model_indices.keys()):
    print("{%s}," % ','.join(map(str, model_indices[key])))

{215,245,249,259,266,270,273,283,292,332,449,453,455,460,476,483,492,501,502,516,524,527,530,539,542,546,547,558,574},
{95,208,219,222,226,261,262,272,275,281,284,291,293,304,315,316,318,323,456,464,479,486,498,500,510,521,543,549,555,564,566,570},
{103,217,227,235,277,282,286,313,350,367,450,458,478,481,488,490,507,520,529,559,563,567},
{228,254,260,267,268,274,294,300,301,317,319,372,457,466,468,477,487,489,494,505,509,513,514,544,550,553,556,560,561,575},
{213,220,248,253,264,271,297,299,309,357,373,465,469,471,473,503,504,506,517,519,534,536,562},
{192,196,237,298,329,344,354,382,448,454,463,470,480,484,493,508,512,518,522,523,525,533,540,541,551,565,572},
{202,205,216,247,296,306,307,311,346,369,377,452,459,467,475,495,496,497,499,515,526,531,535,538,545,552},
{193,203,209,232,238,256,265,280,285,289,290,303,305,310,322,340,360,451,461,462,472,474,482,485,491,511,528,532,537,548,554,557,568,569,571,573},


In [24]:
for key in sorted(model_indices_offsets.keys()):
    print("{%s}," % ','.join(map(str, model_indices_offsets[key])))

{0,0,0,1,9,10,15,28,29},
{0,0,1,5,17,18,21,29,32},
{0,0,1,4,8,10,13,20,22},
{0,0,0,1,11,12,16,27,30},
{0,0,0,2,9,11,15,22,23},
{0,0,0,3,4,8,12,25,27},
{0,0,0,3,8,11,15,26,26},
{0,0,0,5,14,17,22,32,36},


In [25]:
for key in sorted(network_indices.keys()):
    print("{%s}," % ','.join(map(str, network_indices[key])))

{},
{15,23},
{55,48,59,62,66,57,67,75,68,53,60,32,36,77,42,45,56,33,43,49,72,78},
{5,9,19,26,30,33,43,52,21,22,32,35,41,44,51,53,64,75,76,78,37,42,46,73,14,20,27,28,34,54,60,61,77,79,8,13,24,31,57,59,69,58,7,56,66,67,71,16,25,40,45,49,50,63,65,70},
{12,3,30,47,52,37,53,9,24,34,62,26,49,57,2,20,40},
{49,53,55,60,76,56,64,79,50,58,78,57,66,68,77,65,69,71,73,48,54,63,70,52,59,67,75,51,61,62,72,74},
{3,12,21,22,36,44,47,50,59,62,66,67,78,6,18,20,30,41,63,69,75,1,8,10,27,40,49,79,7,9,14,25,29,33,34,64,70,73,76,23,24,26,37,39,54,56,0,4,13,28,32,38,42,43,45,53,60,61,71,15,16,17,19,35,46,51,55,58,65,72,2,5,11,31,48,52,57,68,74,77},
{14,4,6,10,3,7,0,1,15,2,5,12,8,9,11,13},


In [26]:
for key in sorted(network_indices_offsets.keys()):
    print("{%s}," % ','.join(map(str, network_indices_offsets[key])))

{0,0,0,0,0,0,0,0,0},
{0,0,1,2,2,2,2,2,2},
{0,1,5,8,9,11,14,17,22},
{0,8,20,24,34,41,42,47,56},
{0,1,2,4,5,7,11,14,17},
{0,5,8,11,15,19,23,27,32},
{0,13,21,28,39,46,59,70,80},
{0,1,4,6,9,10,12,12,16},
