In [2]:
import os
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

In [1]:
from datasets import load_dataset

## basics

- data collator: collates data， prepare for the model input from raw inputs
    - 比如长度处理, padding
    - 更广义上来说，数据预处理；

```
# bs: 4
1, 3, 4
1, 7, 5, 2, 2, 4
1, 3, 2, 4
1, 6, 4

# bs: 4, seq_len: 6
# padding
1, 3, 4, 0, 0, 0
1, 7, 5, 2, 2, 4
1, 3, 2, 4, 0, 0
1, 6, 4, 0, 0, 0
```

### 一个标准用法

```
from transformers.data import DefaultDataCollator
from transformers import Trainer

data_collator = DefaultDataCollator()
trainer = Trainer(model=model, 
                  args=args, 
                  train_dataset=dataset, 
                  data_collator=data_collator)
```

In [8]:
from transformers.data import DefaultDataCollator
from transformers.data import DataCollatorWithPadding
from transformers.data import DataCollatorForTokenClassification
from transformers.data import DataCollatorForSeq2Seq
from transformers.data import DataCollatorForLanguageModeling

```
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer
)

data_collator = DataCollatorForTokenClassification(
    tokenizer=tokenizer,
)


data_collator = DataCollatorForSeq2Seq(
    model=model,
    tokenizer=tokenizer,
)
```

### for seq2seq

```
# pair of input
(1, 3, 4)
(5, 3, 10)

(1, 7, 5, 2, 2, 4)
(8, 9, 3, 8, 2, 10)

(1, 3, 2, 4)
(3, 4, 8, 10)


# inputs
(1, 3, 4, 0, 0, 0)
(1, 7, 5, 2, 2, 4)
(1, 3, 2, 4, 0, 0)

# labels
(5, 3, 10, -100, -100, -100)
(8, 9, 3, 8, 2, 10)
(3, 4, 8, 10, -100, -100)
```

### for LM

```

data_collator = DataCollatorForLanguageModeling(
    model=model,
    mlm=False
)

data_collator = DataCollatorForLanguageModeling(
    model=model,
    mlm=True,
    mlm_probability=0.15
)
```

```
# bs: 4
1, 3, 4
1, 7, 5, 2, 2, 4
1, 3, 2, 4
1, 6, 4

# bs: 4, seq_len: 6
# padding
1, [MASK], 4, 0, 0, 0
1, 7, 5, [MASK], 2, 4
1, 3, 2, [MASK], 0, 0
1, [MASK], 4, 0, 0, 0
```

## custom collator

In [4]:
ds = load_dataset("imdb", split="train")
ds

Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

In [5]:
ds[0].keys()

dict_keys(['text', 'label'])

In [8]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [12]:
res_dict = collator(ds)
print(len(res_dict['text']))
print(len(res_dict['label']))

25000
25000
