In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer
from datasets import Dataset

In [2]:
dataset_all = load_dataset(path='dair-ai/emotion')
dataset_train = dataset_all['train']
dataset_train  # 类型:datasets.arrow_dataset.Dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 16000
})

### filter

In [3]:
def filter_func(data):
    print('data:', data)
    return data['text'].startswith('i feel as')  # 返回`text`字段以'i'开头的所有数据


# Apply a filter function to all the elements in the table in batches and update the table so that the dataset only includes examples according to the filter function.
start_with_ar = dataset_train.filter(filter_func)
print(start_with_ar)  # 类型:datasets.arrow_dataset.Dataset
print(len(start_with_ar), start_with_ar['text'])

Dataset({
    features: ['text', 'label'],
    num_rows: 100
})
100 ['i feel as confused about life as a teenager or as jaded as a year old man', 'i feel as a child innocent feelings illustrating a', 'i feel as if i was abused in some way', 'i feel as if i should be punished for neglecting you', 'i feel as though ive reached a point in my career where im highly respected there', 'i feel as if i must blog constantly for all my loyal fans the baker thia sandwich the scruncher and of course mini t rex', 'i feel as if this opportunity to return to moz is gods gracious gracious way of giving me that heat desire despite my own self doubt and uncertainty in the past', 'i feel as if the leaders of countries do not depict the people of their countries because for the love of god i hope no one thought at all i was in any way supportive or like george w', 'i feel assured that my mind is not one', 'i feel as though i am being a little neglectful of my fellow bloggers', 'i feel as though i cant bea

### map

In [4]:
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)


def map_func(data, name='hello'):
    print('data: ', data)
    return tokenizer(data["text"], padding=True, truncation=True)

In [None]:
# Apply a function to all the elements in the table (individually or in batches) and update the table (if function does update examples).
datatset_map = dataset_train.map(map_func,
                                 # 数据集中的每条样本单独处理
                                 batched=False  # 默认batched=False
                                 # Max number of processes when generating cache.
                                 num_proc=8  # 默认num_proc=1
                                 )
# 原有数据与map函数新增数据的联合
print(datatset_map)  # 类型:datasets.arrow_dataset.Dataset
for i in datatset_map['input_ids'][:10]:
    print(len(i))  # 样本长度不一

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 16000
})
7
23
12
22
8
17
30
20
25
6


In [6]:
datatset_map_batched_10 = dataset_train.map(map_func,
                                            # 批次处理,每次处理数据集中的2条样本(默认batch_size=1000)
                                            batched=True, batch_size=2, 
                                            fn_kwargs={"name": 'hello jave!'})
print(datatset_map_batched_10)
for i in datatset_map_batched_10['input_ids'][:30]:
    print(len(i))  # 样本长度每2条相等

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 16000
})
23
23
22
22
17
17
30
30
25
25
23
23
14
14
44
44
10
10
25
25
46
46
31
31
57
57
20
20
25
25


In [7]:
datatset_map_batched_all = dataset_train.map(map_func,
                                             # 批次处理,整个数据集同时进行处理
                                             batched=True, batch_size=None)
for i in datatset_map_batched_all['input_ids'][:50]:
    print(len(i))  # 所有样本等长

87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87
87


In [8]:
def not_eq_len(data):
    print("data: ", data)
    return {"c": data["a"][: 2]} 

dataset = Dataset.from_dict({"a": [0, 1, 2, 4], "b": [3, 4, 5, 6]})

dataset

Dataset({
    features: ['a', 'b'],
    num_rows: 4
})

In [9]:
dataset.map(not_eq_len, batched=True)  # 列a、列b的长度为4,列c的长度为2,报错

Map:   0%|          | 0/4 [00:00<?, ? examples/s]

data:  {'a': [0, 1, 2, 4], 'b': [3, 4, 5, 6]}


ArrowInvalid: Column 2 named c expected length 4 but got length 2

In [10]:
dataset.map(not_eq_len, batched=True, 
            # Remove a selection of columns while doing the mapping.
            remove_columns=["a", "b"]
            )

Map:   0%|          | 0/4 [00:00<?, ? examples/s]

data:  {'a': [0, 1, 2, 4], 'b': [3, 4, 5, 6]}


Dataset({
    features: ['c'],
    num_rows: 2
})