In [None]:
import mindspore
import argparse
import numpy as np
import logging
import mindspore.dataset as ds
import os

import json

from tqdm import tqdm
from datetime import datetime
from mindspore.nn import CrossEntropyLoss
from mindspore import nn, ops
from mindspore.train.serialization import save_checkpoint
from mindspore.dataset import TextFileDataset

from mindnlp.transforms import BertTokenizer
from mindnlp.modules import Accumulator
from mindnlp.models import GPT2Config, GPT2LMHeadModel

In [None]:
epochs = 5
batch_size = 8

lr = 1e-4
warmup_steps = 2000
accumulate_step = 2
max_grad_norm = 1.0

log_step = 100

In [None]:
from mindnlp.utils import cached_file

url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path, _ = cached_file('train_with_summ.txt', './', url)

In [None]:
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()

In [None]:
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

In [None]:
# article: [CLS] xxxxx [SEP]
# summary: [CLS] xxxxx [SEP]

In [None]:
import numpy as np

def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        article_len = len(article)
        summary_len = len(summary)

        sep_id = np.array([tokenizer.sep_token_id])
        pad_id = np.array([tokenizer.pad_token_id])
        if article_len + summary_len > max_seq_len:
            new_article_len = max_seq_len - summary_len
            merged = np.concatenate([article[:new_article_len], sep_id, summary[1:]])
        elif article_len + summary_len - 1 < max_seq_len:
            pad_len = max_seq_len - article_len - summary_len + 1
            pad_text = np.array([tokenizer.pad_token_id] * pad_len)
            merged = np.concatenate([article, summary[1:], pad_text])
        else:
            merged = np.concatenate([article, summary[1:]])
            
        return merged.astype(np.int32)

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(tokenizer, 'article')
    dataset = dataset.map(tokenizer, 'summary')
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], 'input_ids')
    
    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

In [None]:
train_dataset = process_dataset(train_dataset, tokenizer)

In [None]:
next(train_dataset.create_tuple_iterator())

In [None]:
len(tokenizer)

In [None]:
from mindnlp._legacy.amp import auto_mixed_precision

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2LMHeadModel(config, ignore_index=tokenizer.pad_token_id)
model = auto_mixed_precision(model, 'O1')

optimizer = nn.AdamWeightDecay(model.trainable_params(), lr)
accumulator = Accumulator(optimizer, accumulate_step, max_grad_norm)

In [None]:
from mindspore import ops, ms_function
from mindnlp._legacy.amp import DynamicLossScaler, all_finite
# Define forward function

loss_scaler = DynamicLossScaler(scale_value=2**10, scale_factor=2, scale_window=50)

def forward_fn(input_ids, labels):
    outputs = model(input_ids, labels=labels)
    loss = outputs[0]
    return loss_scaler.scale(loss / accumulate_step)

# Get gradient function
grad_fn = ops.value_and_grad(forward_fn, None, model.trainable_params())

# Define function of one-step training
@ms_function
def train_step(data, label):
    loss, grads = grad_fn(data, label)
    loss = loss_scaler.unscale(loss)

    is_finite = all_finite(grads)
    if is_finite:
        grads = loss_scaler.unscale(grads)
        loss = ops.depend(loss, accumulator(grads))
    loss = ops.depend(loss, loss_scaler.adjust(is_finite))
    return loss

In [None]:
# 记录模型参数数量
num_parameters = 0
parameters = model.trainable_params()
for parameter in parameters:
    num_parameters += parameter.numel()
print('number of model parameters: {}'.format(num_parameters))

In [None]:
from tqdm import tqdm

total = train_dataset.get_dataset_size()

for epoch in range(epochs):
    with tqdm(total=total) as progress:
        progress.set_description(f'Epoch {epoch}')
        loss_total = 0
        cur_step_nums = 0
        for batch_idx, (input_ids,) in enumerate(train_dataset.create_tuple_iterator()):
            cur_step_nums += 1
            loss = train_step(input_ids, input_ids)
            loss_total += loss

            progress.set_postfix(loss=loss_total/cur_step_nums)
            progress.update(1)
        save_checkpoint(model, f"gpt2_summarization_epoch_{epoch}.ckpt")

In [None]:
def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def pad(article):
        article_len = len(article)
        max_article_len = max_seq_len - max_summary_len
        if article_len >= max_article_len:
            article = np.concatenate([article[:max_article_len-1], np.array([tokenizer.sep_token_id]), np.array([tokenizer.pad_token_id] * max_summary_len)])
            return article, max_article_len
        else:
            pad_len = max_seq_len - len(article)
            article = np.concatenate([article, np.array([tokenizer.pad_token_id] * pad_len)])
            return article, article_len

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(tokenizer, 'article')
    dataset = dataset.map(pad, 'article', ['input_ids', 'article_len'])
    
    dataset = dataset.batch(batch_size)

    return dataset

In [None]:
batched_test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)

In [None]:
print(next(batched_test_dataset.create_tuple_iterator(output_numpy=True)))

In [None]:
from mindspore import Tensor
# [CLS] article [SEP] [PAD]
# article [SEP] [PAD]
def generate(input_ids, article_len, model, tokenizer, max_summary_len=100):
    curr_idx = int(article_len)
    for i in range(max_summary_len):
        outputs = model(Tensor(input_ids))
        output_id = outputs[0].asnumpy()[:,curr_idx - 1].argmax()

        if output_id == tokenizer.sep_token_id:
            break
        input_ids[:, curr_idx] = output_id
        curr_idx += 1
    output_ids = input_ids[:, int(article_len):curr_idx]
    return output_ids

In [None]:
model.set_train(False)

i = 0
for (input_ids, article_len, raw_summary) in batched_test_dataset.create_tuple_iterator(output_numpy=True):
    output_ids = generate(input_ids, article_len, model, tokenizer)
    output_text = tokenizer.decode(output_ids[0].tolist())
    print(output_text)
    i += 1
    if i == 10:
        break