## Fine tuning text simplification with BART
Use pseudo wiki parallel data to fine tune Chinese BART for TS task

In [None]:
!pip install jieba evaluate sacrebleu sacremoses datasets

In [1]:
#runtime variables
HSK_dir = "../data/HSK/"
pseudo_dir = "../data/mcts-pseudo/"

%load_ext autoreload
%autoreload 2
# import pandas as pd
import numpy as np
import jieba
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
from datasets import Dataset
import pickle
# from easse.sari import sentence_sari
from evaluate import load
sari = load("sari")
with open(HSK_dir+"HSK_levels.pickle", 'rb') as handle:
    HSK_dict = pickle.load(handle)

In [2]:
from transformers import BertTokenizer, BartForConditionalGeneration
tokenizer = BertTokenizer.from_pretrained("fnlp/bart-base-chinese", local_files_only=True)
model = BartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese", local_files_only=True)

In [5]:
def tokenize_with_HSK(sentence, HSK_dict):
    return " ".join(jieba.cut(sentence))
    # split_sentence = jieba.lcut(sentence)
    # HSK_sentence = ""
    # for word in split_sentence:
    #     score = HSK_dict.get(word, 0)
    #     if score>2:
    #         HSK_sentence += f" {word}[{score}]"
    #     else:
    #         HSK_sentence += f" {word}"
    # return HSK_sentence

# tokenizer.add_tokens(["[3]", "[4]", "[5]", "[6]", "[7]"])
# model.resize_token_embeddings(len(tokenizer))

def preprocess_data(filename: str, start: int, stop: int):
    lines_HSK = []
    with open(filename, encoding="utf8") as f:
        lines_orig = f.read().splitlines()
        for line in lines_orig[start:stop]:
            lines_HSK.append(tokenize_with_HSK(line, HSK_dict))
            if len(lines_HSK)%1000==0:
                print(len(lines_HSK))
    return lines_HSK

In [6]:
start = 0
stop = 500000
split = 450000

lines_complex = preprocess_data(pseudo_dir+"zh_selected.ori", start, stop)
lines_simple = preprocess_data(pseudo_dir+"zh_selected.sim", start, stop)

data_dict = {'complex': lines_complex[start:split], 'simple': lines_simple[start:split]}
ds_train = Dataset.from_dict(data_dict)
data_dict = {'complex': lines_complex[split:stop], 'simple': lines_simple[split:stop]}
ds_eval = Dataset.from_dict(data_dict)

1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
100000
101000
102000
103000
104000
105000
106000
107000
108000
109000
110000
111000
112000
113000
114000
115000
116000
117000
118000
119000
120000
121000
122000
123000
124000
125000
126000
127000
128000
129000
130000
131000
132000
133000
134000
135000
136000
137000
138000
139000
140000
141000
142000
143000
144000
145000
146000
147000
148000
149000
150000
151000
152000
153000
154000
155000
156000
157000
158000
15

In [7]:
ds_eval

Dataset({
    features: ['complex', 'simple'],
    num_rows: 50000
})

In [8]:
# tokenize data
max_length = 128
def batch_tokenize_data(data):
    inputs = [example for example in data["complex"]]
    targets = [example for example in data["simple"]]

    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True)
    labels = tokenizer(targets, max_length=max_length, padding="max_length", truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_data_train = ds_train.map(batch_tokenize_data, batched=True)
tokenized_data_eval = ds_eval.map(batch_tokenize_data, batched=True)

Map:   0%|          | 0/450000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [10]:
import pickle
with open("../data/pseudo_train", "wb") as handle:
    pickle.dump(tokenized_data_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open("../data/pseudo_eval", "wb") as handle:
    pickle.dump(tokenized_data_eval, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [78]:
# optuna fine-tuning
import optuna

def search_space(trial):
    """ Define hyperparameter search space """
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 5e-4, log=True),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16]),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 3, 10),
        "weight_decay": trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)
    }

# Define training arguments
def model_init():
    """ Function to initialize the model for Trainer """
    from transformers import BartForConditionalGeneration
    return BartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese", local_files_only=True)

def compute_metrics(eval_preds):
    prediction_tokens = np.argmax(eval_preds.predictions[0], axis=-1)
    prediction_text = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in prediction_tokens]

    sari_score = sari.compute(
            predictions=prediction_text, # model output
            references=[[simple] for simple in tokenized_data_eval.select(range(20))['simple']], # reference simple sentences
            sources=tokenized_data_eval.select(range(20))['complex'] # complex sentence
        )
    return {"sari": sari_score["sari"]}

training_args = TrainingArguments(
    output_dir="./bart_hypersearch",
    eval_strategy="epoch",
    per_device_eval_batch_size=8,
    logging_dir="./logs",
    greater_is_better = True,
)

trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=tokenized_data_train.select(range(50)),  # Use a subset for quick tuning
    eval_dataset=tokenized_data_eval.select(range(20)),  
    processing_class = tokenizer,
    compute_metrics = compute_metrics,
)

# Run Optuna hyperparameter search
best_trial = trainer.hyperparameter_search(
    direction="maximize",  # For maximizing performance (adjust as needed)
    hp_space=search_space,
    n_trials=10  # Number of trials
)

print(best_trial)

[I 2025-02-24 10:53:09,532] A new study created in memory with name: no-name-8a3273ce-0de3-4dd8-bd3d-593dd95f35b2


  0%|          | 0/91 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

{'eval_loss': 0.46418604254722595, 'eval_sari': 31.37398563638061, 'eval_runtime': 10.275, 'eval_samples_per_second': 1.946, 'eval_steps_per_second': 0.292, 'epoch': 1.0}


  0%|          | 0/3 [00:00<?, ?it/s]

{'eval_loss': 0.34705060720443726, 'eval_sari': 31.811298312934973, 'eval_runtime': 7.9967, 'eval_samples_per_second': 2.501, 'eval_steps_per_second': 0.375, 'epoch': 2.0}


  0%|          | 0/3 [00:00<?, ?it/s]

{'eval_loss': 0.3692977726459503, 'eval_sari': 32.276282577824986, 'eval_runtime': 8.2613, 'eval_samples_per_second': 2.421, 'eval_steps_per_second': 0.363, 'epoch': 3.0}


  0%|          | 0/3 [00:00<?, ?it/s]

{'eval_loss': 0.4066435694694519, 'eval_sari': 32.20763792522312, 'eval_runtime': 8.0842, 'eval_samples_per_second': 2.474, 'eval_steps_per_second': 0.371, 'epoch': 4.0}


  0%|          | 0/3 [00:00<?, ?it/s]

{'eval_loss': 0.42069554328918457, 'eval_sari': 32.64845580637466, 'eval_runtime': 9.3641, 'eval_samples_per_second': 2.136, 'eval_steps_per_second': 0.32, 'epoch': 5.0}


  0%|          | 0/3 [00:00<?, ?it/s]

{'eval_loss': 0.4440982937812805, 'eval_sari': 32.62074075718358, 'eval_runtime': 8.554, 'eval_samples_per_second': 2.338, 'eval_steps_per_second': 0.351, 'epoch': 6.0}




  0%|          | 0/3 [00:00<?, ?it/s]

[I 2025-02-24 11:04:25,588] Trial 0 finished with value: 32.68186373525397 and parameters: {'learning_rate': 7.014021052299661e-05, 'per_device_train_batch_size': 4, 'num_train_epochs': 7, 'weight_decay': 6.136055886315135e-05}. Best is trial 0 with value: 32.68186373525397.


{'eval_loss': 0.4579606056213379, 'eval_sari': 32.68186373525397, 'eval_runtime': 8.5797, 'eval_samples_per_second': 2.331, 'eval_steps_per_second': 0.35, 'epoch': 7.0}
{'train_runtime': 674.4202, 'train_samples_per_second': 0.519, 'train_steps_per_second': 0.135, 'train_loss': 0.7262989924504206, 'epoch': 7.0}


  0%|          | 0/36 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

{'eval_loss': 7.0060625076293945, 'eval_sari': 31.94175703521805, 'eval_runtime': 11.449, 'eval_samples_per_second': 1.747, 'eval_steps_per_second': 0.262, 'epoch': 1.0}


[W 2025-02-24 11:07:05,083] Trial 1 failed with parameters: {'learning_rate': 2.451285646150543e-05, 'per_device_train_batch_size': 16, 'num_train_epochs': 9, 'weight_decay': 0.00021426933891344972} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "c:\Users\tempu\AppData\Local\Programs\Python\Python312\Lib\site-packages\optuna\study\_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "c:\Users\tempu\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\integrations\integration_utils.py", line 249, in _objective
    trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  File "c:\Users\tempu\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\trainer.py", line 2164, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\tempu\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\trainer

KeyboardInterrupt: 

In [None]:
import logging
logging.basicConfig(level=logging.INFO)

def compute_metrics(eval_preds):
    prediction_tokens = np.argmax(eval_preds.predictions[0], axis=-1)
    prediction_text = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in prediction_tokens]

    sari_score = sari.compute(
            predictions=prediction_text, # model output
            references=[[simple] for simple in tokenized_data_eval['simple']], # reference simple sentences
            sources=tokenized_data_eval['complex'] # complex sentence
        )
    return {"sari": sari_score["sari"]}

best_params = best_trial.hyperparameters
training_args = TrainingArguments(
    output_dir = "./bart_simplification",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate = best_params["learning_rate"],
    per_device_train_batch_size = best_params["per_device_train_batch_size"],
    num_train_epochs = best_params["num_train_epochs"],
    weight_decay = best_params["weight_decay"],
    logging_dir = "./logs",
    logging_steps = 500,
    greater_is_better = True,
)

training_args = TrainingArguments(
    output_dir="./bart_simplification",
    eval_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    eval_accumulation_steps = 1,
    num_train_epochs=5,
    weight_decay=0.01,
    save_total_limit=2,
    logging_steps=500,
    greater_is_better=True  # maximize SARI
)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = tokenized_data_train,
    eval_dataset = tokenized_data_eval,
    processing_class = tokenizer,
    compute_metrics = compute_metrics
)

  trainer = Trainer(


In [70]:
trainer.train()

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.47325244545936584, 'eval_sari': 34.432335363128274, 'eval_runtime': 11.419, 'eval_samples_per_second': 2.189, 'eval_steps_per_second': 0.35, 'epoch': 1.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.520830512046814, 'eval_sari': 34.25620275223107, 'eval_runtime': 11.2691, 'eval_samples_per_second': 2.218, 'eval_steps_per_second': 0.355, 'epoch': 2.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5226666331291199, 'eval_sari': 34.384029474241494, 'eval_runtime': 13.3294, 'eval_samples_per_second': 1.876, 'eval_steps_per_second': 0.3, 'epoch': 3.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5167384147644043, 'eval_sari': 34.578188307775484, 'eval_runtime': 13.252, 'eval_samples_per_second': 1.887, 'eval_steps_per_second': 0.302, 'epoch': 4.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5156217813491821, 'eval_sari': 34.538338728313626, 'eval_runtime': 10.8963, 'eval_samples_per_second': 2.294, 'eval_steps_per_second': 0.367, 'epoch': 5.0}
{'train_runtime': 615.4155, 'train_samples_per_second': 0.609, 'train_steps_per_second': 0.081, 'train_loss': 0.04471360206604004, 'epoch': 5.0}


TrainOutput(global_step=50, training_loss=0.04471360206604004, metrics={'train_runtime': 615.4155, 'train_samples_per_second': 0.609, 'train_steps_per_second': 0.081, 'total_flos': 28581396480000.0, 'train_loss': 0.04471360206604004, 'epoch': 5.0})

In [13]:
predictions = trainer.predict(tokenized_data_train)

  0%|          | 0/1 [00:00<?, ?it/s]

In [63]:
prediction_tokens = np.argmax(predictions.predictions[0], axis=-1)
prediction_text = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in prediction_tokens]

sari_score = sari.compute(
        predictions=prediction_text, # model output
        references=[[simple] for simple in tokenized_data_train['simple']], # reference simple sentences
        sources=tokenized_data_train['complex'] # complex sentence
    )

In [None]:
idx = 0
sentence = lines_complex[idx]
from transformers import Text2TextGenerationPipeline
text2text_generator = Text2TextGenerationPipeline(model, tokenizer)
output = text2text_generator(sentence, max_length=128, do_sample=False)[0]['generated_text'].replace(" ","")
print(sentence)
print(output)

Device set to use cpu


75公斤级比赛3个项目[4]的第一名均为中国选手李顺柱获得[4]。
75公斤级三个项目[4]的第一名均由中国选手李顺柱获得[4]。


In [67]:
# Another way to do the same thing:
input_ids = tokenizer(sentence, return_tensors='pt', max_length=256, padding="max_length", truncation=True)
outputs = model.generate(input_ids=input_ids["input_ids"], max_length=256, num_beams=4, penalty_alpha=0.6, top_k=4)
print(str(tokenizer.batch_decode(outputs, skip_special_tokens=True)).replace(" ",""))

['最近几天，，在一些交通要道的交叉口，发生爆炸事件[[]']
