From c747c7e4147dec34f8f83b631532f54c209ba544 Mon Sep 17 00:00:00 2001 From: Dieuwke Date: Tue, 5 Jun 2018 16:36:18 +0200 Subject: [PATCH] Update readme and example script --- README.md | 4 +++- example.sh | 12 ++++++------ infer.py | 4 +++- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index b977bf02..eb977a40 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ This is a pytorch implementation of a sequence to sequence learning toolkit for # Requirements -This library runs with PyTorch 0.3.0. We refer to the [PyTorch website](http://pytorch.org/) to install the right version for your environment. +This library runs with PyTorch 0.4.0. We refer to the [PyTorch website](http://pytorch.org/) to install the right version for your environment. To install additional requirements (including numpy and torchtext), run: `pip install -r requirements.txt` @@ -29,6 +29,7 @@ The script `train_model.py` can be used to train a new model, resume the trainin `python train_model.py --train $train_path --dev $dev_path --output_dir $expt_dir --embedding_size 128 --hidden_size 256 --rnn_cell gru --epoch 20 Several options are available from the command line, including changing the optimizer, batch size, using attention/bidirectionality and using teacher forcing. +Additionally, models can be trained using *attentive guidance* [1](https://arxiv.org/abs/1805.09657), which can be provided in the training file. For a complete overview, use the *help* function of the script. ## Evaluation and inference @@ -50,6 +51,7 @@ Once training is complete, you will be prompted to enter a new sequence to trans Input: 1 3 5 7 9 Expected output: 9 7 5 3 1 EOS +Additional example scripts, as well as different datasets can be found in our auxiliary repository [machine-tasks](https://github.com/i-machine-think/machine-tasks). ## Checkpoints diff --git a/example.sh b/example.sh index 64e09b77..210a19f6 100755 --- a/example.sh +++ b/example.sh @@ -5,8 +5,8 @@ DEV_PATH=test/test_data/dev.txt EXPT_DIR=example # set values -EMB_SIZE=128 -H_SIZE=128 +EMB_SIZE=16 +H_SIZE=64 N_LAYERS=1 CELL='gru' EPOCH=6 @@ -15,10 +15,10 @@ TF=0.5 # Start training echo "Train model on example data" -python train_model.py --train $TRAIN_PATH --dev $DEV_PATH --output_dir $EXPT_DIR --print_every $PRINT_EVERY --embedding_size $EMB_SIZE --hidden_size $H_SIZE --rnn_cell $CELL --n_layers $N_LAYERS --epoch $EPOCH --print_every $PRINT_EVERY --teacher_forcing $TF --bidirectional --attention +python train_model.py --train $TRAIN_PATH --output_dir $EXPT_DIR --print_every $PRINT_EVERY --embedding_size $EMB_SIZE --hidden_size $H_SIZE --rnn_cell $CELL --n_layers $N_LAYERS --epoch $EPOCH --print_every $PRINT_EVERY --teacher_forcing $TF --attention 'pre-rnn' --attention_method 'mlp' -echo "Evaluate model on test data" -python evaluate.py --checkpoint_path $EXPT_DIR/$(ls -t $EXPT_DIR/ | head -1) --test_data $DEV_PATH +echo "\n\nEvaluate model on test data" +python evaluate.py --checkpoint_path $EXPT_DIR/$(ls -t $EXPT_DIR/ | head -1) --test_data $TRAIN_PATH -echo "Run in inference mode" +echo "\n\nRun in inference mode" python infer.py --checkpoint_path $EXPT_DIR/$(ls -t $EXPT_DIR/ | head -1) diff --git a/infer.py b/infer.py index 92d68848..2143e548 100644 --- a/infer.py +++ b/infer.py @@ -49,6 +49,8 @@ exit() while True: - seq_str = raw_input("Type in a source sequence:") + seq_str = raw_input("\n\nType in a source sequence: ") + if seq_str == 'q': + exit() seq = seq_str.strip().split() print(predictor.predict(seq))