In [1]:
import pandas as pd
import numpy as np
import torch
import transformers
import datasets
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
from datasets import Dataset
from easse.bleu import sentence_bleu

In [None]:
# Use a pipeline as a high-level helper
from transformers import pipeline
pipe = pipeline("text2text-generation", model="fnlp/bart-base-chinese")
#pipe(lines_orig[0:10], max_length=30, min_length=10) #example of how to run the pipeline on input sequences

Device set to use cpu


In [2]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("fnlp/bart-base-chinese")
model = AutoModelForSeq2SeqLM.from_pretrained("fnlp/bart-base-chinese")

In [3]:
lines_orig = []
with open('../mcts-main/dataset/mcts.dev.orig', encoding="utf8") as f:
    lines_orig = f.read().splitlines()

lines_ref = []
for dataset in range(0,5):
    filename = str('../mcts-main/dataset/mcts.dev.simp.'+str(dataset))
    with open(filename, encoding="utf8") as f:
        lines_ref.append(f.read().splitlines())

data_dict = {'orig': lines_orig, 'ref': lines_ref[0]}

In [4]:
ds = Dataset.from_dict(data_dict)

In [5]:
max_input = 512
max_target = 512

In [6]:
# tokenize data
def preprocess_data(data):
    input_data = [dialogue for dialogue in data['orig']]
    targets = [dialogue for dialogue in data['ref']]
    inputs = tokenizer(input_data, max_length=max_input, padding='max_length', truncation=True)
    targets = tokenizer(targets, max_length=max_target, padding='max_length', truncation=True, text_target='List[str]')
    #set labels
    inputs['labels'] = targets['input_ids']
    #return the tokenized data
    #input_ids, attention_mask and labels
    return inputs

tokenized_data = ds.map(preprocess_data, batched=True)

Map:   0%|          | 0/366 [00:00<?, ? examples/s]

In [22]:
# sentence_bleu(
#     sys_sent: str,
#     ref_sents: List[str],
#     smooth_method: str = "floor",
#     smooth_value: float = None,
#     lowercase: bool = False,
#     tokenizer: str = "13a",
#     effective_order: bool = True,
# ):

# return bleu_scorer.corpus_score(
#         sys_sents,
#         refs_sents,
#     ).score

In [7]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    
    bleu_score = sentence_bleu(sys_sent=preds, ref_sents=[labels])
    return {
        'bleu': bleu_score
    }

In [8]:
args = Seq2SeqTrainingArguments(
    'simplified_save', #save directory
    #eval_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size= 10,
    gradient_accumulation_steps=2,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=3,
    predict_with_generate=True,
    eval_accumulation_steps=3,
    fp16=False #available only with CUDA
    )

trainer = Seq2SeqTrainer(
    model, 
    args,
    train_dataset=tokenized_data,
    #eval_dataset=tokenized_data['labels'],
    #data_collator=collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

  trainer = Seq2SeqTrainer(


In [9]:
trainer.train()

  0%|          | 0/54 [00:00<?, ?it/s]



{'train_runtime': 7885.0894, 'train_samples_per_second': 0.139, 'train_steps_per_second': 0.007, 'train_loss': 2.42289959942853, 'epoch': 2.86}


TrainOutput(global_step=54, training_loss=2.42289959942853, metrics={'train_runtime': 7885.0894, 'train_samples_per_second': 0.139, 'train_steps_per_second': 0.007, 'total_flos': 320721377034240.0, 'train_loss': 2.42289959942853, 'epoch': 2.864864864864865})

In [33]:
print(lines_orig[0])

一、在中华人民共和国领土内及中华人民共和国注册的运输工具上就业或者工作的最低年龄为十六周岁；


In [35]:
#tokenize the conversation
sentence = lines_orig[0]
model_inputs = tokenizer(sentence,  max_length=max_input, padding='max_length', truncation=True, return_token_type_ids=False)
#make prediction
raw_pred, _, _ = trainer.predict([model_inputs])
#decode the output
print(lines_orig[0])
print(tokenizer.decode(raw_pred[0]))

  0%|          | 0/1 [00:00<?, ?it/s]

一、在中华人民共和国领土内及中华人民共和国注册的运输工具上就业或者工作的最低年龄为十六周岁；
[SEP] [CLS] 一 、 在 中 华 人 民 共 和 国 领 土 内 的 运 输 工 具 [SEP]
