参考:
https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/chinese_gpt2
https://github.com/yuanzhoulvpi2017/zero_nlp/blob/main/chinese_gpt2/train_chinese_gpt2.ipynb

In [None]:
from datasets import load_dataset, DatasetDict
from glob import glob
import random
random.seed(42)

all_file_list = glob(pathname="gpt2_data/*/**")
test_file_list = random.sample(all_file_list, 50)
train_file_list = [i for i in all_file_list if i not in test_file_list]

len(train_file_list), len(test_file_list)

raw_datasets =load_dataset("csv",data_files={'train':train_file_list,'valid':test_file_list}, cache_dir="cache_data")

print(raw_datasets)



In [None]:
from transformers import AutoTokenizer

context_length = 512
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

outputs = tokenizer(
    raw_datasets["train"][:2]["content"],
    truncation=True,
    max_length=context_length,
    return_overflowing_tokens=True,
    return_length=True,
)

print(f"Input IDs length: {len(outputs['input_ids'])}")
print(f"Input chunk lengths: {(outputs['length'])}")
print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")

tokenizer.add_special_tokens(special_tokens_dict={'bos_token': '<|endoftext|>',
 'eos_token': '<|endoftext|>',
 'unk_token': '<|endoftext|>'})


def tokenize(element):
    outputs = tokenizer(
        element["content"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)
print(tokenized_datasets)

In [None]:
from transformers import GPT2LMHeadModel, AutoConfig

config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

model = GPT2LMHeadModel(config)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 size: {model_size/1000**2:.1f}M parameters")


from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)


out = data_collator([tokenized_datasets["train"][i] for i in range(5)])
for key in out:
    print(f"{key} shape: {out[key].shape}")


from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="suanni", # chinese_gpt2_big
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy="steps",
    eval_steps=2_000,
    logging_steps=2_000,
    gradient_accumulation_steps=8,
    num_train_epochs=2,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=2_000,
    fp16=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
)

In [None]:
trainer.train()