diff --git a/README.md b/README.md index bc9d09c..1e9d709 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,12 @@ You can then run the `interact.py` script on the pretrained model: python3 interact.py --model models/ ``` +After running `interact.py` you can easily exit/stop talking with bot by typing: + +```bash +quit +``` + ## Pretrained model We make a pretrained and fine-tuned model available on our S3 [here](https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/finetuned_chatbot_gpt.tar.gz). The easiest way to download and use this model is just to run the `interact.py` script to talk with the model. Without any argument, this script will automatically download and cache our model. diff --git a/interact.py b/interact.py index d368204..198a6fe 100644 --- a/interact.py +++ b/interact.py @@ -8,7 +8,7 @@ from itertools import chain from pprint import pformat import warnings - +from termcolor import colored import torch import torch.nn.functional as F @@ -136,18 +136,22 @@ def run(): logger.info("Selected personality: %s", tokenizer.decode(chain(*personality))) history = [] + print("\n===============================================") + print("================== Conv AI ====================") + print("===============================================\n") while True: - raw_text = input(">>> ") + raw_text = input(colored("You: ",'green')) + if raw_text == 'quit': break while not raw_text: print('Prompt should not be empty!') - raw_text = input(">>> ") + raw_text = input(colored("You: ",'green')) history.append(tokenizer.encode(raw_text)) with torch.no_grad(): out_ids = sample_sequence(personality, history, tokenizer, model, args) history.append(out_ids) history = history[-(2*args.max_history+1):] out_text = tokenizer.decode(out_ids, skip_special_tokens=True) - print(out_text) + print(colored("Bot:",'red'),out_text) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index a931b0c..30f9d61 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ transformers==2.5.1 tensorboardX==1.8 tensorflow # for tensorboardX spacy +termcolor