## Fine tuning text simplification with BART
Use pseudo wiki parallel data to fine tune Chinese BART for TS task

In [None]:
!pip install jieba evaluate sacrebleu sacremoses datasets

In [1]:
#runtime variables
HSK_dir = "../data/HSK/"
pseudo_dir = "../data/mcts-pseudo/"

%load_ext autoreload
%autoreload 2
# import pandas as pd
import numpy as np
import jieba
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
from datasets import Dataset
import pickle
# from easse.sari import sentence_sari
from evaluate import load
sari = load("sari")
with open(HSK_dir+"HSK_levels.pickle", 'rb') as handle:
    HSK_dict = pickle.load(handle)

In [None]:
from transformers import BertTokenizer, BartForConditionalGeneration
tokenizer = BertTokenizer.from_pretrained("fnlp/bart-base-chinese")
model = BartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese")

In [9]:
from transformers import Text2TextGenerationPipeline
text2text_generator = Text2TextGenerationPipeline(model, tokenizer)
text2text_generator("这是一个句子", max_length=128, do_sample=False)[0]['generated_text'].replace(" ","")

Device set to use cpu


'这是一句话。'

In [None]:
input_ids = tokenizer("这是一个句子", return_tensors='pt', max_length=256, padding="max_length", truncation=True)
outputs = model.generate(input_ids=input_ids["input_ids"], max_length=256, num_beams=4, penalty_alpha=0.6, top_k=4)
print(str(tokenizer.batch_decode(outputs, skip_special_tokens=True)).replace(" ",""))

AttributeError: 'list' object has no attribute 'shape'

In [3]:
def tokenize_with_HSK(sentence, HSK_dict):
    return " ".join(jieba.cut(sentence))
    # split_sentence = jieba.lcut(sentence)
    # HSK_sentence = ""
    # for word in split_sentence:
    #     score = HSK_dict.get(word, 0)
    #     if score>2:
    #         HSK_sentence += f" {word}[{score}]"
    #     else:
    #         HSK_sentence += f" {word}"
    # return HSK_sentence

# tokenizer.add_tokens(["[3]", "[4]", "[5]", "[6]", "[7]"])
# model.resize_token_embeddings(len(tokenizer))

def preprocess_data(filename: str, start: int, stop: int):
    lines_HSK = []
    with open(filename, encoding="utf8") as f:
        lines_orig = f.read().splitlines()
        for line in lines_orig[start:stop]:
            lines_HSK.append(tokenize_with_HSK(line, HSK_dict))
            if len(lines_HSK)%1000==0:
                print(len(lines_HSK))
    return lines_HSK

In [4]:
start = 0
stop = 100
split = 75

lines_complex = preprocess_data(pseudo_dir+"zh_selected.ori", start, stop)
lines_simple = preprocess_data(pseudo_dir+"zh_selected.sim", start, stop)

data_dict = {'complex': lines_complex[start:split], 'simple': lines_simple[start:split]}
ds_train = Dataset.from_dict(data_dict)
data_dict = {'complex': lines_complex[split:stop], 'simple': lines_simple[split:stop]}
ds_eval = Dataset.from_dict(data_dict)

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\tempu\AppData\Local\Temp\jieba.cache
Loading model cost 1.469 seconds.
Prefix dict has been built successfully.


In [1]:
lines_simple

NameError: name 'lines_simple' is not defined

In [6]:
# tokenize data
max_length = 128
def batch_tokenize_data(data):
    inputs = [example for example in data["complex"]]
    targets = [example for example in data["simple"]]

    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True)
    labels = tokenizer(targets, max_length=max_length, padding="max_length", truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_data_train = ds_train.map(batch_tokenize_data, batched=True)
tokenized_data_eval = ds_eval.map(batch_tokenize_data, batched=True)

Map:   0%|          | 0/75 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

In [None]:
import pickle
with open("../data/pseudo_train", "wb") as handle:
    pickle.dump(tokenized_data_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open("../data/pseudo_eval", "wb") as handle:
    pickle.dump(tokenized_data_eval, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [22]:
# optuna fine-tuning
import optuna
import torch

def search_space(trial):
    """ Define hyperparameter search space """
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 5e-4, log=True),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16]),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 3, 10),
        "weight_decay": trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)
    }

# Define training arguments
def model_init():
    """ Function to initialize the model for Trainer """
    from transformers import BartForConditionalGeneration
    return BartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese", local_files_only=True)

def compute_metrics(eval_preds):
    prediction_tokens = eval_preds.predictions
    prediction_text = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in prediction_tokens]

    sari_score = sari.compute(
            predictions=prediction_text, # model output
            references=[[simple] for simple in tokenized_data_eval.select(range(15))['simple']], # reference simple sentences
            sources=tokenized_data_eval.select(range(15))['complex'] # complex sentence
        )
    return {"sari": sari_score["sari"]}

def preprocess_logits_for_metrics(logits, labels):
        """
        Original Trainer may have a memory leak. 
        This is a workaround to avoid storing too many tensors that are not needed.
        """
        pred_ids = torch.argmax(logits[0], dim=-1)
        return pred_ids

training_args = TrainingArguments(
    output_dir = "./bart_hypersearch",
    eval_strategy = "epoch",
    per_device_eval_batch_size = 1,
    eval_accumulation_steps = 1, 
    logging_dir = "./logs",
    greater_is_better = True,
)

trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=tokenized_data_train.select(range(5)),  # Use a subset for quick tuning
    eval_dataset=tokenized_data_eval.select(range(15)),  
    processing_class = tokenizer,
    compute_metrics = compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

# Run Optuna hyperparameter search
best_trial = trainer.hyperparameter_search(
    direction="maximize",  # For maximizing performance (adjust as needed)
    hp_space=search_space,
    n_trials=10  # Number of trials
)

print(best_trial)

[I 2025-02-26 10:30:40,742] A new study created in memory with name: no-name-f50bf91f-1ad4-458a-b33b-aff458e7f6c8


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

{'eval_loss': 9.780804634094238, 'eval_sari': 30.42471376377177, 'eval_runtime': 8.5501, 'eval_samples_per_second': 1.754, 'eval_steps_per_second': 1.754, 'epoch': 1.0}


  0%|          | 0/15 [00:00<?, ?it/s]

{'eval_loss': 7.8586106300354, 'eval_sari': 30.437257337160588, 'eval_runtime': 7.7343, 'eval_samples_per_second': 1.939, 'eval_steps_per_second': 1.939, 'epoch': 2.0}


[W 2025-02-26 10:31:19,440] Trial 0 failed with parameters: {'learning_rate': 5.710436455910181e-05, 'per_device_train_batch_size': 16, 'num_train_epochs': 8, 'weight_decay': 0.00031625137052626964} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "c:\Users\tempu\AppData\Local\Programs\Python\Python312\Lib\site-packages\optuna\study\_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "c:\Users\tempu\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\integrations\integration_utils.py", line 249, in _objective
    trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  File "c:\Users\tempu\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\trainer.py", line 2164, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\tempu\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\trainer

KeyboardInterrupt: 

In [None]:
import logging
logging.basicConfig(level=logging.INFO)

def compute_metrics(eval_preds):
    prediction_tokens = np.argmax(eval_preds.predictions[0], axis=-1)
    prediction_text = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in prediction_tokens]

    sari_score = sari.compute(
            predictions=prediction_text, # model output
            references=[[simple] for simple in tokenized_data_eval['simple']], # reference simple sentences
            sources=tokenized_data_eval['complex'] # complex sentence
        )
    return {"sari": sari_score["sari"]}

best_params = best_trial.hyperparameters
training_args = TrainingArguments(
    output_dir = "./bart_simplification",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate = best_params["learning_rate"],
    per_device_train_batch_size = best_params["per_device_train_batch_size"],
    num_train_epochs = best_params["num_train_epochs"],
    weight_decay = best_params["weight_decay"],
    logging_dir = "./logs",
    logging_steps = 500,
    greater_is_better = True,
)

training_args = TrainingArguments(
    output_dir="./bart_simplification",
    eval_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    eval_accumulation_steps = 1,
    num_train_epochs=5,
    weight_decay=0.01,
    save_total_limit=2,
    logging_steps=500,
    greater_is_better=True  # maximize SARI
)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = tokenized_data_train,
    eval_dataset = tokenized_data_eval,
    processing_class = tokenizer,
    compute_metrics = compute_metrics
)

  trainer = Trainer(


In [70]:
trainer.train()

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.47325244545936584, 'eval_sari': 34.432335363128274, 'eval_runtime': 11.419, 'eval_samples_per_second': 2.189, 'eval_steps_per_second': 0.35, 'epoch': 1.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.520830512046814, 'eval_sari': 34.25620275223107, 'eval_runtime': 11.2691, 'eval_samples_per_second': 2.218, 'eval_steps_per_second': 0.355, 'epoch': 2.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5226666331291199, 'eval_sari': 34.384029474241494, 'eval_runtime': 13.3294, 'eval_samples_per_second': 1.876, 'eval_steps_per_second': 0.3, 'epoch': 3.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5167384147644043, 'eval_sari': 34.578188307775484, 'eval_runtime': 13.252, 'eval_samples_per_second': 1.887, 'eval_steps_per_second': 0.302, 'epoch': 4.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5156217813491821, 'eval_sari': 34.538338728313626, 'eval_runtime': 10.8963, 'eval_samples_per_second': 2.294, 'eval_steps_per_second': 0.367, 'epoch': 5.0}
{'train_runtime': 615.4155, 'train_samples_per_second': 0.609, 'train_steps_per_second': 0.081, 'train_loss': 0.04471360206604004, 'epoch': 5.0}


TrainOutput(global_step=50, training_loss=0.04471360206604004, metrics={'train_runtime': 615.4155, 'train_samples_per_second': 0.609, 'train_steps_per_second': 0.081, 'total_flos': 28581396480000.0, 'train_loss': 0.04471360206604004, 'epoch': 5.0})

In [13]:
predictions = trainer.predict(tokenized_data_train)

  0%|          | 0/1 [00:00<?, ?it/s]

In [63]:
prediction_tokens = np.argmax(predictions.predictions[0], axis=-1)
prediction_text = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in prediction_tokens]

sari_score = sari.compute(
        predictions=prediction_text, # model output
        references=[[simple] for simple in tokenized_data_train['simple']], # reference simple sentences
        sources=tokenized_data_train['complex'] # complex sentence
    )

In [None]:
idx = 0
sentence = lines_complex[idx]
from transformers import Text2TextGenerationPipeline
text2text_generator = Text2TextGenerationPipeline(model, tokenizer)
output = text2text_generator(sentence, max_length=128, do_sample=False)[0]['generated_text'].replace(" ","")
print(sentence)
print(output)

Device set to use cpu


75公斤级比赛3个项目[4]的第一名均为中国选手李顺柱获得[4]。
75公斤级三个项目[4]的第一名均由中国选手李顺柱获得[4]。


In [67]:
# Another way to do the same thing:
input_ids = tokenizer(sentence, return_tensors='pt', max_length=256, padding="max_length", truncation=True)
outputs = model.generate(input_ids=input_ids["input_ids"], max_length=256, num_beams=4, penalty_alpha=0.6, top_k=4)
print(str(tokenizer.batch_decode(outputs, skip_special_tokens=True)).replace(" ",""))

['最近几天，，在一些交通要道的交叉口，发生爆炸事件[[]']
