In [1]:
from transformers import DataCollatorForLanguageModeling
from transformers import AutoTokenizer
import torch

### custom

In [2]:
def collator(data):
    return {key: [d[key] for d in data] for key in data[0]}

In [3]:
data = [
    {'input_ids': [1, 2, 3], 'attention_mask': [1, 1, 1], 'labels': 0},
    {'input_ids': [4, 5, 6], 'attention_mask': [1, 1, 1], 'labels': 1},
]

In [4]:
collator(data)

{'input_ids': [[1, 2, 3], [4, 5, 6]],
 'attention_mask': [[1, 1, 1], [1, 1, 1]],
 'labels': [0, 1]}

### DataCollatorForLanguageModeling

In [17]:
texts = [
    "Hello, how are you?",
    "I am fine, thank you!",
    "This is a sample sentence for language modeling.",
    "Another example sentence to train the model."
]

In [11]:
# 初始化分词器
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# 初始化 DataCollator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)



In [25]:
encoded_inputs = [tokenizer(text) for text in texts]
encoded_inputs

[{'input_ids': [15496, 11, 703, 389, 345, 30], 'attention_mask': [1, 1, 1, 1, 1, 1]},
 {'input_ids': [40, 716, 3734, 11, 5875, 345, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]},
 {'input_ids': [1212, 318, 257, 6291, 6827, 329, 3303, 21128, 13], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]},
 {'input_ids': [6610, 1672, 6827, 284, 4512, 262, 2746, 13], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}]

In [19]:
tokenizer.decode(0)

'!'

In [29]:
batch = data_collator(encoded_inputs)
batch

{'input_ids': tensor([[15496,    11,   703,   389,   345,    30, 50256, 50256, 50256],
        [   40,   716,  3734,    11,  5875,   345,     0, 50256, 50256],
        [ 1212,   318,   257,  6291,  6827,   329,  3303, 21128,    13],
        [ 6610,  1672,  6827,   284,  4512,   262,  2746,    13, 50256]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0]]), 'labels': tensor([[15496,    11,   703,   389,   345,    30,  -100,  -100,  -100],
        [   40,   716,  3734,    11,  5875,   345,     0,  -100,  -100],
        [ 1212,   318,   257,  6291,  6827,   329,  3303, 21128,    13],
        [ 6610,  1672,  6827,   284,  4512,   262,  2746,    13,  -100]])}

In [26]:
dataloader = torch.utils.data.DataLoader(encoded_inputs, collate_fn=data_collator, batch_size=2)
next(iter(dataloader))

{'input_ids': tensor([[15496,    11,   703,   389,   345,    30, 50256],
        [   40,   716,  3734,    11,  5875,   345,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[15496,    11,   703,   389,   345,    30,  -100],
        [   40,   716,  3734,    11,  5875,   345,     0]])}