Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[Wizard of Wikipedia] Interactive retrieval and MTurk evaluation (#1623)
Browse files Browse the repository at this point in the history
* add interactive retrieval model and mturk evaluation task for wizard

* refactor retriever setup

* lint

* more linting, update READMES

* another readme

* i'm never going to pass lint
  • Loading branch information
Emily Dinan committed Apr 9, 2019
1 parent 4efdd24 commit b618f5c
Show file tree
Hide file tree
Showing 13 changed files with 1,747 additions and 3 deletions.
19 changes: 19 additions & 0 deletions parlai/mturk/tasks/wizard_of_wikipedia/README.md
@@ -0,0 +1,19 @@
# Wizard of Wikipedia Data Collection Task

The task involves two people holding a conversation. One dialog partner
chooses a topic to discuss, and then dialog proceeds.

One partner is the Wizard, who has access to retrieved external
information conditioned on the last two utterances, as well as
information regarding the chosen topic.

The other partner is the Apprentice, who assumes the role of someone
eager to learn about the chosen topic.

This task was used to collect the dataset for the wizard_of_wikipedia task
contained in the tasks folder. A detailed description of the project may
be found in [Dinan et al. (ICLR 2019)](https://arxiv.org/abs/1811.01241).
For more details and pre-trained models, please see the
[Wizard of Wikipedia project page](https://github.com/facebookresearch/ParlAI/tree/master/projects/wizard_of_wikipedia).
The project page contains the Mechanical Turk task used to evaluate
pre-trained models in this paper.
Expand Up @@ -11,6 +11,6 @@

def download(datapath):
opt = {'datapath': datapath}
fnames = ['full_dialogue_retrieval_model.tgz']
fnames = ['wizard_retrieval_2.tgz']
opt['model_type'] = 'full_dialogue_retrieval_model'
download_models(opt, fnames, 'wizard_of_wikipedia')
download_models(opt, fnames, 'wizard_of_wikipedia', version='v2.0')
4 changes: 3 additions & 1 deletion projects/wizard_of_wikipedia/README.md
Expand Up @@ -33,7 +33,9 @@ See the [ParlAI quickstart for help](http://www.parl.ai/static/docs/tutorial_qui

The ParlAI MTurk collection scripts are also
[made available](https://github.com/facebookresearch/ParlAI/tree/master/parlai/mturk/tasks/wizard_of_wikipedia),
for those interested in replication, analysis, or additional data collection
for those interested in replication, analysis, or additional data collection.
The MTurk task for evaluating pre-trained models is made available in this
directory.

## Leaderboard

Expand Down
@@ -0,0 +1,282 @@
#!/usr/bin/env python

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Wizard agent with 2 parts:
1. TFIDF retriever (optional, task may already provide knowledge)
2. Retrieval model, retrieves on possible responses and conditions on
retrieved knowledge
NOTE: this model only works for eval, it assumes all training is already done.
"""

from parlai.core.agents import Agent, create_agent
from projects.wizard_of_wikipedia.wizard_transformer_ranker.\
wizard_transformer_ranker import WizardTransformerRankerAgent

import json
import os


class InteractiveRetrievalAgent(Agent):
def __init__(self, opt, shared=None):
super().__init__(opt, shared)
self.debug = opt['debug']
self.get_unique = opt['get_unique']
if self.get_unique:
self.used_messages = []
self.model_path = os.path.join(opt['datapath'], 'models',
'wizard_of_wikipedia',
'full_dialogue_retrieval_model')

if not shared:
# Create retriever
self._set_up_retriever(opt)
else:
self.opt = shared['opt']
self.retriever = shared['retriever']
self.sent_tok = shared['sent_tok']
self.wiki_map = shared['wiki_map']

# Create responder
self._set_up_responder(opt)

self.id = 'WizardRetrievalInteractiveAgent'
self.ret_history = {}

@staticmethod
def add_cmdline_args(argparser):
"""Add command-line arguments specifically for this agent."""
WizardTransformerRankerAgent.add_cmdline_args(argparser)
parser = argparser.add_argument_group('WizardRetrievalInteractive Arguments')
parser.add_argument('--retriever-model-file', type=str, default=None)
parser.add_argument('--responder-model-file', type=str, default=None)
parser.add_argument('--get-unique', type='bool', default=True,
help='get unique responses from the bot')
parser.add_argument('--num-retrieved', type=int, default=7,
help='how many passages to retrieve for each'
'category')
parser.add_argument('--debug', type='bool', default=False)
return parser

def _set_up_retriever(self, opt):
retriever_opt = {'model_file': opt['retriever_model_file'],
'remove_title': False,
'datapath': opt['datapath'],
'override': {'remove_title': False}}
self.retriever = create_agent(retriever_opt)

self._set_up_sent_tok()
wiki_map_path = os.path.join(self.model_path,
'title_to_passage.json')
self.wiki_map = json.load(open(wiki_map_path, 'r'))

def _set_up_responder(self, opt):
responder_opts = opt.copy()
# override these opts to build the responder model
override_opts = {
'model_file': opt['responder_model_file'],
'datapath': opt['datapath'],
'model': 'projects:wizard_of_wikipedia:wizard_transformer_ranker',
'fixed_candidates_path': os.path.join(self.model_path,
'wizard_cands.txt'),
'eval_candidates': 'fixed',
'n_heads': 6,
'ffn_size': 1200,
'embeddings_scale': False,
'delimiter': ' __SOC__ ',
'n_positions': 1000,
'legacy': True,
'no_cuda': True,
'encode_candidate_vecs': True,
}
for k, v in override_opts.items():
responder_opts[k] = v
responder_opts['override'][k] = v
self.responder = create_agent(responder_opts)

def _set_up_sent_tok(self):
try:
import nltk
except ImportError:
raise ImportError('Please install nltk (e.g. pip install nltk).')
# nltk-specific setup
st_path = 'tokenizers/punkt/{0}.pickle'.format('english')
try:
self.sent_tok = nltk.data.load(st_path)
except LookupError:
nltk.download('punkt')
self.sent_tok = nltk.data.load(st_path)

def get_chosen_topic_passages(self, chosen_topic):
retrieved_txt_format = []
if chosen_topic in self.wiki_map:
retrieved_txt = self.wiki_map[chosen_topic]
retrieved_txts = retrieved_txt.split('\n')

if len(retrieved_txts) > 1:
combined = ' '.join(retrieved_txts[2:])
sentences = self.sent_tok.tokenize(combined)
total = 0
for sent in sentences:
if total >= 10:
break
if len(sent) > 0:
retrieved_txt_format.append(
' '.join([chosen_topic, sent])
)
total += 1

if len(retrieved_txt_format) > 0:
passages = '\n'.join(retrieved_txt_format)
else:
passages = ''

return passages

def get_passages(self, act):
"""Format passages retrieved by taking the first paragraph of the
top `num_retrieved` passages.
"""
retrieved_txt = act.get('text', '')
cands = act.get('text_candidates', [])
if len(cands) > 0:
retrieved_txts = cands[:self.opt['num_retrieved']]
else:
retrieved_txts = [retrieved_txt]

retrieved_txt_format = []
for ret_txt in retrieved_txts:
paragraphs = ret_txt.split('\n')
if len(paragraphs) > 2:
sentences = self.sent_tok.tokenize(paragraphs[2])
for sent in sentences:
retrieved_txt_format.append(
' '.join([paragraphs[0], sent])
)

if len(retrieved_txt_format) > 0:
passages = '\n'.join(retrieved_txt_format)
else:
passages = ''

return passages

def retriever_act(self, history):
"""Combines and formats texts retrieved by the TFIDF retriever for the
chosen topic, the last thing the wizard said, and the last thing the
apprentice said.
"""
# retrieve on chosen topic
chosen_topic_txts = None
if self.ret_history.get('chosen_topic'):
chosen_topic_txts = self.get_chosen_topic_passages(
self.ret_history['chosen_topic'],
)

# retrieve on apprentice
apprentice_txts = None
if self.ret_history.get('apprentice'):
apprentice_act = {
'text': self.ret_history['apprentice'],
'episode_done': True
}
self.retriever.observe(apprentice_act)
apprentice_txts = self.get_passages(self.retriever.act())

# retrieve on wizard
wizard_txts = None
if self.ret_history.get('wizard'):
wizard_act = {
'text': self.ret_history['wizard'],
'episode_done': True
}
self.retriever.observe(wizard_act)
wizard_txts = self.get_passages(self.retriever.act())

# combine everything
combined_txt = ''
if chosen_topic_txts:
combined_txt += chosen_topic_txts
if wizard_txts:
combined_txt += ('\n' + wizard_txts)
if apprentice_txts:
combined_txt += ('\n' + apprentice_txts)

return combined_txt

def observe(self, observation):
obs = observation.copy()
self.maintain_retrieved_texts(self.ret_history, obs)
if self.debug:
print('DEBUG: Retriever history:\n{}'.format(self.ret_history))
responder_knowledge = self.retriever_act(self.ret_history)
obs['knowledge'] = responder_knowledge
self.observation = obs

def maintain_retrieved_texts(self, history, observation):
"""Maintain texts retrieved by the retriever to mimic the set-up
from the data collection for the task.
"""
if 'chosen_topic' not in history:
history['episode_done'] = False
history['chosen_topic'] = ''
history['wizard'] = ''
history['apprentice'] = ''

if history['episode_done']:
history['chosen_topic'] = ''
history['wizard'] = ''
history['apprentice'] = ''
if 'next_wizard' in history:
del history['next_wizard']
history['episode_done'] = False

# save chosen topic
if 'chosen_topic' in observation:
history['chosen_topic'] = observation['chosen_topic']
if 'text' in observation:
history['apprentice'] = observation['text']
if 'next_wizard' in history:
history['wizard'] = history['next_wizard']
# save last thing wizard said (for next time)
if 'labels' in observation:
history['next_wizard'] = observation['labels'][0]
elif 'eval_labels' in observation:
history['next_wizard'] = observation['eval_labels'][0]

history['episode_done'] = observation['episode_done']

def get_unique_reply(self, act):
# iterate through text candidates until we find a reply that we
# have not used yet
for txt in act['text_candidates']:
if txt not in self.used_messages:
self.used_messages.append(txt)
return txt

def act(self):
obs = self.observation
# choose a knowledge sentence
responder_obs = obs.copy()
if self.debug:
print('DEBUG: Responder is observing:\n{}'.format(responder_obs))
self.responder.observe(responder_obs)
responder_act = self.responder.act()
if self.debug:
print('DEBUG: Responder is acting:\n{}'.format(responder_act))
responder_act['id'] = 'WizardRetrievalInteractiveAgent'
if self.get_unique:
responder_act['text'] = self.get_unique_reply(responder_act)
return responder_act

def share(self):
"""Share internal saved_model between parent and child instances."""
shared = super().share()
shared['opt'] = self.opt
shared['retriever'] = self.retriever
shared['sent_tok'] = self.sent_tok
shared['wiki_map'] = self.wiki_map

return shared
5 changes: 5 additions & 0 deletions projects/wizard_of_wikipedia/mturk_evaluation_task/README.md
@@ -0,0 +1,5 @@
This is the Mechanical Turk task for evaluating models trained on the Wizard of Wikipedia task.

As an example, we have one of the pre-trained models loaded inside the task. Please edit `config` in `run.py` to swap out the model for one of yours.

In order to run the task with two humans speaking to each other, run with the flag `--human-eval True`.
@@ -0,0 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

0 comments on commit b618f5c

Please sign in to comment.