In [6]:
import sys
import os

from typing import List, Optional, Tuple, Union

import time
import torch
import torch.distributed as dist
import numpy as np
import random
import json

from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset, concatenate_datasets
from transformers import (
    AutoTokenizer,
    DataCollatorWithPadding,
)
from dataclasses import dataclass

  from .autonotebook import tqdm as notebook_tqdm


# 1. sst2_10000_mrpc_2000_MixtralMoE_patterns

In [8]:
dataset = load_dataset("marsggbo/sst2_10000_mrpc_2000_MixtralMoE_patterns")['train']
print(len(dataset), dataset[0].keys()) 

12000 dict_keys(['source', 'prompt_len', 'token_idx', 'token_expert_patterns', 'sentence_expert_pattern'])


## 1.1 查看每个句子的pattern情况
每层 token 的最大值和最小值差距有一般都有 2 到 3 倍左右

In [21]:
for index in list(range(0, 10000, 1000))+list(range(10000, 12000, 200)):
    prompt_token_pattern_list = np.array([dataset[index]['token_expert_patterns'][i] for i in range(dataset[index]['prompt_len'])])
    prompt_sentence_pattern = np.sum(prompt_token_pattern_list,0)
    print(index, dataset[index]['source'], prompt_sentence_pattern.max(-1), prompt_sentence_pattern.min(-1))
    print('='*10)

0 sst2 [23 23 17 21 19 20 21 22 26 20 19 18 20 27 19 21 20 18 22 18 23 22 27 21
 22 19 19 17 23 22 27 23] [6 2 9 8 6 7 6 5 4 1 9 0 8 6 6 6 9 9 6 5 6 3 0 3 4 5 7 9 5 6 1 6]
1000 sst2 [23 23 18 22 19 20 20 21 27 21 20 19 20 28 19 21 21 18 22 19 22 25 28 22
 22 21 19 18 25 22 27 24] [ 7  4  9  7  6  9 10  4  5  1  9  0  9  7  6  6  8  9  8  4  6  3  0  3
  5  6  8  9  5  7  2  7]
2000 sst2 [23 24 20 23 19 24 20 20 29 22 20 19 22 28 18 21 22 18 23 19 24 23 29 21
 23 22 19 18 25 22 27 24] [ 8  4  9  7  6  7  9  5  4  1 10  0  8  7  6  6  8 10  6  6  7  3  1  5
  6  5  9  8  5  7  3  8]
3000 sst2 [23 22 20 24 22 24 22 22 26 23 20 19 23 28 19 22 22 18 23 21 27 25 30 24
 22 19 21 18 26 22 27 26] [ 8  2 10  8  7  8  9  6  6  1  9  0 10  7  6  6 10 10  7  6  5  4  2  3
  7  8  7  8  8  8  3  8]
4000 sst2 [23 22 19 23 19 20 24 21 29 25 19 21 26 28 21 21 20 18 23 20 28 22 29 26
 24 21 22 19 25 24 30 25] [ 7  3 10  8  7  7 10  7  4  1  9  0  8  7  6  6 10 10  6  6  5  6  4  4
  8  9  7  9  8 10  3 

## 1.2 查看乱序情况下一个 batch 的 pattern 情况

数据集 shuffle，batch size=16/32/128/256 的情况下，每行最多 token 数和最小 token 数的差距也都是在 2 倍左右。

另外，当batch size=16 的时候，最小 token 数就超过了 128，最大 token 数也接近 512 了，这表明 normal 情况下，expert 的计算 latency 横跨了 3 个 level。实测 0-128 个 token 为 level-1， 128-256为 level-2

In [12]:
indices = list(range(len(dataset)))
np.random.shuffle(indices)
batch_sizes = [
    16,
    32,
    # 64,
    # 128,
    # 256
]

for bs in batch_sizes:
    for i in range(5):
        # 查看 10 个 batch 的情况
        start_idx = bs * i
        end_idx = start_idx + bs
        batch_indices = indices[start_idx:end_idx]
        batch_samples = dataset.select(batch_indices)
        prompt_sentence_pattern_list = []
        for i in range(bs):
            prompt_token_pattern_list = np.array(batch_samples[i]['token_expert_patterns'][:batch_samples[i]['prompt_len']])
            prompt_token_pattern_list = torch.from_numpy(prompt_token_pattern_list).cuda()
            prompt_sentence_pattern_list.append(torch.sum(prompt_token_pattern_list,0))
        batch_prompt_sentence_pattern = torch.stack(prompt_sentence_pattern_list).sum(0)
        print(f"BS={bs}", batch_prompt_sentence_pattern.max(-1)[0], batch_prompt_sentence_pattern.min(-1)[0])
        print('='*10)

BS=16 tensor([441, 418, 387, 428, 378, 417, 419, 385, 546, 424, 381, 411, 457, 518,
        396, 384, 413, 354, 412, 427, 495, 441, 525, 410, 472, 427, 409, 366,
        408, 413, 488, 426], device='cuda:0') tensor([211, 155, 188, 221, 207, 215, 233, 207, 132,  63, 229,  54, 219, 158,
        137, 212, 264, 209, 199, 174, 205, 126,  79, 168, 192, 199, 195, 198,
        193, 167, 119, 184], device='cuda:0')
BS=16 tensor([468, 443, 431, 458, 409, 465, 446, 458, 607, 463, 458, 453, 475, 553,
        416, 446, 424, 429, 452, 438, 574, 519, 502, 515, 524, 452, 423, 419,
        425, 475, 495, 440], device='cuda:0') tensor([222, 206, 214, 254, 248, 261, 259, 272,  96, 131, 223,  37, 238, 193,
        179, 265, 288, 238, 220, 211, 215, 155, 126, 146, 144, 213, 196, 211,
        236, 204, 158, 224], device='cuda:0')
BS=16 tensor([449, 416, 397, 435, 381, 414, 419, 433, 551, 423, 404, 411, 437, 530,
        378, 397, 398, 365, 456, 396, 502, 476, 509, 425, 459, 419, 384, 366,
        412, 446, 

In [5]:
indices = list(range(len(dataset)))
np.random.shuffle(indices)
batch_sizes = [
    # 16,
    # 32,
    64,
    128,
    256
]

for bs in batch_sizes:
    for i in range(5):
        # 查看 10 个 batch 的情况
        start_idx = bs * i
        end_idx = start_idx + bs
        batch_indices = indices[start_idx:end_idx]
        batch_samples = dataset.select(batch_indices)
        prompt_sentence_pattern_list = []
        for i in range(bs):
            prompt_token_pattern_list = np.array(batch_samples[i]['token_expert_patterns'][:batch_samples[i]['prompt_len']])
            prompt_token_pattern_list = torch.from_numpy(prompt_token_pattern_list).cuda()
            prompt_sentence_pattern_list.append(torch.sum(prompt_token_pattern_list,0))
        batch_prompt_sentence_pattern = torch.stack(prompt_sentence_pattern_list).sum(0)
        print(f"BS={bs}", batch_prompt_sentence_pattern.max(-1)[0], batch_prompt_sentence_pattern.min(-1)[0])
        print('='*10)

BS=64 tensor([1771, 1660, 1524, 1740, 1462, 1643, 1536, 1620, 2160, 1616, 1585, 1568,
        1729, 2055, 1522, 1618, 1636, 1363, 1680, 1638, 1944, 1763, 2017, 1617,
        1815, 1589, 1519, 1388, 1648, 1703, 1965, 1680], device='cuda:0') tensor([875, 631, 775, 905, 807, 898, 933, 894, 546, 345, 799,  98, 875, 663,
        621, 761, 961, 800, 795, 672, 824, 544, 369, 591, 672, 768, 715, 857,
        769, 733, 467, 733], device='cuda:0')
BS=64 tensor([1799, 1675, 1507, 1711, 1509, 1664, 1582, 1552, 2220, 1613, 1506, 1594,
        1782, 2047, 1477, 1612, 1637, 1390, 1704, 1622, 1904, 1767, 2051, 1605,
        1768, 1557, 1475, 1379, 1643, 1699, 1979, 1705], device='cuda:0') tensor([836, 565, 818, 943, 767, 847, 859, 783, 502, 248, 836,  83, 798, 634,
        602, 676, 969, 821, 766, 601, 748, 551, 381, 597, 683, 787, 733, 856,
        765, 759, 411, 707], device='cuda:0')
BS=64 tensor([1748, 1627, 1429, 1668, 1478, 1581, 1552, 1543, 2107, 1517, 1442, 1514,
        1675, 2019, 1444, 1549

## 1.3 查看按prompt长度排序后一个batch的pattern情况

In [9]:
sample_indices = list(range(0, 10000, 10)) + list(range(10000, 12000, 2))
print(f"#samples={len(sample_indices)}")
prompt_len_list = torch.tensor(dataset.select(sample_indices)['prompt_len']).cuda()
sorted_indices = prompt_len_list.sort().indices.to('cpu').numpy()


#samples=2000


In [13]:
prompt_len_list[sorted_indices][::10]

tensor([ 52,  53,  53,  53,  53,  53,  54,  54,  54,  54,  54,  54,  54,  54,
         55,  55,  55,  55,  55,  55,  55,  55,  55,  56,  56,  56,  56,  56,
         56,  56,  57,  57,  57,  57,  57,  57,  58,  58,  58,  58,  58,  58,
         59,  59,  59,  59,  59,  60,  60,  60,  60,  60,  61,  61,  61,  61,
         62,  62,  62,  62,  63,  63,  63,  63,  64,  64,  64,  65,  65,  66,
         66,  66,  67,  67,  68,  68,  68,  69,  70,  70,  71,  71,  72,  73,
         73,  74,  75,  76,  77,  77,  79,  80,  81,  82,  83,  85,  87,  90,
         92,  95,  98,  99,  99, 100, 100, 101, 102, 102, 103, 104, 104, 105,
        105, 106, 106, 106, 107, 107, 108, 108, 109, 109, 110, 110, 111, 111,
        112, 112, 113, 113, 114, 114, 115, 115, 116, 116, 117, 117, 117, 118,
        118, 119, 119, 119, 120, 120, 121, 121, 121, 122, 122, 123, 123, 124,
        124, 124, 125, 125, 126, 126, 126, 127, 127, 128, 128, 128, 129, 129,
        130, 130, 131, 131, 132, 132, 133, 133, 133, 134, 134, 1

In [10]:

batch_sizes = [
    16,
    32,
    # 64,
    128,
    256
]

for bs in batch_sizes:
    for i in range(5):
        # 查看 10 个 batch 的情况
        start_idx = bs * i
        end_idx = start_idx + bs
        batch_indices = sorted_indices[start_idx:end_idx]
        batch_samples = dataset.select(batch_indices)
        prompt_sentence_pattern_list = []
        for i in range(bs):
            prompt_token_pattern_list = np.array(batch_samples[i]['token_expert_patterns'][:batch_samples[i]['prompt_len']])
            prompt_token_pattern_list = torch.from_numpy(prompt_token_pattern_list).cuda()
            prompt_sentence_pattern_list.append(torch.sum(prompt_token_pattern_list,0))
        batch_prompt_sentence_pattern = torch.stack(prompt_sentence_pattern_list).sum(0)
        print(f"BS={bs}", batch_prompt_sentence_pattern.max(-1)[0], batch_prompt_sentence_pattern.min(-1)[0])
        print('='*10)

BS=16 tensor([373, 360, 279, 337, 305, 331, 337, 329, 409, 320, 304, 295, 312, 433,
        305, 337, 320, 277, 352, 297, 377, 353, 443, 341, 359, 309, 291, 277,
        380, 352, 432, 372], device='cuda:0') tensor([ 95,  38, 145, 133, 103, 117, 122,  71,  65,  16, 145,   0, 128,  98,
         96,  96, 135, 147, 105,  67,  79,  44,   8,  54,  77,  91, 118, 134,
         89,  97,  23, 101], device='cuda:0')
BS=16 tensor([369, 364, 278, 336, 311, 330, 337, 326, 409, 324, 306, 297, 314, 433,
        304, 336, 321, 275, 352, 298, 375, 357, 441, 341, 356, 317, 291, 275,
        378, 353, 432, 369], device='cuda:0') tensor([ 94,  42, 146, 130,  97, 117, 123,  70,  66,  18, 144,   0, 131,  97,
         96,  98, 131, 149, 108,  63,  74,  37,   7,  61,  67,  89, 117, 137,
         86,  99,  20, 102], device='cuda:0')
BS=16 tensor([374, 362, 281, 337, 308, 327, 340, 325, 411, 321, 306, 297, 307, 442,
        305, 336, 319, 276, 352, 297, 375, 353, 446, 338, 357, 310, 291, 275,
        373, 355, 