In [1]:
#全局变量
hub_token = open('/root/hub_token.txt').read().strip()
repo_id = 'lansinuote/nlp.8.generation'
push_to_hub = True

In [2]:
from transformers import AutoTokenizer

#加载编码器
tokenizer = AutoTokenizer.from_pretrained('gpt2')

#添加pad
#tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})

print(tokenizer)

#编码试算
tokenizer.batch_encode_plus([
    'hide new secretions from the parental units',
    'contains no wit , only labored gags'
])

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'})


{'input_ids': [[24717, 649, 3200, 507, 422, 262, 21694, 4991], [3642, 1299, 645, 20868, 837, 691, 2248, 1850, 308, 3775]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

In [3]:
from datasets import load_dataset, concatenate_datasets


def get_dataset():
    #加载数据
    dataset = load_dataset('imdb')

    #重新切分数据集
    dataset = concatenate_datasets(
        [dataset['train'], dataset['test'], dataset['unsupervised']])

    dataset = dataset.train_test_split(test_size=0.01, seed=0)

    #采样,数据量太大了跑不动
    dataset['train'] = dataset['train'].shuffle(0).select(range(80000))
    dataset['test'] = dataset['test'].shuffle(0).select(range(200))

    #分词
    def f(data):
        #移除<br/>
        for i in range(len(data['text'])):
            data['text'][i] = data['text'][i].replace('<br /><br />', ' ')

        data = tokenizer.batch_encode_plus(data['text'])

        return data

    dataset = dataset.map(f,
                          batched=True,
                          num_proc=4,
                          batch_size=1000,
                          remove_columns=['text', 'label'])

    #过滤掉太短的句子
    def f(data):
        return [sum(i) >= 25 for i in data['attention_mask']]

    dataset = dataset.filter(f, batched=True, num_proc=4, batch_size=1000)

    #拼合句子到统一的长度
    def f(data):
        block_size = 512

        #展平数据
        input_ids = []
        for i in data['input_ids']:
            input_ids.extend(i)

        #切断数据
        data = {'input_ids': [], 'attention_mask': []}
        for i in range(len(input_ids) // block_size):
            block = input_ids[i * block_size:i * block_size + block_size]
            data['input_ids'].append(block)
            data['attention_mask'].append([1] * block_size)

        #设置labels
        data['labels'] = data['input_ids'].copy()

        return data

    dataset = dataset.map(
        f,
        batched=True,
        batch_size=1000,
        num_proc=4,
    )

    return dataset


if push_to_hub:
    dataset = get_dataset()
    dataset.push_to_hub(repo_id=repo_id, token=hub_token)

#直接使用我处理好的数据集
dataset = load_dataset(path=repo_id)

dataset

Found cached dataset imdb (/root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached split indices for dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-38626b27fd3c3e47.arrow and /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c7f4724d606bfae9.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-9454b0d197fed5e7.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-6958bd3ad6d5941e.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-1f42b360a5a1a053.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c30d784104e55b6b.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-eeb64723393e8ef8.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-7c28727ea8e8b19e.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-ea23f3b4845c990f.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c55bd9b275a43cc3.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-69853f83df46ec3d.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-e8c2508ae15ca61c.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-fbe193a4a7cce931_00000_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-fbe193a4a7cce931_00001_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-fbe193a4a7cce931_00002_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-fbe193a4a7cce931_00003_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-03d7c2d18639a36f_00000_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-03d7c2d18639a36f_00001_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-03d7c2d18639a36f_00002_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-03d7c2d18639a36f_00003_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-aca74ade5419face.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-7ba71373fd178881.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-06f5e1cbe8c5517e.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-2d0413709e38aacd.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-54493ec401eaeea8.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-9753890946b1a2af.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-061bfcf64fa77821.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-e10badbdc7153a65.arrow
Pushing split train to the Hub.
Resuming upload of the dataset shards.


Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Pushing split test to the Hub.
Resuming upload of the dataset shards.


Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration lansinuote--nlp.8.generation-8b7658b3335fddd9


Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--nlp.8.generation-8b7658b3335fddd9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/95.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/266k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/44863 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/107 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--nlp.8.generation-8b7658b3335fddd9/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 44863
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 107
    })
})

In [4]:
import torch
from transformers.data.data_collator import default_data_collator

#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset['train'],
    batch_size=8,
    collate_fn=default_data_collator,
    shuffle=True,
    drop_last=True,
)

for i, data in enumerate(loader):
    break

len(loader), data

(5607,
 {'input_ids': tensor([[ 1842,   284,   511,  ..., 18974,   286,   262],
          [   13,  1375, 11258,  ...,  2223,   318, 39976],
          [23304,   393,  1997,  ...,  2222,  3589,   284],
          ...,
          [14169, 30953,   475,  ...,   284,   651,   262],
          [ 2646,    13,  1675,  ..., 19147,    11,   290],
          [   30,   383,  3807,  ...,   470,   651,   651]]),
  'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]),
  'labels': tensor([[ 1842,   284,   511,  ..., 18974,   286,   262],
          [   13,  1375, 11258,  ...,  2223,   318, 39976],
          [23304,   393,  1997,  ...,  2222,  3589,   284],
          ...,
          [14169, 30953,   475,  ...,   284,   651,   262],
          [ 2646,    13,  1675,  ..., 19147,    11,   290],
          [   30,   383,  3

In [5]:
from transformers import AutoModelForCausalLM, GPT2Model, PreTrainedModel, PretrainedConfig

#加载模型
#model = AutoModelForCausalLM.from_pretrained('gpt2')


#定义下游任务模型
class Model(PreTrainedModel):
    config_class = PretrainedConfig

    def __init__(self, config):
        super().__init__(config)
        self.pretrained = GPT2Model.from_pretrained('gpt2')
        self.fc = torch.nn.Linear(768, tokenizer.vocab_size, bias=False)

        #加载预训练模型的参数
        parameters = AutoModelForCausalLM.from_pretrained('gpt2')
        self.fc.load_state_dict(parameters.lm_head.state_dict())

        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels):
        logits = self.pretrained(input_ids=input_ids,
                                 attention_mask=attention_mask)
        logits = logits.last_hidden_state

        logits = self.fc(logits)

        shift_logits = logits[:, :-1].flatten(end_dim=1)
        shift_labels = labels[:, 1:].flatten()

        loss = self.criterion(shift_logits, shift_labels)

        return {
            'loss': loss,
            'logits': logits,
        }


model = Model(PretrainedConfig())

#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

with torch.no_grad():
    out = model(**data)

out['loss'], out['logits'].shape

16303.7184


(tensor(3.9160), torch.Size([8, 512, 50257]))

In [6]:
def generate(text):

    def generate_loop(data):
        with torch.no_grad():
            out = model(**data)

        #取最后一个字
        #[5, b, 50257]
        out = out['logits']
        #[5, 50257]
        out = out[:, -1]

        #第50大的值,以此为分界线,小于该值的全部赋值为负无穷
        #[5, 50257] -> [5, 50]
        topk_value = torch.topk(out, 50).values
        #[5, 50] -> [5] -> [5, 1]
        topk_value = topk_value[:, -1].unsqueeze(dim=1)

        #赋值
        #[5, 50257]
        out = out.masked_fill(out < topk_value, -float('inf'))

        #根据概率采样,无放回,所以不可能重复
        #[5, 50257] -> [5, 1]
        out = out.softmax(dim=1)
        out = out.multinomial(num_samples=1)

        data['input_ids'] = torch.cat([data['input_ids'], out], dim=1)
        data['attention_mask'] = torch.ones_like(data['input_ids'])
        data['labels'] = data['input_ids'].clone()

        if data['input_ids'].shape[1] >= 30:
            return data

        return generate_loop(data)

    #重复5遍
    data = tokenizer.batch_encode_plus([text] * 5, return_tensors='pt')
    data['labels'] = data['input_ids'].clone()

    data = generate_loop(data)

    for i in range(5):
        print(i, tokenizer.decode(data['input_ids'][i]))


generate('I love this')

0 I love this guy for that dude at times..."

"What was he doing when you said you'd 'leave'? He's so cool."
1 I love this book.


After an hour I found the book's back-story amusing. It was about an Englishman's journey from a
2 I love this game. I'm not complaining because it takes place in a great fantasy world." (Sarviv Smith)<|endoftext|>T-Mobile
3 I love this." You hear me?! How can we be all equal? Because we all see it every day...

And you take a step
4 I love this game and I am using my keyboard as my desktop controller!

Thanks to the other reviewers I have heard so many good words for


In [7]:
from transformers import AdamW
from transformers.optimization import get_scheduler


#训练
def train():
    optimizer = AdamW(model.parameters(), lr=2e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.train()
    model.to(device)

    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)
        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if i % 50 == 0:
            labels = data['labels'][:, 1:]
            out = out['logits'].argmax(dim=2)[:, :-1]

            accuracy = (labels == out).sum().item() / labels.numel()

            lr = optimizer.state_dict()['param_groups'][0]['lr']

            print(i, loss.item(), lr, accuracy)

    model.to('cpu')


if push_to_hub:
    train()
    model.push_to_hub(repo_id=repo_id, use_auth_token=hub_token)



0 4.082333564758301 1.9996433030140896e-05 0.28498043052837574
50 3.887619972229004 1.981808453718566e-05 0.3094422700587084
100 3.889702081680298 1.9639736044230427e-05 0.31237769080234834
150 3.979553461074829 1.9461387551275193e-05 0.30039138943248533
200 3.8883557319641113 1.9283039058319958e-05 0.31727005870841485
250 3.834787368774414 1.9104690565364724e-05 0.3084637964774951
300 3.9670369625091553 1.892634207240949e-05 0.3150684931506849
350 3.758570432662964 1.8747993579454255e-05 0.31727005870841485
400 4.112035751342773 1.856964508649902e-05 0.2837573385518591
450 3.6851837635040283 1.8391296593543786e-05 0.3221624266144814
500 3.706080675125122 1.821294810058855e-05 0.31727005870841485
550 3.8967978954315186 1.8034599607633317e-05 0.32069471624266144
600 3.845363140106201 1.7856251114678083e-05 0.3111545988258317
650 3.824476718902588 1.767790262172285e-05 0.31727005870841485
700 3.8658738136291504 1.7499554128767614e-05 0.3133561643835616
750 3.8619003295898438 1.7321205635

pytorch_model.bin:   0%|          | 0.00/665M [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

In [8]:
#直接使用我训练好的模型
model = Model.from_pretrained(repo_id)

generate('I love this')

Downloading (…)lve/main/config.json:   0%|          | 0.00/105 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/665M [00:00<?, ?B/s]

0 I love this film, its very much the work of John McQuaglish. If this movie had nothing else in it other than what it tells
1 I love this movie, I watched many of its adventures, I think it has something to do with who I am, how I treat my people.
2 I love this movie even though I haven't seen it many times now. I never thought I'd like this movie so much, but now, it
3 I love this show, and I hope others can enjoy it. It is certainly one of my top five favorite sitcoms. I would highly recommend this
4 I love this show (especially the last episode). The cast of the show are great - I loved Jessica Biel's performance as Elmer and Matt
