Skip to content

MANGA-UOFA/EqHard-EM

Repository files navigation

An Equal-Size Hard EM Algorithm for Diverse Dialogue Generation

Training

To train a single-decoder model, use base_trainer.py:

python base_train.py \
    --train-path=path-to-the-training-csv-file \
    --val-path=path-to-the-validation-csv-file \
    --model-str=t5-small \

The resulting checkpoint is used as the initial checkpoint for multi-decoder training. For Weibo, use --model-str=uer/t5-small-chinese-cluecorpussmall, --language=zh, and --multi-ref in addition.

To train a multi-decoder model, use the script multi-decoder_trainer.py:

python multi-decoder_trainer.py \
    --train-path=path-to-the-training-csv-file \
    --val-path=path-to-the-validation-csv-file \
    --model-str=t5-small \
    --init-ckpt=path-to-warmstart-ckpt \
    --freeze \
    --num-modes=10 \
    --trainer=eqhem \
    --decoder=adapter \

where

  1. --num-modes specifies the number of decoders
  2. --trainer specifies the training algorithm.
    • eqhem:EqHard-EM
    • sem: Soft-EM
    • trick-sem: Soft-EM with recurrent dropout trick
    • hem: Hard-EM
    • trick-hem: Hard-EM with recurrent dropout trick
    • random: EqRandom-Fixed
    • drandom: EqRandom-Dynamic
  3. --decoder specifies the decoder architecture.
  4. --lp enables learned priors. (uniform prior by default)

Monitoring Performance

The training script will automatically generate a timestamped logging directory to store the checkpoints as well as log files. The validation performance can be monitored during training through tensorboard:

tensorboard --logdir=path-to-the-timestamped-logging-folder

Continue Training

If the performance is still increasing at the end of training, you can resume with the following command

python base_train.py \
    --the-original-arguments-that-you-started-training-with
    --resume-path=path-to-the-timestamped-logging-folder

Evaluation

After the performance has peaked, you can evaluate the model using evaluate_generations.py:

python evaluate_generations.py --ckpt-path=path-to-the-best-validation-checkpoint --eval-path=path-to-the-test-csv-file

Additionally use --language=zh and --multi-ref for evaluating on Weibo.

Datasets

OST, Weibo

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages