Skip to content

Commit

Permalink
Update readme and example script
Browse files Browse the repository at this point in the history
  • Loading branch information
dieuwkehupkes committed Jun 5, 2018
1 parent 94039cb commit c747c7e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
4 changes: 3 additions & 1 deletion README.md
Expand Up @@ -8,7 +8,7 @@ This is a pytorch implementation of a sequence to sequence learning toolkit for

# Requirements

This library runs with PyTorch 0.3.0. We refer to the [PyTorch website](http://pytorch.org/) to install the right version for your environment.
This library runs with PyTorch 0.4.0. We refer to the [PyTorch website](http://pytorch.org/) to install the right version for your environment.
To install additional requirements (including numpy and torchtext), run:

`pip install -r requirements.txt`
Expand All @@ -29,6 +29,7 @@ The script `train_model.py` can be used to train a new model, resume the trainin
`python train_model.py --train $train_path --dev $dev_path --output_dir $expt_dir --embedding_size 128 --hidden_size 256 --rnn_cell gru --epoch 20

Several options are available from the command line, including changing the optimizer, batch size, using attention/bidirectionality and using teacher forcing.
Additionally, models can be trained using *attentive guidance* [1](https://arxiv.org/abs/1805.09657), which can be provided in the training file.
For a complete overview, use the *help* function of the script.

## Evaluation and inference
Expand All @@ -50,6 +51,7 @@ Once training is complete, you will be prompted to enter a new sequence to trans
Input: 1 3 5 7 9
Expected output: 9 7 5 3 1 EOS

Additional example scripts, as well as different datasets can be found in our auxiliary repository [machine-tasks](https://github.com/i-machine-think/machine-tasks).

## Checkpoints

Expand Down
12 changes: 6 additions & 6 deletions example.sh
Expand Up @@ -5,8 +5,8 @@ DEV_PATH=test/test_data/dev.txt
EXPT_DIR=example

# set values
EMB_SIZE=128
H_SIZE=128
EMB_SIZE=16
H_SIZE=64
N_LAYERS=1
CELL='gru'
EPOCH=6
Expand All @@ -15,10 +15,10 @@ TF=0.5

# Start training
echo "Train model on example data"
python train_model.py --train $TRAIN_PATH --dev $DEV_PATH --output_dir $EXPT_DIR --print_every $PRINT_EVERY --embedding_size $EMB_SIZE --hidden_size $H_SIZE --rnn_cell $CELL --n_layers $N_LAYERS --epoch $EPOCH --print_every $PRINT_EVERY --teacher_forcing $TF --bidirectional --attention
python train_model.py --train $TRAIN_PATH --output_dir $EXPT_DIR --print_every $PRINT_EVERY --embedding_size $EMB_SIZE --hidden_size $H_SIZE --rnn_cell $CELL --n_layers $N_LAYERS --epoch $EPOCH --print_every $PRINT_EVERY --teacher_forcing $TF --attention 'pre-rnn' --attention_method 'mlp'

echo "Evaluate model on test data"
python evaluate.py --checkpoint_path $EXPT_DIR/$(ls -t $EXPT_DIR/ | head -1) --test_data $DEV_PATH
echo "\n\nEvaluate model on test data"
python evaluate.py --checkpoint_path $EXPT_DIR/$(ls -t $EXPT_DIR/ | head -1) --test_data $TRAIN_PATH

echo "Run in inference mode"
echo "\n\nRun in inference mode"
python infer.py --checkpoint_path $EXPT_DIR/$(ls -t $EXPT_DIR/ | head -1)
4 changes: 3 additions & 1 deletion infer.py
Expand Up @@ -49,6 +49,8 @@
exit()

while True:
seq_str = raw_input("Type in a source sequence:")
seq_str = raw_input("\n\nType in a source sequence: ")
if seq_str == 'q':
exit()
seq = seq_str.strip().split()
print(predictor.predict(seq))

0 comments on commit c747c7e

Please sign in to comment.