In [3]:
import torch
import numpy as np


In [3]:
data = torch.load('pattern_matrices.pt')
print(len(data), data[0].keys())

2000 dict_keys(['prompt_text', 'prompt_token_ids', 'prompt_attention_mask', 'token_ids', 'token_pattern_matrices'])


# 1. 统计数据集信息

1. prompt 统计信息

In [12]:
prompt_lens = np.array([len(sample['prompt_token_ids']) for sample in data.values()])
prompt_lens.min(), prompt_lens.max(), prompt_lens.mean(), prompt_lens.std()

(41, 382, 65.696, 30.360823177246033)

2. decoding 统计信息

In [22]:
decoding_tokens = []
for i in range(len(data)):
    sample = data[i]
    decoding_tokens.append(sample['token_ids'][prompt_lens[i]:])
decoding_tokens = torch.stack(decoding_tokens)
print(decoding_tokens.shape)
decoding_token_lens = np.array([len(x) for x in decoding_tokens])
decoding_token_lens.min(), decoding_token_lens.max(), decoding_token_lens.mean(), decoding_token_lens.std()

torch.Size([2000, 64])


(64, 64, 64.0, 0.0)

In [23]:
decoding_token_pattern_matrices = []
for i in range(len(data)):
    sample = data[i]
    decoding_token_pattern_matrices.append(sample['token_pattern_matrices'][prompt_lens[i]:]) # (#decoding_tokens, #layers, #experts)
decoding_token_pattern_matrices = torch.stack(decoding_token_pattern_matrices)
decoding_token_pattern_matrices.shape, decoding_token_pattern_matrices[0].shape, decoding_token_pattern_matrices[0][0]

(torch.Size([2000, 63, 32, 8]),
 torch.Size([63, 32, 8]),
 tensor([[1., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 1., 0., 0., 0., 0., 1.],
         [0., 0., 0., 1., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 1., 1., 0.],
         [0., 0., 0., 0., 1., 1., 0., 0.],
         [1., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 1., 1., 0., 0.],
         [0., 0., 0., 1., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0., 1., 0., 0.],
         [0., 0., 0., 0., 1., 0., 1., 0.],
         [0., 0., 1., 0., 1., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 1., 0.],
         [1., 0., 1., 0., 0., 0., 0., 0.],
         [0., 1., 1., 0., 0., 0., 0., 0.],
         [0., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 1., 0.],
         [0., 0., 0., 0., 1., 0., 0., 1.],
         [0., 1., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 1.],
         [1., 0., 0., 0., 0., 0., 1., 0.],
         [1., 0., 0., 0., 0., 1., 0., 0

# 构建 Predictor 训练集

In [57]:
from datasets import Dataset
hf_data = {
    'prompt_text': [],
    'prompt_tokens_len': [],
    'token_ids': [],
    # 'prompt_token_ids': [],
    # 'decoding_token_ids': [],
    'token_pattern_matrices': []
}

# for i in range(10):
for i in range(len(data)):
    sample = data[i]
    prompt_text = sample['prompt_text']
    padded_prompt_token_ids = sample['prompt_token_ids']
    prompt_attention_mask = sample['prompt_attention_mask']
    start_index = prompt_attention_mask.argmax().item()
    token_ids = sample['token_ids'][start_index:-1]
    prompt_token_ids = padded_prompt_token_ids[start_index:]
    prompt_tokens_len = len(prompt_token_ids)
    decoding_token_ids = sample['token_ids'][len(prompt_attention_mask):-1]
    token_pattern_matrices = sample['token_pattern_matrices'][start_index:]
    assert len(token_ids)==len(decoding_token_ids)+len(prompt_token_ids)
    assert token_ids.numpy().tolist()==prompt_token_ids.numpy().tolist()+decoding_token_ids.numpy().tolist()
    assert len(token_pattern_matrices)==len(token_ids)
    hf_data['prompt_text'].append(prompt_text)
    hf_data['prompt_tokens_len'].append(prompt_tokens_len)
    hf_data['token_ids'].append(token_ids)
    # hf_data['prompt_token_ids'].append(prompt_token_ids)
    # hf_data['decoding_token_ids'].append(decoding_token_ids)
    hf_data['token_pattern_matrices'].append(token_pattern_matrices)
hf_data = Dataset.from_dict(hf_data)
hf_data

Dataset({
    features: ['prompt_text', 'prompt_tokens_len', 'token_ids', 'token_pattern_matrices'],
    num_rows: 2000
})

In [58]:
prompt_token_ids.shape, prompt_attention_mask, prompt_attention_mask.sum(), prompt_attention_mask.argmax().item(), decoding_token_ids.shape, token_ids.shape, token_pattern_matrices.shape

(torch.Size([382]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        

In [59]:
hf_data.push_to_hub('marsggbo/mixtral_8x7b_moe_alpaca_2k_token_pattern')

Creating parquet from Arrow format: 100%|██████████| 2/2 [00:01<00:00,  1.08ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:04<00:00,  4.94s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/marsggbo/mixtral_8x7b_moe_alpaca_2k_token_pattern/commit/2fc2ba69bd0cf6b5f5dba433d10f5641ba53048f', commit_message='Upload dataset', commit_description='', oid='2fc2ba69bd0cf6b5f5dba433d10f5641ba53048f', pr_url=None, pr_revision=None, pr_num=None)

In [61]:
from datasets import load_dataset

dataset = load_dataset("marsggbo/mixtral_8x7b_moe_alpaca_2k_token_pattern")
dataset

Downloading readme: 100%|██████████| 459/459 [00:00<00:00, 4.31MB/s]
Downloading data: 100%|██████████| 8.02M/8.02M [00:00<00:00, 8.42MB/s]
Generating train split: 100%|██████████| 2000/2000 [00:01<00:00, 1911.88 examples/s]


DatasetDict({
    train: Dataset({
        features: ['prompt_text', 'prompt_tokens_len', 'token_ids', 'token_pattern_matrices'],
        num_rows: 2000
    })
})

In [66]:
from datasets import Dataset

In [63]:
lens = [x for x in dataset['train']['prompt_tokens_len']]
print(np.min(lens), np.max(lens), np.mean(lens), np.std(lens))
# lens = [len(x) for x in dataset['test']['prompt_token_ids']]
# print(np.min(lens), np.max(lens), np.mean(lens), np.std(lens))

37 382 59.2955 21.39596176267849


In [69]:

from typing import List, Optional, Tuple, Union
Union[str, Dataset]

typing.Union[str, datasets.arrow_dataset.Dataset]

In [65]:
sample = dataset['train'][0]
prompt_tokens_len = sample['prompt_tokens_len']
len(sample['token_ids'][:prompt_tokens_len]), np.stack(sample['token_pattern_matrices']).shape

(39, (102, 32, 8))

In [77]:
labels = np.stack(sample['token_pattern_matrices'])
labels = torch.from_numpy(labels).int()
labels

tensor([[[0, 1, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 1],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 1]],

        [[0, 0, 0,  ..., 1, 1, 0],
         [1, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 1, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 1, 0]],

        [[0, 0, 0,  ..., 1, 1, 0],
         [0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 1, 0, 0],
         ...,
         [1, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 1, 0],
         [0, 0, 1,  ..., 0, 1, 0]],

        ...,

        [[0, 0, 1,  ..., 1, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 1, 1],
         ...,
         [0, 0, 0,  ..., 1, 0, 0],
         [0, 1, 0,  ..., 0, 1, 0],
         [0, 0, 0,  ..., 0, 1, 0]],

        [[1, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 1, 0, 1],
         [1,

In [73]:
torch.tensor(sample['token_ids'], dtype=int)

tensor([    1, 20811,   349,   396, 13126,   369, 13966,   264,  3638, 28723,
        12018,   264,  2899,   369,  6582,  1999,  2691,   274,   272,  2159,
        28723,    13,    13, 27332,  3133,  3112, 28747,    13, 28784,   648,
        28705, 28770,   327,  1550,    13,    13, 27332, 12107, 28747,    13,
        28784,   648, 28705, 28770, 21588, 28705, 28774, 28723,     2, 10020,
        28744,     2, 10020, 28744,    13,    13, 27332, 11530, 12107, 28747,
           13, 28774, 28723,     2, 10020, 28744,     2, 10020, 28744,    13,
           13, 27332,  1529, 11009,   352, 28747,    13,  7477,   368,   967,
        28705, 28784,   304, 28705, 28770,  2553, 28725,   368,   625, 28705,
        28774, 28723,   415,  1474, 28705, 28774,   349,   272,  2648,   302,
        28705, 28784])

In [15]:
def acc_precision_recall_f1(y_true_origin, y_pred_origin):
    bs, seq_len, num_layer, num_experts = y_true_origin.shape
    y_true = np.reshape(y_true_origin, (bs, seq_len, num_layer * num_experts))
    y_pred = np.reshape(y_pred_origin, (bs, seq_len, num_layer * num_experts))
    y_true = np.transpose(y_true, (1, 0, 2)) # (seq, bs, num_layer * num_experts)
    y_pred = np.transpose(y_pred, (1, 0, 2)) # (seq, bs, num_layer * num_experts)
    y_true = (y_true.sum(1)>0).astype(int) # (seq, num_layer * num_experts)
    y_pred = (y_pred.sum(1)>0).astype(int) # (seq, num_layer * num_experts)
    print(y_true.shape, y_true)
    # 真正例 (True Positives)
    TP = np.sum((y_true == 1) & (y_pred == 1))
    
    # 假正例 (False Positives)
    FP = np.sum((y_true == 0) & (y_pred == 1))
    
    # 假负例 (False Negatives)
    FN = np.sum((y_true == 1) & (y_pred == 0))
    
    # 真负例 (True Negatives)
    TN = np.sum((y_true == 0) & (y_pred == 0))

    y_true = y_true.reshape(-1, 256)
    y_pred = y_pred.reshape(-1, 256)
    print(f"origin y_true.shape={y_true.shape}")
    indices = np.any(y_true, axis=-1)
    print(indices.shape)
    y_true = y_true[indices]
    y_pred = y_pred[indices]
    print(f"filtered y_true.shape={y_true.shape}")

    # 准确率
    num_tokens = y_true.shape[0]
    accuracy = TP / (num_tokens*64)
    recall = 0
    precision = 0
    f1 = 0
    print(f"non-padding ratio: {indices.sum()}/{len(indices)}={indices.sum()/len(indices)}\n")

    return {
        'accuracy': accuracy,
        'recall': recall,
        'precision': precision,
        'f1': f1,
    }

In [16]:
seq_len = 12
pad_seq_len = 4 
y_true = torch.randint(0,2,(4,8,32,8)).numpy()
y_pred = torch.randint(0,2,(4,8,32,8)).numpy()
print(acc_precision_recall_f1(y_true, y_pred))

(8, 256) [[1 1 1 ... 1 1 1]
 [1 1 0 ... 1 1 1]
 [1 1 1 ... 1 1 1]
 ...
 [0 1 1 ... 1 1 0]
 [1 1 1 ... 1 1 1]
 [1 1 1 ... 1 1 1]]
origin y_true.shape=(8, 256)
(8,)
filtered y_true.shape=(8, 256)
non-padding ratio: 8/8=1.0

{'accuracy': 3.46875, 'recall': 0, 'precision': 0, 'f1': 0}


In [6]:
dataset_names = {
    "auto_categorization": 328,
    "tense": 286,
    "disfl_qa": 8000,
    "semantic_parsing_in_context_sparc": 1160,
    "word_sorting": 1900,
    "linguistics_puzzles": 2000,
}

In [14]:
import datasets
dataset_name = "tasksource/bigbench"
# names = datasets.get_dataset_config_names(dataset_name)

names = list(dataset_names.keys())
all_inputs = []
for name in names:
    print(name)
    all_inputs.append(datasets.load_dataset(dataset_name, name))

auto_categorization


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


tense
disfl_qa
semantic_parsing_in_context_sparc
word_sorting
linguistics_puzzles


In [15]:
all_inputs[0]

DatasetDict({
    train: Dataset({
        features: ['inputs', 'targets', 'multiple_choice_targets', 'multiple_choice_scores', 'idx'],
        num_rows: 263
    })
    validation: Dataset({
        features: ['inputs', 'targets', 'multiple_choice_targets', 'multiple_choice_scores', 'idx'],
        num_rows: 65
    })
})

In [18]:

train_all_inputs = []
valid_all_inputs = []
for dataset in all_inputs:
    train_all_inputs += [text for text in dataset["train"]["inputs"]]
    valid_all_inputs += [text for text in dataset["validation"]["inputs"]]
len(train_all_inputs), len(valid_all_inputs)

(10936, 2733)

In [1]:
import torch
import numpy as np
import datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = datasets.load_dataset("marsggbo/bigbench4switch64_pattern_predictor")['train']
dataset

Dataset({
    features: ['prompt_text', 'prompt_ids', 'decode_ids', 'prompt_pattern', 'decode_pattern'],
    num_rows: 10936
})

In [3]:
np.array(dataset[0]['decode_pattern']).shape

(6, 32)

In [4]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
tokenizer.padding_side = 'left'
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
model.lm_head = torch.nn.Linear(768, 6*64)
model = model.cuda(1)

In [5]:
model.__class__, model.config.hidden_size

(transformers.models.t5.modeling_t5.T5ForConditionalGeneration, 768)

In [6]:
params = sum([p.numel() for p in model.parameters()])
params*2/1024**2

425.718017578125

In [13]:
import time
decoder_lens = list(range(1, 16))
bs = 32
encoder_len = 512
device = torch.device("cuda:1")
time_costs = {}
# warmup

input_ids = torch.randint(0, 100, (bs, encoder_len)).to(device)
attention_mask = torch.ones(bs, encoder_len).to(device)
decoder_input_ids =torch.randint(0, 100, (bs, 10)).to(device)
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
)

for decoder_len in decoder_lens:
    input_ids = torch.randint(0, 100, (bs, encoder_len)).to(device)
    attention_mask = torch.ones(bs, encoder_len).to(device)
    decoder_input_ids =torch.randint(0, 100, (bs, decoder_len)).to(device)
    torch.cuda.synchronize()
    start = time.time()
    for i in range(10):
        out = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
        )
    torch.cuda.synchronize()
    end = time.time()
    time_costs[decoder_len] = (end-start)/10
time_costs

{1: 0.2365039348602295,
 2: 0.2400728702545166,
 3: 0.24014568328857422,
 4: 0.24187352657318115,
 5: 0.2428675889968872,
 6: 0.24388659000396729,
 7: 0.2439492702484131,
 8: 0.2441192626953125,
 9: 0.24286985397338867,
 10: 0.24423601627349853,
 11: 0.24433786869049073,
 12: 0.24466307163238527,
 13: 0.24510047435760499,
 14: 0.2453702926635742,
 15: 0.24580740928649902}

In [53]:
bs = 32
batch_indices = [list(range(i,i+bs)) for i in range(0, len(dataset), bs)]
all_prompt_text = np.array(dataset['prompt_text'])
all_decode_ids = np.array(dataset['decode_ids'])
all_decode_pattern = np.array(dataset['decode_pattern'])
num_experts_per_layer = 64
for indices in batch_indices[:4]:
    batch_text = all_prompt_text[indices].tolist()
    data = tokenizer(batch_text, return_tensors="pt", return_attention_mask=True, padding=True)
    input_ids = data.input_ids.cuda()
    attention_mask = data.attention_mask.cuda()
    decoder_input_ids = torch.tensor(all_decode_ids[indices]).cuda()
    decode_pattern = torch.tensor(all_decode_pattern[indices]).permute(0,2,1).cuda()
    decode_pattern = torch.nn.functional.one_hot(decode_pattern, num_classes=num_experts_per_layer).float()
    out = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,
    )
    logits = out.logits
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits.view(-1, num_experts_per_layer),
            decode_pattern.view(-1, num_experts_per_layer),
            reduction='mean')
    print(out.logits.shape, loss)

torch.Size([8, 32, 384]) tensor(0.6928, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
torch.Size([8, 32, 384]) tensor(0.6928, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
torch.Size([8, 32, 384]) tensor(0.6927, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
torch.Size([8, 32, 384]) tensor(0.6928, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


In [41]:
out.logits.shape

torch.Size([8, 32, 32128])

In [36]:
decode_pattern.shape, decoding_token_ids.shape

(torch.Size([8, 32, 6]), torch.Size([8, 32]))

In [2]:
import torch
import torch.nn.functional as F

def prepare_batch(input_seqs, target_seqs, pad_token_id=0):
    # 输入和目标序列的填充
    input_padded = torch.nn.utils.rnn.pad_sequence(
        input_seqs[::-1], batch_first=True, padding_value=pad_token_id).flip(dims=[1])
    target_padded = torch.nn.utils.rnn.pad_sequence(
        target_seqs, batch_first=True, padding_value=pad_token_id)

    # 创建编码器和解码器的掩码
    input_mask = (input_padded != pad_token_id)
    target_mask = (target_padded != pad_token_id)

    return input_padded, input_mask, target_padded, target_mask

# 假设 input_seqs 和 target_seqs 是预处理后的序列列表，其中每个元素是一个tensor
input_seqs = [torch.tensor([1, 2, 3, 4, 5, 6]), torch.tensor([4, 5])]
target_seqs = [torch.tensor([1, 2]), torch.tensor([3, 4, 5, 6, 7, 8])]

input_padded, input_mask, target_padded, target_mask = prepare_batch(input_seqs, target_seqs)
(input_padded.shape, input_mask.shape, target_padded.shape, target_mask.shape)
(input_padded, input_mask, target_padded, target_mask)

(tensor([[0, 0, 0, 0, 5, 4],
         [6, 5, 4, 3, 2, 1]]),
 tensor([[False, False, False, False,  True,  True],
         [ True,  True,  True,  True,  True,  True]]),
 tensor([[1, 2, 0, 0, 0, 0],
         [3, 4, 5, 6, 7, 8]]),
 tensor([[ True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True]]))

In [36]:
bs=2
seq_len = torch.randint(1, 10, (bs,))
target_seqs = [torch.randint(0,100,(seq_len[i], )) for i in range(bs)]
pad_target_seqs = torch.nn.utils.rnn.pad_sequence(target_seqs, batch_first=True, padding_value=0)
print([x.shape for x in target_seqs])
print(pad_target_seqs.shape, pad_target_seqs)


labels = [torch.randint(0, 2, (seq_len[i], 16)) for i in range(bs)]
print([x.shape for x in labels])
pad_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
print(pad_labels.shape)

[torch.Size([9]), torch.Size([3])]
torch.Size([2, 9]) tensor([[52, 96, 67, 19, 30, 99, 36, 52, 34],
        [55, 53, 64,  0,  0,  0,  0,  0,  0]])
[torch.Size([9, 16]), torch.Size([3, 16])]
torch.Size([2, 9, 16])


In [37]:
pad_labels[1]

tensor([[0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1],
        [0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [13]:
import torch
import numpy as np
from sklearn.cluster import SpectralClustering

def hamming_distance(a, b):
    return (a != b).sum(dim=(1, 2))

def sim_func(pattern_list):
    # 将矩阵展平
    patterns = torch.stack(pattern_list, dim=0)
    flat_patterns = patterns.view(patterns.size(0), -1)
    # 计算两两之间的Hamming距离
    dist = torch.cdist(flat_patterns, flat_patterns, p=0)
    # 将Hamming距离转换为相似度
    similarity = 1 - dist / flat_patterns.size(1)
    print(similarity)
    return similarity

def partition_func(similarities, k):
    # 转换为 numpy 矩阵
    sim_matrix_np = similarities.cpu().numpy()
    n_clusters = k
    # 应用 spectral clustering
    clustering = SpectralClustering(n_clusters=n_clusters, affinity='precomputed', assign_labels='kmeans')
    labels = clustering.fit_predict(sim_matrix_np)

    # 将结果转换为分组
    groups = [[] for _ in range(n_clusters)]
    for idx, label in enumerate(labels):
        groups[label].append(idx)
    return groups

def scheduler(pattern_list, k):
    similarities = sim_func(pattern_list)
    groups = partition_func(similarities, k)
    return groups

# Example usage
L, E, N = 10, 10, 6  # Dimensions and number of matrices
k = 2  # Number of groups to partition
pattern_list = [torch.randint(0, 2, (L, E), device='cuda').float() for _ in range(N)]
grouped_indices = scheduler(pattern_list, k)
print(grouped_indices)


tensor([[1.0000, 0.5300, 0.4400, 0.5500, 0.4900, 0.5300],
        [0.5300, 1.0000, 0.4900, 0.5400, 0.5400, 0.5000],
        [0.4400, 0.4900, 1.0000, 0.6100, 0.4900, 0.5500],
        [0.5500, 0.5400, 0.6100, 1.0000, 0.5600, 0.5400],
        [0.4900, 0.5400, 0.4900, 0.5600, 1.0000, 0.5200],
        [0.5300, 0.5000, 0.5500, 0.5400, 0.5200, 1.0000]], device='cuda:0')
[[2, 3, 5], [0, 1, 4]]


In [12]:
%timeit scheduler(pattern_list, k)

The slowest run took 235.17 times longer than the fastest. This could mean that an intermediate result is being cached.
440 ms ± 684 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [57]:
import torch

def initialize_indices(n, k):
    """随机初始化质心的索引"""
    indices = torch.randperm(n)[:k]
    return indices

def update_clusters(data_similarity, indices):
    """更新簇分配，选择与质心最相似的点"""
    # 提取与质心相关的相似度
    cluster_similarities = data_similarity[:, indices]
    # 每个点被分配到最相似质心的簇中
    labels = torch.argmax(cluster_similarities, dim=1)
    return labels

def update_centroids(data_similarity, labels, k):
    """根据新的簇分配更新质心"""
    new_indices = torch.zeros(k, dtype=torch.long, device='cuda')
    within_cluster_similarities = []
    for i in range(k):
        # 找到属于同一簇的所有点的相似度总和
        within_cluster_similarity = data_similarity[labels == i][:, labels == i].sum(dim=1)
        # 选择使内部相似度最大化的点作为新质心
        new_indices[i] = (labels == i).nonzero()[within_cluster_similarity.argmax()]
        within_cluster_similarities.append(within_cluster_similarity.max())
    return new_indices, within_cluster_similarities

def kmeans_similarity(data_similarity, k, num_epochs=100):
    """执行基于相似度矩阵的K-means聚类"""
    n = data_similarity.size(0)
    indices = initialize_indices(n, k).cuda()
    labels = torch.zeros(n, dtype=torch.long, device='cuda')

    for epoch in range(num_epochs):
        new_labels = update_clusters(data_similarity, indices)
        if torch.equal(labels, new_labels):
            # print('break', epoch)
            break
        labels = new_labels
        indices, within_cluster_similarities = update_centroids(data_similarity, labels, k)
        # for i, idx in enumerate(indices):
        #     print(f"Epoch-{epoch} {idx} {within_cluster_similarities[i]}")

    # 收集每个簇的成员索引
    clusters = {i: (labels == i).nonzero(as_tuple=True)[0].cpu().numpy().tolist() for i in range(k)}
    return labels, indices, clusters
    # return list(clusters.items())

def sim_func(pattern_list):
    # 将矩阵展平
    patterns = torch.stack(pattern_list, dim=0)
    flat_patterns = patterns.view(patterns.size(0), -1)
    # 计算两两之间的Hamming距离
    dist = torch.cdist(flat_patterns, flat_patterns, p=0)
    # 将Hamming距离转换为相似度
    similarity = 1 - dist / flat_patterns.size(1)
    print(similarity)
    return similarity

def scheduler(pattern_list, num_epochs=100, k=2):
    data_similarity = sim_func(pattern_list)
    labels, centroids_indices, clusters = kmeans_similarity(data_similarity, k, num_epochs)
    indices_within_cluster = list(clusters.values())
    return indices_within_cluster

# 示例用法
n = 128  # 数据点的数量
L = 6
E = 32
k = 2    # 簇的数量
data = [torch.randint(0, 2, (L, E)).cuda().float() for _ in range(n)]
indices_within_cluster = scheduler(data)
for i, cluster in enumerate(indices_within_cluster):
    print(f"Cluster {i} ({len(cluster)}): {cluster}")

tensor([[1.0000, 0.5469, 0.4323,  ..., 0.5365, 0.4583, 0.4948],
        [0.5469, 1.0000, 0.4062,  ..., 0.5208, 0.4740, 0.5104],
        [0.4323, 0.4062, 1.0000,  ..., 0.4792, 0.4844, 0.5000],
        ...,
        [0.5365, 0.5208, 0.4792,  ..., 1.0000, 0.4635, 0.5104],
        [0.4583, 0.4740, 0.4844,  ..., 0.4635, 1.0000, 0.4531],
        [0.4948, 0.5104, 0.5000,  ..., 0.5104, 0.4531, 1.0000]],
       device='cuda:0')
Cluster 0 (70): [3, 6, 7, 13, 14, 16, 17, 20, 21, 22, 26, 27, 29, 30, 31, 33, 34, 36, 39, 40, 41, 42, 44, 45, 46, 48, 49, 50, 53, 54, 55, 57, 59, 60, 61, 62, 63, 64, 65, 69, 72, 74, 77, 78, 80, 84, 85, 86, 88, 89, 90, 92, 94, 95, 98, 99, 100, 101, 102, 106, 108, 109, 111, 112, 114, 117, 118, 122, 125, 127]
Cluster 1 (58): [0, 1, 2, 4, 5, 8, 9, 10, 11, 12, 15, 18, 19, 23, 24, 25, 28, 32, 35, 37, 38, 43, 47, 51, 52, 56, 58, 66, 67, 68, 70, 71, 73, 75, 76, 79, 81, 82, 83, 87, 91, 93, 96, 97, 103, 104, 105, 107, 110, 113, 115, 116, 119, 120, 121, 123, 124, 126]


In [58]:
def balance_clusters(clusters, target_size):
    # 将簇的索引列表转换为数组以便操作
    cluster_sizes = {i: len(cluster) for i, cluster in clusters.items()}
    overfilled = {i: members for i, members in clusters.items() if len(members) > target_size}
    underfilled = {i: members for i, members in clusters.items() if len(members) < target_size}

    # 调整簇中的数据点分配
    adjusted_clusters = dict(clusters)  # 开始调整的副本
    transfer_list = []  # 存储需要移动的数据点

    # 收集过多的数据点
    for idx, members in overfilled.items():
        excess = len(members) - target_size
        transfer_list.extend((idx, member) for member in members[-excess:])
        adjusted_clusters[idx] = members[:-excess]  # 移除多余的成员

    # 将收集到的数据点重新分配到需要填充的簇中
    for idx, member in transfer_list:
        for u_idx, u_members in underfilled.items():
            if len(u_members) < target_size:
                adjusted_clusters[u_idx].append(member)
                underfilled[u_idx].append(member)
                if len(u_members) + 1 == target_size:
                    break

    # 确保调整后的簇大小正确
    assert all(len(members) == target_size for members in adjusted_clusters.values())

    return adjusted_clusters

def postprocess_clusters(indices_within_cluster, n, k):
    target_size = n // k
    clusters = {i: cluster for i, cluster in enumerate(indices_within_cluster)}
    balanced_clusters = balance_clusters(clusters, target_size)
    return list(balanced_clusters.values())

# 使用后处理调整簇的大小
balanced_cluster_indices = postprocess_clusters(indices_within_cluster, n, k)
for i, cluster in enumerate(balanced_cluster_indices):
    print(f"Balanced Cluster {i} ({len(cluster)}): {cluster}")


Balanced Cluster 0 (64): [3, 6, 7, 13, 14, 16, 17, 20, 21, 22, 26, 27, 29, 30, 31, 33, 34, 36, 39, 40, 41, 42, 44, 45, 46, 48, 49, 50, 53, 54, 55, 57, 59, 60, 61, 62, 63, 64, 65, 69, 72, 74, 77, 78, 80, 84, 85, 86, 88, 89, 90, 92, 94, 95, 98, 99, 100, 101, 102, 106, 108, 109, 111, 112]
Balanced Cluster 1 (64): [0, 1, 2, 4, 5, 8, 9, 10, 11, 12, 15, 18, 19, 23, 24, 25, 28, 32, 35, 37, 38, 43, 47, 51, 52, 56, 58, 66, 67, 68, 70, 71, 73, 75, 76, 79, 81, 82, 83, 87, 91, 93, 96, 97, 103, 104, 105, 107, 110, 113, 115, 116, 119, 120, 121, 123, 124, 126, 114, 114, 117, 117, 118, 118]


In [59]:
len([ 11860, 10, 363, 410, 8, 2243, 13, 6923, 8032, 363, 410, 8, 10098, 533, 5779, 3553, 38, 12592, 16, 5744, 9, 15, 17, 7, 3, 18, 180, 1538, 7, 3, 208, 472, 362, 1222, 3969, 15, 2217, 14080, 7, 3, 58, 2625, 10, 37, 96, 4333, 12, 370, 364, 96, 365, 3, 9164, 12062, 1108, 11526, 8275, 12, 151, 113, 428, 364, 96, 21, 3, 60, 51, 444, 2661, 96, 3, 6, 902, 1328, 42, 771, 1756, 3, 5, 242, 677, 3, 6, 16, 4480, 7617, 7, 2235, 35, 3, 208, 1648, 76, 450, 4049, 20, 7499, 22276, 89, 624, 35, 23, 3896, 3, 1621, 127, 20, 14204, 138, 29, 23, 354, 624, 88, 23, 26, 3, 9, 10098, 6297, 2301, 12, 15575, 298, 3, 25289, 3, 9, 1188, 16, 3, 9, 569, 1034, 495, 3, 6, 11, 47, 1219, 3, 88, 228, 59, 916, 250, 10098, 973, 243, 163, 151, 2127, 16, 8, 12023, 228, 428, 1281, 1867, 3, 5, 37, 2243, 13, 6923, 1213, 24, 8, 4333, 12, 370, 364, 2930, 3, 6, 34, 47, 1461, 1231, 3, 6, 11, 8, 3356, 47, 1077, 73, 4998, 3676, 3, 10, 578, 46, 1115, 16, 8, 1144, 538, 133, 36, 631, 12, 6665, 8, 12372, 2674, 13, 207, 3602, 13, 4831, 3, 5, 37, 2243, 13, 6923, 65, 1213, 24, 6980, 1073, 7250, 1067, 8, 7401, 13, 1108, 11526, 3, 6, 250, 1086, 8, 538, 2731, 34, 3, 6, 713, 1146, 1073, 405, 59, 3, 5, 1685, 124, 2389, 12052, 38, 3, 9, 313, 3, 5, 86, 5744, 9, 15, 17, 7, 3, 18, 180, 1538, 7, 3, 208, 472, 362, 1222, 3969, 15, 2217, 14080, 7, 8667, 5744, 9, 15, 17, 7, 3, 18, 180, 1538, 7, 7760, 255, 225, 36, 29560, 26, 57, 10098, 569, 958, 21, 1358, 13, 4281, 1058, 16, 3434, 3, 5, 37, 10098, 533, 5779, 3, 12327, 8, 1058, 12592, 3, 6, 78, 255, 3, 15585, 48, 12103, 8, 4333, 41, 13, 8, 2968, 533, 5998, 3, 61, 12, 370, 364, 3, 5, 3, 8656, 10524, 5776, 24, 2833, 364, 225, 59, 36, 3, 12327, 38, 1456, 3, 6, 11, 225, 59, 1590, 441, 1108, 11526, 3, 5, 299, 8, 2243, 13, 6923, 1213, 533, 47, 3, 9, 96, 313, 96, 237, 713, 8, 789, 41, 1066, 145, 8, 313, 11095, 3, 61, 1866, 21, 8, 313, 3, 5, 868, 5779, 228, 36, 24125, 16, 17285, 53, 12, 29560, 1221, 21, 1035, 364, 7979, 3, 99, 8, 533, 124, 1204, 44, 234, 47, 406, 64, 76, 15, 7230, 3, 6, 11, 34, 2348, 96, 1038, 1035, 2056, 96, 30, 84, 5872, 3476, 15, 26, 38, 1389, 11, 1316, 3, 5, 37, 2243, 2311, 24, 8, 928, 4616, 13, 3, 9, 1868, 18686, 2794, 7809, 3, 6, 11, 48, 19, 92, 1176, 16, 8, 2625, 13, 8, 1270, 3, 31, 3, 7, 868, 1685, 1387, 3, 5, 71, 1583, 45, 452, 364, 3, 6, 430, 6280, 1057, 13, 364, 33, 273, 12910, 38, 6016, 3, 5, 10854, 348, 7, 3, 208, 14343, 15, 526, 222, 49, 4049, 1534, 9, 22350, 17, 1213, 24, 8, 12023, 3, 31, 9464, 13, 11513, 5962, 3, 6, 379, 8, 30693, 7, 57, 128, 24858, 30, 11618, 41, 68, 59, 10098, 1157, 7, 3, 61, 352, 12, 1975, 5391, 3, 6, 4728, 1067, 1108, 11526, 16889, 3, 5, 37, 2243, 13, 6923, 1053, 15, 26, 24, 3, 29, 4667, 9798, 4845, 130, 6478, 16, 66, 1144, 2315, 3, 6, 11, 78, 48, 7641, 15, 26, 45, 119, 1488, 213, 813, 17448, 42, 119, 16172, 3, 18, 1281, 1756, 47, 1426, 12, 19841, 3, 5, 156, 46, 1756, 405, 1590, 441, 1108, 11526, 3, 6, 3, 9, 19841, 54, 36, 24125, 365, 1108, 9065, 42, 147, 4055, 53, 1502, 1597, 57, 8, 2243, 13, 6923, 3, 5, 86, 28196, 9682, 7, 3, 22480, 3, 208, 3271, 4049, 7430, 11389, 2, 29, 3, 9, 268, 24, 1916, 27592, 647, 7, 41, 28, 24330, 195, 25381, 11, 430, 8175, 5654, 3, 61, 13090, 12, 1921, 3, 9, 10098, 973, 24, 19551, 53, 2107, 3874, 722, 3, 5, 37, 2243, 13, 6923, 1213, 8, 10098, 30693, 6665, 26, 3, 9, 12372, 2674, 12, 1709, 96, 30995, 11336, 16, 16944, 3415, 96, 379, 9932, 8, 3733, 45, 8299, 1085, 15541, 3, 6, 2932, 6011, 3410, 16, 8, 10098, 3212, 3, 5, 86, 20336, 3012, 11516, 35, 7635, 3, 208, 4523, 29, 3, 9, 96, 6124, 26, 11956, 96, 268, 47, 18168, 57, 8, 4523, 29, 6098, 3, 5, 94, 2944, 9901, 6124, 4740, 364, 45, 3, 9, 1270, 1669, 718, 10035, 7, 291, 3937, 3, 6, 68, 2797, 141, 4973, 15, 26, 581, 96, 1556, 44, 9357, 96, 4527, 3, 5, 37, 2243, 13, 6923, 1213, 24, 8, 2968, 18896, 701, 13, 936, 21377, 3, 6, 84, 365, 29320, 8, 4514, 3, 6, 410, 3476, 38, 3, 9, 24125, 19841, 30, 4333, 12, 370, 364, 3, 5, 86, 18515, 3625, 76, 15991, 9, 20, 377, 2810, 4243, 3, 208, 4625, 9882, 836, 8306, 15, 2234, 4922, 52, 26, 23, 9, 20, 1414, 7, 115, 32, 9, 8, 2243, 13, 6923, 92, 1213, 24, 8, 538, 3, 23507, 63, 30, 9531, 3, 6, 11, 3, 9, 10736, 21, 3, 9, 17223, 4900, 2046, 1669, 24, 141, 1916, 1396, 9531, 364, 3, 6, 47, 24125, 12, 1709, 7712, 11, 9531, 213, 151, 3, 31, 3, 7, 2441, 130, 1385, 12355, 5560, 3, 5, 37, 4514, 47, 7385, 342, 38, 48, 47, 46, 2016, 11, 1316, 194, 12, 8000, 8, 2261, 982, 13, 7712, 24, 7931, 147, 8, 1396, 3, 5, 86, 8, 1799, 28965, 3, 9, 563, 13, 131, 2420, 7, 130, 10763, 3676, 16, 1108, 898, 24, 8, 495, 973, 65, 1597, 3, 5, 11801, 10, 1 ])

961

In [52]:
%timeit kmeans_similarity(data_similarity, k, 5)

29.8 ms ± 4.35 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
