This is a code repository for the improving the performance of baseline model from LOT benchmark on its Chinese Story Generation Task.
-
Download: The checkpoints of baseline model and example data can be downloaded from THUCloud or Hugging Face Model Card. The training and generation scripts are under the directory
./LOT-LongLM/longlm
. -
Model Loading:
from transformers import T5Tokenizer, T5ForConditionalGeneration tokenizer = T5Tokenizer.from_pretrained('thu-coai/LongLM-base') model = T5ForConditionalGeneration.from_pretrained('thu-coai/LongLM-base')
-
Training:
Execute
bash ./finetune.sh
to fine-tune LongLM. If deepspeed is available, you can executebash ./finetune_deepspped.sh
to accelerate. You can also use the official script provided by Transformers to fine-tune the model.env CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 CUDA_LAUNCH_BLOCKING=1 python3 -m torch.distributed.launch --nproc_per_node=8 \ finetune_trainer.py \ --data_dir=./data \ # directory of data --train_name=train \ # file prefix of the training data --output_dir=./save_model \ # output directory to save the checkpoint --save_total_limit=10 \ # maximum number of the saved checkpoints --per_gpu_train_batch_size=3 \ # batch size for training --per_gpu_eval_batch_size=3 \ # batch size for evaluation --num_train_epochs=1 \ # number of training epochs --logging_steps=5 \ # number of stps to log the loss value --model_name_or_path=./LongLM-base \ # path to the pretrained model --warmup_steps=100 \ # number of steps for warmup --learning_rate=1e-4 \ # learning rate --n_val=100 \ # number of examples for validation --do_train --do_eval \ # whether to training/validation --evaluation_strategy steps \ # strategy of evaluation --gradient_accumulation_steps=40 # number of steps for gradient accumulation --overwrite_output_dir \ --load_best_model_at_end
-
Generation:
input_ids = tokenizer("小咕噜对,<extra_id_1>",return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids.to(device) gen = model.generate(input_ids, do_sample=True, decoder_start_token_id=1, top_p=0.9, max_length=512)
Data statistics of the OutGen task in LOT. The abbreviations char/sent/len stand for character/sentence/length, respectively.
The datasets and evaluation scripts can be downloaded from THUCloud.
The python script jsontrans.py provides APIs to convert the jsonl file downloaded from THUCloud into the .source
and .target
file required for training script.
The training script of LongLM for the generation tasks is the same as pretraining script. The generation script and example data can be found under ./LOT-LongLM/baseline/generation
. You can execute bash ./gen.sh
for generation.
The python script DependencyTagging.py consumes the .jsonl
files and adding the dependency tokens into the story, in order to produce the .source
and .target
files.
The python script boost_simbert.py consumes the .jsonl
files and expanded data to 6 times the original size, in order to produce the .source
and .target
files.
-
Run the python script boost_simbert.py with the codes :
data = load_file("./boosts_bert/train.jsonl") # read the training data from json file data = boost_data(data, gen_synonyms, 5) # Expanded data to 6 times the original size write_jsonl_file_source("./boosts_bert/train_new.jsonl", data) #saving the data as jsonal file
To generate a new
train.jsonl
-
Tuns the python script DependencyTagging.py to label the dependency tagging.
data = load_file("./boosts_bert/train_new.jsonl") # read the training data from json file write_txt_file_source("./outgen/train.source",data) # save the outline from data to the file write_txt_file_target("./outgen/train.target",data)
-
Training the model.
Difference script requires different dependencies environments. You can find the different version requirements.txt in the folder requirements.
datasets 1.6.2
deepspeed 0.3.16
huggingface-hub 0.0.8
jieba 0.42.1
jsonlines 2.0.0
nltk 3.5
numpy 1.19.5
pytorch-lightning 1.2.0
regex 2020.11.13
rouge 1.0.1
rouge-score 0.0.4
sacrebleu 1.5.0
scipy 1.5.4
sentencepiece 0.1.95
tokenizers 0.10.1
torch 1.8.1
torchaudio 0.8.0
torchmetrics 0.2.0
torchvision 0.9.0
transformers 4.6.1
Dependencis for boost_simbert.py
tensorflow 1.14
keras 2.3.1
bert4keras 0.10.6
Dependencis for DependencyTagging
hanlp
@misc{tang2022CSG,
title={Improving Chinese Story Generation via Awareness of Syntactic Dependencies and Semantics},
author={Henglin Huang and Chen Tang and Tyler Loakman and Frank Guerin and Chenghua Lin},
year={2022}
}