Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Add TriviaQA task (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssatia authored and alexholdenmiller committed Jul 18, 2017
1 parent 1fdebcb commit 838a042
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 0 deletions.
7 changes: 7 additions & 0 deletions parlai/tasks/task_list.py
Expand Up @@ -157,6 +157,13 @@
"tags": [ "All", "QA" ],
"description": "Open-domain QA dataset answerable from a given paragraph from Wikipedia, from Rajpurkar et al. '16. Link: https://arxiv.org/abs/1606.05250"
},
{
"id": "TriviaQA",
"display_name": "TriviaQA",
"task": "triviaqa",
"tags": [ "All", "QA" ],
"description": "Open-domain QA dataset with question-answer-evidence triples, from Joshi et al. '17. Link: https://arxiv.org/abs/1705.03551"
},
{
"id": "Ubuntu",
"display_name": "Ubuntu",
Expand Down
5 changes: 5 additions & 0 deletions parlai/tasks/triviaqa/__init__.py
@@ -0,0 +1,5 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
133 changes: 133 additions & 0 deletions parlai/tasks/triviaqa/agents.py
@@ -0,0 +1,133 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from parlai.core.dialog_teacher import DialogTeacher
from parlai.core.agents import MultiTaskTeacher
from .build import build

import copy
import json
import os
import random

def _path(opt):
build(opt)

return (os.path.join(opt['datapath'], 'TriviaQA', 'qa'),
os.path.join(opt['datapath'], 'TriviaQA', 'evidence'))


class WebTeacher(DialogTeacher):
def __init__(self, opt, shared=None):
if not hasattr(self, 'prefix'):
self.prefix = ''
if opt['datatype'].startswith('train'):
self.suffix = 'train'
else:
self.suffix = 'dev'

qa_dir, self.evidence_dir = _path(opt)
opt['datafile'] = os.path.join(qa_dir, self.prefix + 'web-' +
self.suffix + '.json')
self.id = 'triviaqa'
super().__init__(opt, shared)

def setup_data(self, path):
print('loading: ' + path)
with open(path) as data_file:
data = json.load(data_file)['Data']
for datapoint in data:
question = datapoint['Question']
answers = datapoint['Answer']['Aliases']
evidence_list = datapoint['SearchResults']

if len(evidence_list) == 0:
continue

for evidence_item in evidence_list:
evidence_file_path = os.path.join(self.evidence_dir, 'web',
evidence_item['Filename'])
with open(evidence_file_path) as evidence_file:
evidence = 'Title: %s\n' % evidence_item['Title']
evidence += evidence_file.read()
yield (evidence + '\n' + question, answers), True


class VerifiedWebTeacher(WebTeacher):
def __init__(self, opt, shared=None):
self.prefix = 'verified-'
self.suffix = 'dev'
if opt['datatype'] != 'valid':
print('WARNING: Verified teacher only provides dev data')

opt['datafile'], self.evidence_dir = _path(opt)
self.id = 'triviaqa'
super().__init__(opt, shared)


class WikipediaTeacher(DialogTeacher):
def __init__(self, opt, shared=None):
if not hasattr(self, 'prefix'):
self.prefix = ''
if opt['datatype'].startswith('train'):
self.suffix = 'train'
else:
self.suffix = 'dev'

qa_dir, self.evidence_dir = _path(opt)
opt['datafile'] = os.path.join(qa_dir, self.prefix + 'wikipedia-' +
self.suffix + '.json')

self.id = 'triviaqa'
super().__init__(opt, shared)

def setup_data(self, path):
print('loading: ' + path)
with open(path) as data_file:
data = json.load(data_file)['Data']
for datapoint in data:
question = datapoint['Question']
answers = datapoint['Answer']['Aliases']
evidence_list = datapoint['EntityPages']

if len(evidence_list) == 0:
continue

evidence = ''
for evidence_item in evidence_list:
evidence_file_path = os.path.join(self.evidence_dir,
'wikipedia',
evidence_item['Filename'])
with open(evidence_file_path) as evidence_file:
evidence += 'Title: %s\n' % evidence_item['Title']
evidence += evidence_file.read() + '\n\n'

yield (evidence + question, answers), True


class VerifiedWikipediaTeacher(WikipediaTeacher):
def __init__(self, opt, shared=None):
self.prefix = 'verified-'
self.suffix = 'dev'
if opt['datatype'] != 'valid':
print('WARNING: Verified teacher only provides dev data')

opt['datafile'], self.evidence_dir = _path(opt)
self.id = 'triviaqa'
super().__init__(opt, shared)


class VerifiedTeacher(MultiTaskTeacher):
def __init__(self, opt, shared=None):
opt = copy.deepcopy(opt)
opt['task'] = 'triviaqa:VerifiedWikipedia,triviaqa:VerifiedWeb'
super().__init__(opt, shared)

class DefaultTeacher(MultiTaskTeacher):
def __init__(self, opt, shared=None):
opt = copy.deepcopy(opt)
opt['task'] = 'triviaqa:wikipedia,triviaqa:web'
super().__init__(opt, shared)
31 changes: 31 additions & 0 deletions parlai/tasks/triviaqa/build.py
@@ -0,0 +1,31 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
#
# Download and build the data if it does not exist.

import parlai.core.build_data as build_data
import os


def build(opt):
dpath = os.path.join(opt['datapath'], 'TriviaQA')
version = None

if not build_data.built(dpath, version_string=version):
print('[building data: ' + dpath + ']')
if build_data.built(dpath):
# An older version exists, so remove these outdated files.
build_data.remove_dir(dpath)
build_data.make_dir(dpath)

# Download the data.
fname = 'triviaqa-rc.tar.gz'
url = 'http://nlp.cs.washington.edu/triviaqa/data/'
build_data.download(url + fname, dpath, fname)
build_data.untar(dpath, fname)

# Mark the data as built.
build_data.mark_done(dpath, version_string=version)
16 changes: 16 additions & 0 deletions tests/test_downloads.py
Expand Up @@ -284,6 +284,22 @@ def test_squad(self):

shutil.rmtree(self.TMP_PATH)

def test_triviaqa(self):
from parlai.core.params import ParlaiParser
from parlai.tasks.triviaqa.agents import WebTeacher, WikipediaTeacher

opt = ParlaiParser().parse_args(args=self.args)

for teacher_class in (WebTeacher, WikipediaTeacher):
for dt in ['train:ordered', 'valid']:
opt['datatype'] = dt

teacher = teacher_class(opt)
reply = teacher.act()
check(opt, reply)

shutil.rmtree(self.TMP_PATH)

def test_ubuntu(self):
from parlai.core.params import ParlaiParser
from parlai.tasks.ubuntu.agents import DefaultTeacher
Expand Down

0 comments on commit 838a042

Please sign in to comment.