In [1]:
import mindspore
import argparse
import numpy as np
import logging
import mindspore.dataset as ds
import os

import json

from tqdm import tqdm
from datetime import datetime
from mindspore.nn import CrossEntropyLoss
from mindspore import nn, ops
from mindspore.train.serialization import save_checkpoint
from mindspore.dataset import TextFileDataset

from mindnlp.transforms import BertTokenizer
from mindnlp.modules import Accumulator
from mindnlp.models import GPT2Config, GPT2LMHeadModel

[ERROR] ME(19477:140382442993472,MainProcess):2023-05-11-01:57:24.466.974 [mindspore/run_check/_check_version.py:226] Cuda ['10.1', '11.1', '11.6'] version(libcu*.so need by mindspore-gpu) is not found. Please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, or check whether the CUDA version in wheel package and the CUDA runtime in current device matches. Please refer to the installation guidelines: https://www.mindspore.cn/install
[ERROR] ME(19477:140382442993472,MainProcess):2023-05-11-01:57:24.489.465 [mindspore/run_check/_check_version.py:226] Cuda ['10.1', '11.1', '11.6'] version(libcudnn*.so need by mindspore-gpu) is not found. Please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, or check whether the CUDA version in wheel package and the CUDA runtime in current device matches. Please refer to the installation guidelines: https://www.mindspore.cn/install
  from tqdm.autonotebook import tqdm


In [2]:
epochs = 6
batch_size = 8

lr = 1e-4
warmup_steps = 2000
accumulate_step = 2
max_grad_norm = 1.0

log_step = 100

In [3]:
from mindnlp.utils import cache_file

url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path, _ = cache_file('train_with_summ.txt', './', url)

In [4]:
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()

50000

In [5]:
train_dataset, eval_dataset, test_dataset = dataset.split([0.8, 0.1, 0.1])

In [6]:
# article: [CLS] xxxxx [SEP]
# summary: [CLS] xxxxx [SEP]

In [7]:
import numpy as np

def process_dataset(dataset, tokenizer, batch_size=8, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        article_len = len(article)
        summary_len = len(summary)

        sep_id = np.array([tokenizer.sep_token_id])
        pad_id = np.array([tokenizer.pad_token_id])
        if article_len + summary_len > max_seq_len:
            new_article_len = max_seq_len - summary_len
            merged = np.concatenate([article[:new_article_len], sep_id, summary[1:]])
        elif article_len + summary_len - 1 < max_seq_len:
            pad_len = max_seq_len - article_len - summary_len + 1
            pad_text = np.array([tokenizer.pad_token_id] * pad_len)
            merged = np.concatenate([article, summary[1:], pad_text])
        else:
            merged = np.concatenate([article, summary[1:]])
            
        return merged.astype(np.int32)

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(tokenizer, 'article')
    dataset = dataset.map(tokenizer, 'summary')
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], 'input_ids')
    
    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset

In [8]:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

In [9]:
train_dataset = process_dataset(train_dataset, tokenizer)
eval_dataset = process_dataset(eval_dataset, tokenizer)
test_dataset = process_dataset(test_dataset, tokenizer)

In [10]:
next(train_dataset.create_tuple_iterator())

[Tensor(shape=[8, 1024], dtype=Int32, value=
 [[ 101,  126, 3299 ...    0,    0,    0],
  [ 101,  704, 1744 ...    0,    0,    0],
  [ 101, 1957, 2094 ...    0,    0,    0],
  ...
  [ 101,  868, 5442 ...    0,    0,    0],
  [ 101, 1298, 3175 ... 2658,  511,  102],
  [ 101,  704, 3173 ...    0,    0,    0]])]

In [11]:
len(tokenizer)

21128

In [12]:
from mindnlp._legacy.amp import auto_mixed_precision

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2LMHeadModel(config, ignore_index=tokenizer.pad_token_id)
model = auto_mixed_precision(model, 'O1')

optimizer = nn.AdamWeightDecay(model.trainable_params(), lr)
accumulator = Accumulator(optimizer, accumulate_step, max_grad_norm)



In [13]:
# Define forward function
def forward_fn(input_ids, labels):
    outputs = model(input_ids, labels=labels)
    loss = outputs[0]
    return loss / accumulate_step

# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, model.trainable_params())

# Define function of one-step training
@mindspore.jit
def train_step(data, label):
    loss, grads = grad_fn(data, label)
    accumulator(grads)
    return loss

In [14]:
# 记录模型参数数量
num_parameters = 0
parameters = model.trainable_params()
for parameter in parameters:
    num_parameters += parameter.numel()
print('number of model parameters: {}'.format(num_parameters))

number of model parameters: 118295040


In [None]:
from tqdm import tqdm

total = train_dataset.get_dataset_size()

for epoch in range(epochs):
    with tqdm(total=total) as progress:
        progress.set_description(f'Epoch {epoch}')
        loss_total = 0
        cur_step_nums = 0
        for batch_idx, (input_ids,) in enumerate(train_dataset.create_tuple_iterator()):
            cur_step_nums += 1
            loss = train_step(input_ids, input_ids)
            loss_total += loss

            progress.set_postfix(loss=loss_total/cur_step_nums)
            progress.update(1)

Epoch 0: 100%|███████████████████████████████████████████| 5000/5000 [1:18:41<00:00,  1.06it/s, loss=3.4540372]
Epoch 1: 100%|███████████████████████████████████████████| 5000/5000 [1:18:16<00:00,  1.06it/s, loss=3.5189726]
Epoch 2: 100%|████████████████████████████████████████████| 5000/5000 [1:18:15<00:00,  1.06it/s, loss=3.714254]
Epoch 3:  41%|██████████████████▍                          | 2042/5000 [31:59<47:31,  1.04it/s, loss=3.9296877]