### Datasets使用简介

这个文档将简单介绍Datasets库的各种用法以及可定制的内容

Datasets 是一个用于方便地访问和共享音频、计算机视觉和自然语言处理（NLP）任务的数据集的库。

数据集部分的工作，一部分在于数据集的收集，另一部分在于数据集的处理。Datasets库的出现，一定程度上也使得这两部分的工作变得简单了许多。


In [1]:
import datasets
from transformers import BertTokenizer
from datasets import load_dataset

models.bert transformers
tokenization_bert transformers.models.bert


In [46]:
# load_dataset使用
# https://huggingface.co/datasets 列出了所有可供加载的数据集
dataset_imdb = load_dataset('imdb')
print(dataset_imdb)
print(dataset_imdb["train"].features)

# 只加载某个划分
dataset_imdb_train = load_dataset('imdb',split="train")
print(dataset_imdb_train)

# 切分未切分的Dataset
imdb_train_split = dataset_imdb_train.train_test_split(train_size=0.8,seed=222)
print(imdb_train_split)

# 加载某subset
glue_data = load_dataset('glue',"ax") 
print(glue_data)

Found cached dataset imdb (C:/Users/james/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})
{'text': Value(dtype='string', id=None), 'label': ClassLabel(names=['neg', 'pos'], id=None)}
Found cached dataset imdb (C:/Users/james/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Loading cached split indices for dataset at C:\Users\james\.cache\huggingface\datasets\imdb\plain_text\1.0.0\d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0\cache-fc4b4bdc3bc3689c.arrow and C:\Users\james\.cache\huggingface\datasets\imdb\plain_text\1.0.0\d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0\cache-64ab916026e5722f.arrow
Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})
DatasetDict({
    train: Da

  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 1104
    })
})


In [5]:
# 本地数据集加载，支持csv、text、json、pandas等格式
# field 参数的具体取值取决于加载的数据集的结构和字段命名  在不同的数据集中，可能会有不同的字段名和结构
squad_it_dataset_train = load_dataset("json", data_files="SQuAD_it-train.json", field="data")
squad_it_dataset_train

Found cached dataset json (C:/Users/james/.cache/huggingface/datasets/json/default-1548ecb804c8e6c7/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['paragraphs', 'title'],
        num_rows: 442
    })
})

In [6]:
# 同时加载训练和测试数据集
data_files = {"train": "SQuAD_it-train.json", "test": "SQuAD_it-test.json"}
squad_it_dataset = load_dataset("json", data_files=data_files, field="data")
squad_it_dataset
# data_files也可以直接使用url路径直接访问

Downloading and preparing dataset json/default to C:/Users/james/.cache/huggingface/datasets/json/default-1af5a03550dd4ef0/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to C:/Users/james/.cache/huggingface/datasets/json/default-1af5a03550dd4ef0/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['paragraphs', 'title'],
        num_rows: 442
    })
    test: Dataset({
        features: ['paragraphs', 'title'],
        num_rows: 48
    })
})

In [7]:
# 数据查看
print(dataset_imdb["train"][:4])
print(type(dataset_imdb["train"]))
print(type(dataset_imdb['train'][:4]))
print(type(dataset_imdb['train']["text"]))

{'text': ['I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far b

In [63]:
# 通过加载脚本加载数据集
imdb_new = load_dataset("imdb.py")
print(datasets.Split.TRAIN)

Downloading and preparing dataset imdb/plain_text to C:/Users/james/.cache/huggingface/datasets/imdb/plain_text/1.0.0/44e72d998fe5cf065b7e9149906bef24fba80eb43984607f84ab8ef1d2188a56...


Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating unsupervised split: 0 examples [00:00, ? examples/s]

Dataset imdb downloaded and prepared to C:/Users/james/.cache/huggingface/datasets/imdb/plain_text/1.0.0/44e72d998fe5cf065b7e9149906bef24fba80eb43984607f84ab8ef1d2188a56. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

train


In [10]:
# 数据选取与过滤
# 通过select方法与filter方法对数据集中的数据进行选取与过滤
print(dataset_imdb["train"].shuffle(seed = 2).select(range(100)))
imdb_train_pos = dataset_imdb["train"].filter(lambda example: example["label"] == 1 and len(example["text"]) > 100)
print(imdb_train_pos["label"][:150])

Loading cached processed dataset at C:\Users\james\.cache\huggingface\datasets\imdb\plain_text\1.0.0\d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0\cache-0f8136b87dae204d.arrow
Dataset({
    features: ['text', 'label'],
    num_rows: 100
})
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [15]:
# Dataset.map函数的使用,用于数据处理
# 用于对加载的数据集中的每个样本应用一个函数或转换操作，并返回转换后的新数据集
# map() 方法返回的是一个新的数据集对象，不会修改原始数据集

# map()为每个text前加入一个prefix
def add_prefix(example):
    example["text"] = 'Prefix: ' + example["text"]
    return example
print(dataset_imdb.map(add_prefix)["train"][0])

# map()创建新列
def compute_text_length(example):
    return {"text_length":len(example["text"])}

def compute_text_length_2(example):
    example["text_length"] = len(example["text"])
    return example
    
print(dataset_imdb.map(compute_text_length_2))
print(dataset_imdb.map(compute_text_length_2)["train"]["text_length"][:100])

Loading cached processed dataset at C:\Users\james\.cache\huggingface\datasets\imdb\plain_text\1.0.0\d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0\cache-0d92d288b3d2649c.arrow
Loading cached processed dataset at C:\Users\james\.cache\huggingface\datasets\imdb\plain_text\1.0.0\d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0\cache-796d40f4c80728ea.arrow
Loading cached processed dataset at C:\Users\james\.cache\huggingface\datasets\imdb\plain_text\1.0.0\d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0\cache-b9fd55e8e74acf18.arrow
Loading cached processed dataset at C:\Users\james\.cache\huggingface\datasets\imdb\plain_text\1.0.0\d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0\cache-111535dfb439dde9.arrow
Loading cached processed dataset at C:\Users\james\.cache\huggingface\datasets\imdb\plain_text\1.0.0\d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0\cache-387a9c805c243f9c.arrow
Loading cached processed 

In [16]:
# 利用map进行tokenize
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def prepare_train_features(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=512)

train_token = dataset_imdb["train"].select(range(100)).map(prepare_train_features,remove_columns=["text"])
print(train_token)
print(train_token[0])

path_or_repo_id bert-base-uncased
filename vocab.txt
pretrained_model_name_or_path bert-base-uncased
In else
configuration_file config.json
resolved_config_file C:\Users\james/.cache\huggingface\hub\models--bert-base-uncased\snapshots\0a6aa9128b6194f4f3c4db429b6cb4891cdb421b\config.json
resolved_config_file C:\Users\james/.cache\huggingface\hub\models--bert-base-uncased\snapshots\0a6aa9128b6194f4f3c4db429b6cb4891cdb421b\config.json
configuration_bert transformers.models.bert
config transformers.onnx
configuration_encoder_decoder transformers.models.encoder_decoder


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 100
})
{'label': 0, 'input_ids': [101, 1045, 12524, 1045, 2572, 8025, 1011, 3756, 2013, 2026, 2678, 3573, 2138, 1997, 2035, 1996, 6704, 2008, 5129, 2009, 2043, 2009, 2001, 2034, 2207, 1999, 3476, 1012, 1045, 2036, 2657, 2008, 2012, 2034, 2009, 2001, 8243, 2011, 1057, 1012, 1055, 1012, 8205, 2065, 2009, 2412, 2699, 2000, 4607, 2023, 2406, 1010, 3568, 2108, 1037, 5470, 1997, 3152, 2641, 1000, 6801, 1000, 1045, 2428, 2018, 2000, 2156, 2023, 2005, 2870, 1012, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 1996, 5436, 2003, 8857, 2105, 1037, 2402, 4467, 3689, 3076, 2315, 14229, 2040, 4122, 2000, 4553, 2673, 2016, 2064, 2055, 2166, 1012, 1999, 3327, 2016, 4122, 2000, 3579, 2014, 3086, 2015, 2000, 2437, 2070, 4066, 1997, 4516, 2006, 2054, 1996, 2779, 25430, 14728, 2245, 2055, 3056, 2576, 3314, 2107, 2004, 1996, 5148, 2162, 1998, 2679, 3314, 1999, 1996, 2142, 2163, 1012, 1999, 2090, 4851, 8801, 1

In [19]:
# 大数据情况下，为了节约空间，可以在加载数据集时设置streaming=True
# 当设置 streaming=True 时，数据集将以流的形式从磁盘或远程源逐行读取，并逐行进行处理。这种方式可以提高内存效率，特别适用于处理大型数据集或无法一次性全部加载到内存的数据集。
# 数据集将返回一个流式数据集对象，该对象支持迭代操作，每次迭代返回一个示例。可以使用 for 循环或其他迭代方式逐行处理数据集中的示例，而不需要一次性加载整个数据集。
dataset_imdb_streamed = load_dataset("imdb",split="train",streaming=True)
print(dataset_imdb_streamed)

<datasets.iterable_dataset.IterableDataset object at 0x000001D9BBF3E610>


In [21]:
# 流式数据集的读取
my_iter = iter(dataset_imdb_streamed)
print(next(my_iter))
for i,data in enumerate(dataset_imdb_streamed):
    print("Num:",i,data["text"])
    if(i>5):
        break

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

In [23]:
dataset_imdb_streamed = dataset_imdb_streamed.skip(50)
print(dataset_imdb_streamed.take(2))
print(list(dataset_imdb_streamed.take(2)))

<datasets.iterable_dataset.IterableDataset object at 0x000001D9B2341DF0>
[{'text': "Terrible movie. Nuff Said.<br /><br />These Lines are Just Filler. The movie was bad. Why I have to expand on that I don't know. This is already a waste of my time. I just wanted to warn others. Avoid this movie. The acting sucks and the writing is just moronic. Bad in every way. The only nice thing about the movie are Deniz Akkaya's breasts. Even that was ruined though by a terrible and unneeded rape scene. The movie is a poorly contrived and totally unbelievable piece of garbage.<br /><br />OK now I am just going to rag on IMDb for this stupid rule of 10 lines of text minimum. First I waste my time watching this offal. Then feeling compelled to warn others I create an account with IMDb only to discover that I have to write a friggen essay on the film just to express how bad I think it is. Totally unnecessary.", 'label': 0}, {'text': 'Assuming this won\'t end up a straight-to-video release, I would hav

In [24]:
# ​ 与Dataset.map()函数类似，如果要做数据预处理的话，可以使用IterableDataset.map()函数进行操作。与Dataset.map()的主要区别在于，IterableDataset.map()处理的结果也需要迭代读取。
tokenized_data = dataset_imdb_streamed.map(prepare_train_features,remove_columns=['text'])
print(next(iter(tokenized_data)))

{'label': 0, 'input_ids': [101, 6659, 3185, 1012, 16371, 4246, 2056, 1012, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 2122, 3210, 2024, 2074, 6039, 2121, 1012, 1996, 3185, 2001, 2919, 1012, 2339, 1045, 2031, 2000, 7818, 2006, 2008, 1045, 2123, 1005, 1056, 2113, 1012, 2023, 2003, 2525, 1037, 5949, 1997, 2026, 2051, 1012, 1045, 2074, 2359, 2000, 11582, 2500, 1012, 4468, 2023, 3185, 1012, 1996, 3772, 19237, 1998, 1996, 3015, 2003, 2074, 22822, 12356, 1012, 2919, 1999, 2296, 2126, 1012, 1996, 2069, 3835, 2518, 2055, 1996, 3185, 2024, 7939, 10993, 17712, 20718, 1005, 1055, 12682, 1012, 2130, 2008, 2001, 9868, 2295, 2011, 1037, 6659, 1998, 4895, 24045, 5732, 9040, 3496, 1012, 1996, 3185, 2003, 1037, 9996, 9530, 18886, 7178, 1998, 6135, 23653, 3538, 1997, 13044, 1012, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 7929, 2085, 1045, 2572, 2074, 2183, 2000, 17768, 2006, 10047, 18939, 2005, 2023, 5236, 3627, 1997, 2184, 3210, 1997, 3793, 6263, 1012, 2034, 1045, 5949, 2026, 2051, 3666, 2023

In [43]:
# datasets 库更侧重于数据集的访问、预处理和转换
# 需要与Dataloader配合使用，也可以实现更高维度上的自定义
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def prepare_features(example):
    encoding = tokenizer(example["text"], truncation=True, padding="max_length", max_length=512)
    label = int(example["label"])
    in_ids = torch.tensor(encoding["input_ids"], dtype=torch.long)
    return {
        "input_ids": in_ids,
        "attention_mask": torch.Tensor(encoding["attention_mask"]),
        "label": label
    }

train_token = dataset_imdb["train"].map(prepare_features,remove_columns=["text"])
print(train_token["input_ids"][0])

path_or_repo_id bert-base-uncased
filename vocab.txt
pretrained_model_name_or_path bert-base-uncased
In else
configuration_file config.json
resolved_config_file C:\Users\james/.cache\huggingface\hub\models--bert-base-uncased\snapshots\0a6aa9128b6194f4f3c4db429b6cb4891cdb421b\config.json
resolved_config_file C:\Users\james/.cache\huggingface\hub\models--bert-base-uncased\snapshots\0a6aa9128b6194f4f3c4db429b6cb4891cdb421b\config.json


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

[101, 1045, 12524, 1045, 2572, 8025, 1011, 3756, 2013, 2026, 2678, 3573, 2138, 1997, 2035, 1996, 6704, 2008, 5129, 2009, 2043, 2009, 2001, 2034, 2207, 1999, 3476, 1012, 1045, 2036, 2657, 2008, 2012, 2034, 2009, 2001, 8243, 2011, 1057, 1012, 1055, 1012, 8205, 2065, 2009, 2412, 2699, 2000, 4607, 2023, 2406, 1010, 3568, 2108, 1037, 5470, 1997, 3152, 2641, 1000, 6801, 1000, 1045, 2428, 2018, 2000, 2156, 2023, 2005, 2870, 1012, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 1996, 5436, 2003, 8857, 2105, 1037, 2402, 4467, 3689, 3076, 2315, 14229, 2040, 4122, 2000, 4553, 2673, 2016, 2064, 2055, 2166, 1012, 1999, 3327, 2016, 4122, 2000, 3579, 2014, 3086, 2015, 2000, 2437, 2070, 4066, 1997, 4516, 2006, 2054, 1996, 2779, 25430, 14728, 2245, 2055, 3056, 2576, 3314, 2107, 2004, 1996, 5148, 2162, 1998, 2679, 3314, 1999, 1996, 2142, 2163, 1012, 1999, 2090, 4851, 8801, 1998, 6623, 7939, 4697, 3619, 1997, 8947, 2055, 2037, 10740, 2006, 4331, 1010, 2016, 2038, 3348, 2007, 2014, 3689, 3836, 1010, 19846

In [60]:
# Dataloader与Datasets结合使用
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch
class MyDataset(Dataset):
    def __init__(self,data_tokenized):
        self.data_tokenized = data_tokenized

    def __getitem__(self,idx):
        item = {key:torch.Tensor(self.data_tokenized[key][idx]) for key in ["input_ids","attention_mask"]}
        item["label"] = self.data_tokenized["label"][idx]
        return item

    def __len__(self):
        return len(self.data_tokenized) # self.data_tokenized.num_rows

In [61]:
train_dataset = MyDataset(train_token)
print(train_dataset[0])
train_data_loader = DataLoader(train_dataset,batch_size = 64,shuffle = True)
for one_batch in train_data_loader:
    print(one_batch["input_ids"].shape)
    break

{'input_ids': tensor([  101.,  1045., 12524.,  1045.,  2572.,  8025.,  1011.,  3756.,  2013.,
         2026.,  2678.,  3573.,  2138.,  1997.,  2035.,  1996.,  6704.,  2008.,
         5129.,  2009.,  2043.,  2009.,  2001.,  2034.,  2207.,  1999.,  3476.,
         1012.,  1045.,  2036.,  2657.,  2008.,  2012.,  2034.,  2009.,  2001.,
         8243.,  2011.,  1057.,  1012.,  1055.,  1012.,  8205.,  2065.,  2009.,
         2412.,  2699.,  2000.,  4607.,  2023.,  2406.,  1010.,  3568.,  2108.,
         1037.,  5470.,  1997.,  3152.,  2641.,  1000.,  6801.,  1000.,  1045.,
         2428.,  2018.,  2000.,  2156.,  2023.,  2005.,  2870.,  1012.,  1026.,
         7987.,  1013.,  1028.,  1026.,  7987.,  1013.,  1028.,  1996.,  5436.,
         2003.,  8857.,  2105.,  1037.,  2402.,  4467.,  3689.,  3076.,  2315.,
        14229.,  2040.,  4122.,  2000.,  4553.,  2673.,  2016.,  2064.,  2055.,
         2166.,  1012.,  1999.,  3327.,  2016.,  4122.,  2000.,  3579.,  2014.,
         3086.,  2015.,  2

In [62]:
class MyDataset(Dataset):
    def __init__(self,encodings,labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self,idx):
        item = {key:torch.Tensor(val[idx]) for key,val in self.encodings.items()}
        item["label"] = int(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_encoding = tokenizer(dataset_imdb["train"]["text"], truncation=True, padding=True, max_length=512)
train_dataset = MyDataset(train_encoding,dataset_imdb["train"]["label"])
train_dataloader = DataLoader(train_dataset,shuffle = False,batch_size = 64)
for one_batch in train_dataloader:
    print(one_batch["input_ids"].shape)
    break


torch.Size([64, 512])
