Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Color and Exit point. #111

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ You can then run the `interact.py` script on the pretrained model:
python3 interact.py --model models/
```

After running `interact.py` you can easily exit/stop talking with bot by typing:

```bash
quit
```

## Pretrained model

We make a pretrained and fine-tuned model available on our S3 [here](https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/finetuned_chatbot_gpt.tar.gz). The easiest way to download and use this model is just to run the `interact.py` script to talk with the model. Without any argument, this script will automatically download and cache our model.
Expand Down
12 changes: 8 additions & 4 deletions interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import chain
from pprint import pformat
import warnings

from termcolor import colored
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -136,18 +136,22 @@ def run():
logger.info("Selected personality: %s", tokenizer.decode(chain(*personality)))

history = []
print("\n===============================================")
print("================== Conv AI ====================")
print("===============================================\n")
while True:
raw_text = input(">>> ")
raw_text = input(colored("You: ",'green'))
if raw_text == 'quit': break
while not raw_text:
print('Prompt should not be empty!')
raw_text = input(">>> ")
raw_text = input(colored("You: ",'green'))
history.append(tokenizer.encode(raw_text))
with torch.no_grad():
out_ids = sample_sequence(personality, history, tokenizer, model, args)
history.append(out_ids)
history = history[-(2*args.max_history+1):]
out_text = tokenizer.decode(out_ids, skip_special_tokens=True)
print(out_text)
print(colored("Bot:",'red'),out_text)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ transformers==2.5.1
tensorboardX==1.8
tensorflow # for tensorboardX
spacy
termcolor