In [1]:
import datasets
from datasets import load_dataset

In [2]:
dataset_args = {}
dataset_args["keep_linebreaks"] = True

In [4]:
data_files = 'sample-wiki.txt'
extension = 'text'
raw_datasets = load_dataset(
    extension,
    data_files=data_files,
    **dataset_args,
)

Using custom data configuration default-706b6e6ca36d15ce


Downloading and preparing dataset text/default to /home/husein/.cache/huggingface/datasets/text/default-706b6e6ca36d15ce/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Dataset text downloaded and prepared to /home/husein/.cache/huggingface/datasets/text/default-706b6e6ca36d15ce/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 10000
    })
})

In [6]:
raw_datasets["train"].column_names

['text']

In [7]:
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

In [10]:
from transformers.testing_utils import CaptureLogger
from transformers import AutoTokenizer
import transformers

tokenizer = AutoTokenizer.from_pretrained('./malay-cased-gpt2')

In [11]:
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

def tokenize_function(examples):
    with CaptureLogger(tok_logger) as cl:
        output = tokenizer(examples[text_column_name])
    # clm input could be much much longer than block_size
    if "Token indices sequence length is longer than the" in cl.out:
        tok_logger.warning(
            "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
            " before being passed to the model."
        )
    return output

In [12]:
tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=column_names,
    desc="Running tokenizer on dataset",
)

Running tokenizer on dataset:   0%|          | 0/10 [00:00<?, ?ba/s]

In [20]:
from itertools import chain

block_size = 1024

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size) if len(t[i : i + block_size]) == block_size]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [21]:
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    desc=f"Grouping texts in chunks of {block_size}",
)

Grouping texts in chunks of 1024:   0%|          | 0/10 [00:00<?, ?ba/s]

In [29]:
lm_datasets['train']['input_ids'][0]

[202,
 17724,
 78,
 2459,
 339,
 264,
 277,
 450,
 46,
 264,
 277,
 16,
 20319,
 5,
 440,
 879,
 312,
 269,
 354,
 6986,
 1731,
 3568,
 8288,
 12,
 560,
 2810,
 3386,
 3964,
 5172,
 292,
 3664,
 281,
 8346,
 16451,
 17,
 202,
 1188,
 1550,
 3664,
 388,
 5343,
 328,
 579,
 91,
 17,
 202,
 202,
 43,
 5565,
 82,
 16,
 3474,
 1172,
 17,
 202,
 43,
 5565,
 82,
 16,
 3474,
 1172,
 354,
 3748,
 962,
 2061,
 517,
 399,
 673,
 606,
 15074,
 9945,
 3696,
 281,
 339,
 264,
 7274,
 14056,
 87,
 15,
 14789,
 17,
 202,
 1225,
 560,
 15074,
 2142,
 14789,
 3986,
 81,
 3278,
 17,
 202,
 27736,
 1541,
 2061,
 517,
 4645,
 5816,
 72,
 708,
 443,
 295,
 15,
 3463,
 430,
 9299,
 297,
 778,
 10369,
 1334,
 1267,
 15,
 20080,
 17,
 202,
 1225,
 27940,
 430,
 3345,
 297,
 9747,
 367,
 440,
 5565,
 82,
 16,
 3474,
 1172,
 292,
 910,
 15,
 502,
 2061,
 1181,
 2121,
 2423,
 3051,
 367,
 1195,
 450,
 58,
 299,
 5816,
 72,
 708,
 443,
 295,
 1073,
 202,
 1225,
 486,
 1128,
 565,
 536,
 1217,
 15074,
 388,
 15641,

In [28]:
lm_datasets['train']['labels'][0]

[202,
 17724,
 78,
 2459,
 339,
 264,
 277,
 450,
 46,
 264,
 277,
 16,
 20319,
 5,
 440,
 879,
 312,
 269,
 354,
 6986,
 1731,
 3568,
 8288,
 12,
 560,
 2810,
 3386,
 3964,
 5172,
 292,
 3664,
 281,
 8346,
 16451,
 17,
 202,
 1188,
 1550,
 3664,
 388,
 5343,
 328,
 579,
 91,
 17,
 202,
 202,
 43,
 5565,
 82,
 16,
 3474,
 1172,
 17,
 202,
 43,
 5565,
 82,
 16,
 3474,
 1172,
 354,
 3748,
 962,
 2061,
 517,
 399,
 673,
 606,
 15074,
 9945,
 3696,
 281,
 339,
 264,
 7274,
 14056,
 87,
 15,
 14789,
 17,
 202,
 1225,
 560,
 15074,
 2142,
 14789,
 3986,
 81,
 3278,
 17,
 202,
 27736,
 1541,
 2061,
 517,
 4645,
 5816,
 72,
 708,
 443,
 295,
 15,
 3463,
 430,
 9299,
 297,
 778,
 10369,
 1334,
 1267,
 15,
 20080,
 17,
 202,
 1225,
 27940,
 430,
 3345,
 297,
 9747,
 367,
 440,
 5565,
 82,
 16,
 3474,
 1172,
 292,
 910,
 15,
 502,
 2061,
 1181,
 2121,
 2423,
 3051,
 367,
 1195,
 450,
 58,
 299,
 5816,
 72,
 708,
 443,
 295,
 1073,
 202,
 1225,
 486,
 1128,
 565,
 536,
 1217,
 15074,
 388,
 15641,