In [1]:
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
raw_datasets = load_dataset("glue","sst2")
raw_datasets 

Generating train split: 100%|██████████| 67349/67349 [00:00<00:00, 858397.11 examples/s]
Generating validation split: 100%|██████████| 872/872 [00:00<00:00, 439701.02 examples/s]
Generating test split: 100%|██████████| 1821/1821 [00:00<00:00, 698156.09 examples/s]


DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [4]:
raw_train_dataset = raw_datasets["train"]

In [5]:
raw_train_dataset[0]

{'sentence': 'hide new secretions from the parental units ',
 'label': 0,
 'idx': 0}

In [6]:
raw_train_dataset.features

{'sentence': Value(dtype='string', id=None),
 'label': ClassLabel(names=['negative', 'positive'], id=None),
 'idx': Value(dtype='int32', id=None)}

In [7]:
from transformers import AutoTokenizer

In [8]:
checkpoint = "bert-base-uncased"

In [9]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [14]:
def tokenize_function(example):
    return tokenizer(example["sentence"],truncation=True)

In [15]:
tokenized_datasets = raw_datasets.map(tokenize_function,batched=True)

Map: 100%|██████████| 67349/67349 [00:02<00:00, 30353.76 examples/s]
Map: 100%|██████████| 872/872 [00:00<00:00, 7767.81 examples/s]
Map: 100%|██████████| 1821/1821 [00:00<00:00, 22169.73 examples/s]


In [20]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [21]:
from transformers import DataCollatorWithPadding 

data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 

In [25]:
samples = tokenized_datasets["train"][:8]
print(samples)
print(samples.items())
samples = {k: v for k, v in samples.items() if k not in ["idx", "sentence"]}
print(len(samples["input_ids"]))
[len(x) for x in samples["input_ids"]]

{'sentence': ['hide new secretions from the parental units ', 'contains no wit , only labored gags ', 'that loves its characters and communicates something rather beautiful about human nature ', 'remains utterly satisfied to remain the same throughout ', 'on the worst revenge-of-the-nerds clichés the filmmakers could dredge up ', "that 's far too tragic to merit such superficial treatment ", 'demonstrates that the director of such hollywood blockbusters as patriot games can still turn out a small , personal film with an emotional wallop . ', 'of saucy '], 'label': [0, 0, 1, 0, 0, 0, 1, 1], 'idx': [0, 1, 2, 3, 4, 5, 6, 7], 'input_ids': [[101, 5342, 2047, 3595, 8496, 2013, 1996, 18643, 3197, 102], [101, 3397, 2053, 15966, 1010, 2069, 4450, 2098, 18201, 2015, 102], [101, 2008, 7459, 2049, 3494, 1998, 10639, 2015, 2242, 2738, 3376, 2055, 2529, 3267, 102], [101, 3464, 12580, 8510, 2000, 3961, 1996, 2168, 2802, 102], [101, 2006, 1996, 5409, 7195, 1011, 1997, 1011, 1996, 1011, 11265, 17811, 1

[10, 11, 15, 10, 22, 13, 29, 6]

In [26]:
batch = data_collator(samples)

In [28]:
print(batch)
{k: v.shape for k, v in batch.items() }

{'input_ids': tensor([[  101,  5342,  2047,  3595,  8496,  2013,  1996, 18643,  3197,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  3397,  2053, 15966,  1010,  2069,  4450,  2098, 18201,  2015,
           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2008,  7459,  2049,  3494,  1998, 10639,  2015,  2242,  2738,
          3376,  2055,  2529,  3267,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  3464, 12580,  8510,  2000,  3961,  1996,  2168,  2802,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2006,  1996,  5409,  7195,  1011,  1997,  101

{'input_ids': torch.Size([8, 29]),
 'token_type_ids': torch.Size([8, 29]),
 'attention_mask': torch.Size([8, 29]),
 'labels': torch.Size([8])}