In [1]:
import mindspore
import argparse
import numpy as np
import logging
import mindspore.dataset as ds
import os

import json

from tqdm import tqdm
from datetime import datetime
from mindspore.nn import CrossEntropyLoss
from mindspore import nn, ops
from mindspore.train.serialization import save_checkpoint
from mindspore.dataset import TextFileDataset

from mindnlp.transforms import BertTokenizer
from mindnlp.modules import Accumulator
from mindnlp.models import GPT2Config, GPT2LMHeadModel

In [2]:
epochs = 6
batch_size = 8

lr = 1e-4
warmup_steps = 2000
accumulate_step = 2
max_grad_norm = 1.0

log_step = 100

In [3]:
from mindnlp.utils import cache_file

url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path, _ = cache_file('train_with_summ.txt', './', url)

In [4]:
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()

50000

In [5]:
train_dataset, eval_dataset, test_dataset = dataset.split([0.8, 0.1, 0.1])

In [6]:
# article: [CLS] xxxxx [SEP]
# summary: [CLS] xxxxx [SEP]

In [7]:
import numpy as np

def process_dataset(dataset, tokenizer, batch_size=8, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        article_len = len(article)
        summary_len = len(summary)

        sep_id = np.array([tokenizer.sep_token_id])
        pad_id = np.array([tokenizer.pad_token_id])
        if article_len + summary_len > max_seq_len:
            new_article_len = max_seq_len - summary_len
            merged = np.concatenate([article[:new_article_len], sep_id, summary[1:]])
        elif article_len + summary_len - 1 < max_seq_len:
            pad_len = max_seq_len - article_len - summary_len + 1
            pad_text = np.array([tokenizer.pad_token_id] * pad_len)
            merged = np.concatenate([article, summary[1:], pad_text])
        else:
            merged = np.concatenate([article, summary[1:]])
            
        return merged.astype(np.int32)

    dataset = dataset.map(read_map, 'text', ['article', 'summary'], ['article', 'summary'])
    dataset = dataset.map(tokenizer, 'article')
    dataset = dataset.map(tokenizer, 'summary')
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids'], ['input_ids'])
    
    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset

In [8]:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

In [9]:
train_dataset = process_dataset(train_dataset, tokenizer)
eval_dataset = process_dataset(eval_dataset, tokenizer)
test_dataset = process_dataset(test_dataset, tokenizer)

In [10]:
next(train_dataset.create_tuple_iterator())

[Tensor(shape=[8, 1024], dtype=Int32, value=
 [[ 101, 8119,  118 ...    0,    0,    0],
  [ 101, 3343, 1814 ... 7164,  511,  102],
  [ 101, 6598, 3160 ...    0,    0,    0],
  ...
  [ 101,  928, 2622 ...    0,    0,    0],
  [ 101, 3173, 3857 ...    0,    0,    0],
  [ 101, 1346, 5440 ...    0,    0,    0]])]

In [11]:
len(tokenizer)

21128

In [12]:
from mindnlp._legacy.amp import auto_mixed_precision

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2LMHeadModel(config, ignore_index=tokenizer.pad_token_id)
model = auto_mixed_precision(model, 'O1')

optimizer = nn.AdamWeightDecay(model.trainable_params(), lr)
accumulator = Accumulator(optimizer, accumulate_step, max_grad_norm)

In [13]:
from mindspore import ops, ms_function
from mindspore.amp import init_status, all_finite, DynamicLossScaler
# Define forward function

loss_scaler = DynamicLossScaler(scale_value=2**10, scale_factor=2, scale_window=1000)

def forward_fn(input_ids, labels):
    outputs = model(input_ids, labels=labels)
    loss = outputs[0]
    return loss_scaler.scale(loss / accumulate_step)

# Get gradient function
grad_fn = ops.value_and_grad(forward_fn, None, model.trainable_params())

# Define function of one-step training
@ms_function
def train_step(data, label):
    status = init_status()
    data = ops.depend(data, status)
    loss, grads = grad_fn(data, label)
    loss = loss_scaler.unscale(loss)

    is_finite = all_finite(grads, status)
    if is_finite:
        grads = loss_scaler.unscale(grads)
        loss = ops.depend(loss, accumulator(grads))
    loss = ops.depend(loss, loss_scaler.adjust(is_finite))
    return loss, is_finite

In [None]:
from tqdm import tqdm

total = train_dataset.get_dataset_size()

for epoch in range(epochs):
    with tqdm(total=total) as progress:
        progress.set_description(f'Epoch {epoch}')
        loss_total = 0
        cur_step_nums = 0
        for batch_idx, (input_ids,) in enumerate(train_dataset.create_tuple_iterator()):
            cur_step_nums += 1
            loss, is_finite = train_step(input_ids, input_ids)
            loss_total += loss

            progress.set_postfix(loss=loss_total/cur_step_nums, finite=is_finite, scale_value=loss_scaler.scale_value.asnumpy())
            progress.update(1)
        save_checkpoint(model, f'gpt_summarization_epoch_{epoch}.ckpt')

Epoch 0: 100%|██████████| 5000/5000 [1:08:18<00:00,  1.22it/s, finite=True, loss=2.9114637, scale_value=32768.0]
Epoch 1: 100%|██████████| 5000/5000 [1:06:15<00:00,  1.26it/s, finite=True, loss=2.100337, scale_value=1048576.0]
Epoch 2:  60%|█████▉    | 2997/5000 [39:09<25:40,  1.30it/s, finite=True, loss=1.8634704, scale_value=4194304.0]  