Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add TriviaQA task #204

Merged
merged 7 commits into from Jul 18, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions parlai/tasks/task_list.py
Expand Up @@ -157,6 +157,13 @@
"tags": [ "All", "QA" ],
"description": "Open-domain QA dataset answerable from a given paragraph from Wikipedia, from Rajpurkar et al. '16. Link: https://arxiv.org/abs/1606.05250"
},
{
"id": "TriviaQA",
"display_name": "TriviaQA",
"task": "triviaqa",
"tags": [ "All", "QA" ],
"description": "Open-domain QA dataset with question-answer-evidence triples, from Joshi et al. '17. Link: https://arxiv.org/abs/1705.03551"
},
{
"id": "Ubuntu",
"display_name": "Ubuntu",
Expand Down
5 changes: 5 additions & 0 deletions parlai/tasks/triviaqa/__init__.py
@@ -0,0 +1,5 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
133 changes: 133 additions & 0 deletions parlai/tasks/triviaqa/agents.py
@@ -0,0 +1,133 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from parlai.core.dialog_teacher import DialogTeacher
from parlai.core.agents import MultiTaskTeacher
from .build import build

import copy
import json
import os
import random

def _path(opt):
build(opt)

return (os.path.join(opt['datapath'], 'TriviaQA', 'qa'),
os.path.join(opt['datapath'], 'TriviaQA', 'evidence'))


class WebTeacher(DialogTeacher):
def __init__(self, opt, shared=None):
if not hasattr(self, 'prefix'):
self.prefix = ''
if opt['datatype'].startswith('train'):
self.suffix = 'train'
else:
self.suffix = 'dev'

qa_dir, self.evidence_dir = _path(opt)
opt['datafile'] = os.path.join(qa_dir, self.prefix + 'web-' +
self.suffix + '.json')
self.id = 'triviaqa'
super().__init__(opt, shared)

def setup_data(self, path):
print('loading: ' + path)
with open(path) as data_file:
data = json.load(data_file)['Data']
for datapoint in data:
question = datapoint['Question']
answers = datapoint['Answer']['Aliases']
evidence_list = datapoint['SearchResults']

if len(evidence_list) == 0:
continue

for evidence_item in evidence_list:
evidence_file_path = os.path.join(self.evidence_dir, 'web',
evidence_item['Filename'])
with open(evidence_file_path) as evidence_file:
evidence = 'Title: %s\n' % evidence_item['Title']
evidence += evidence_file.read()
yield (evidence + '\n' + question, answers), True


class VerifiedWebTeacher(WebTeacher):
def __init__(self, opt, shared=None):
self.prefix = 'verified-'
self.suffix = 'dev'
if opt['datatype'] != 'valid':
print('WARNING: Verified teacher only provides dev data')

opt['datafile'], self.evidence_dir = _path(opt)
self.id = 'triviaqa'
super().__init__(opt, shared)


class WikipediaTeacher(DialogTeacher):
def __init__(self, opt, shared=None):
if not hasattr(self, 'prefix'):
self.prefix = ''
if opt['datatype'].startswith('train'):
self.suffix = 'train'
else:
self.suffix = 'dev'

qa_dir, self.evidence_dir = _path(opt)
opt['datafile'] = os.path.join(qa_dir, self.prefix + 'wikipedia-' +
self.suffix + '.json')

self.id = 'triviaqa'
super().__init__(opt, shared)

def setup_data(self, path):
print('loading: ' + path)
with open(path) as data_file:
data = json.load(data_file)['Data']
for datapoint in data:
question = datapoint['Question']
answers = datapoint['Answer']['Aliases']
evidence_list = datapoint['EntityPages']

if len(evidence_list) == 0:
continue

evidence = ''
for evidence_item in evidence_list:
evidence_file_path = os.path.join(self.evidence_dir,
'wikipedia',
evidence_item['Filename'])
with open(evidence_file_path) as evidence_file:
evidence += 'Title: %s\n' % evidence_item['Title']
evidence += evidence_file.read() + '\n\n'

yield (evidence + question, answers), True


class VerifiedWikipediaTeacher(WikipediaTeacher):
def __init__(self, opt, shared=None):
self.prefix = 'verified-'
self.suffix = 'dev'
if opt['datatype'] != 'valid':
print('WARNING: Verified teacher only provides dev data')

opt['datafile'], self.evidence_dir = _path(opt)
self.id = 'triviaqa'
super().__init__(opt, shared)


class VerifiedTeacher(MultiTaskTeacher):
def __init__(self, opt, shared=None):
opt = copy.deepcopy(opt)
opt['task'] = 'triviaqa:VerifiedWikipedia,triviaqa:VerifiedWeb'
super().__init__(opt, shared)

class DefaultTeacher(MultiTaskTeacher):
def __init__(self, opt, shared=None):
opt = copy.deepcopy(opt)
opt['task'] = 'triviaqa:wikipedia,triviaqa:web'
super().__init__(opt, shared)
31 changes: 31 additions & 0 deletions parlai/tasks/triviaqa/build.py
@@ -0,0 +1,31 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
#
# Download and build the data if it does not exist.

import parlai.core.build_data as build_data
import os


def build(opt):
dpath = os.path.join(opt['datapath'], 'TriviaQA')
version = None

if not build_data.built(dpath, version_string=version):
print('[building data: ' + dpath + ']')
if build_data.built(dpath):
# An older version exists, so remove these outdated files.
build_data.remove_dir(dpath)
build_data.make_dir(dpath)

# Download the data.
fname = 'triviaqa-rc.tar.gz'
url = 'http://nlp.cs.washington.edu/triviaqa/data/'
build_data.download(url + fname, dpath, fname)
build_data.untar(dpath, fname)

# Mark the data as built.
build_data.mark_done(dpath, version_string=version)
16 changes: 16 additions & 0 deletions tests/test_downloads.py
Expand Up @@ -284,6 +284,22 @@ def test_squad(self):

shutil.rmtree(self.TMP_PATH)

def test_triviaqa(self):
from parlai.core.params import ParlaiParser
from parlai.tasks.triviaqa.agents import WebTeacher, WikipediaTeacher

opt = ParlaiParser().parse_args(args=self.args)

for teacher_class in (WebTeacher, WikipediaTeacher):
for dt in ['train:ordered', 'valid']:
opt['datatype'] = dt

teacher = teacher_class(opt)
reply = teacher.act()
check(opt, reply)

shutil.rmtree(self.TMP_PATH)

def test_ubuntu(self):
from parlai.core.params import ParlaiParser
from parlai.tasks.ubuntu.agents import DefaultTeacher
Expand Down