In [45]:
import os
import random

from pymongo import MongoClient

In [46]:
if not os.path.exists('clm_train.txt') or not os.path.exists('clm_valid.txt'):
    limit = 100000
    c = MongoClient()
    #如果文档数量过多，则需要分批进行训练
    docs = list(c.article.crawl.find(projection=['summary'], limit=limit))
    random.shuffle(docs)
    n_doc = len(docs)
    n_train = 1000 #int(0.9 * n_doc)
    n_valid = 1100 #n_doc
    with open('clm_train.txt', 'w') as fw:
        for doc in docs[:n_train]:
            fw.write(doc['summary'])
            
    with open('clm_valid.txt', 'w') as fw:
        for doc in docs[n_train:n_valid]:
            fw.write(doc['summary'])
            

In [1]:
from datasets import load_dataset
datasets = load_dataset("text", data_files={"train": 'clm_train.txt', "validation": 'clm_valid.txt'})

Using custom data configuration default-ba418acf56910de3
Reusing dataset text (/home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


In [2]:
model_checkpoint = "hfl/chinese-xlnet-base"

In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [4]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

In [5]:
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

Loading cached processed dataset at /home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-5ead30a09a33aaa8.arrow
Loading cached processed dataset at /home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-65e70f768df2b145.arrow
Loading cached processed dataset at /home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-1618ffc1edb35c25.arrow
Loading cached processed dataset at /home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-b554fb49e0341706.arrow
Loading cached processed dataset at /home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-3d2

In [6]:
block_size = 128

In [7]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [8]:
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

Loading cached processed dataset at /home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-acb1fdeb0e8502a1.arrow
Loading cached processed dataset at /home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-14264015c3e3949d.arrow
Loading cached processed dataset at /home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-3937d3f26488a288.arrow
Loading cached processed dataset at /home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-c19169df58e012be.arrow
Loading cached processed dataset at /home/min/.cache/huggingface/datasets/text/default-ba418acf56910de3/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-c9d

In [9]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

'、无人机和828辆消防车参与了灭火工作。<sep><cls>  报道称,土耳其过去8天发生近200起森林野火,其中15起林火未获控制。西南部木拉省一处火场延烧附近火力发电厂边缘,电厂已疏散人员,易燃易爆物也已移除。<sep><cls>  土耳其总统埃尔多安称:“这座火力发电厂可能被烧成平地......要不是风势猛烈,火势早就控制下来。”<sep><cls>  于一周前最早传出大火的南部安塔利亚'

In [10]:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)

In [11]:
def textgen():
    prefix = '在所有的木偶表演中，提线木偶难度最大。提线木偶是演员自上而下以数十条丝线操纵木偶表演的艺术。演员必须熟练掌握10多种理线技巧和30多种组织提线以表演各个行当、各种动作的“线规”，才有资格走上舞台。'
    prompt_text = prefix + '我希望'
    encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
    output_sequences = model.generate(
        encoded_prompt, 
        max_length=20+len(encoded_prompt[0]),
        temperature=0.1,
        top_k=0,
        top_p=0.9,
        repetition_penalty=1.0,
        do_sample=True,
    )
    print(tokenizer.decode(output_sequences[0], clean_up_tokenization_spaces=True))

In [12]:
textgen()

在所有的木偶表演中,提线木偶难度最大。提线木偶是演员自上而下以数十条丝线操纵木偶表演的艺术。演员必须熟练掌握10多种理线技巧和30多种组织提线以表演各个行当、各种动作的“线规”,才有资格走上舞台。我希望,在我看来,提线木偶是个好演员。 好演员,是个


In [1]:
from transformers import Trainer, TrainingArguments

In [2]:
training_args = TrainingArguments(
    "test-clm",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01
)

  return torch._C._cuda_getDeviceCount() > 0


In [15]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
)

In [None]:
trainer.train()

***** Running training *****
  Num examples = 8094
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 3036


Epoch,Training Loss,Validation Loss


In [21]:
textgen()

In [5]:
import torch
torch.cuda.is_available()

False