This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1fdebcb
commit 838a042
Showing
5 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. An additional grant | ||
# of patent rights can be found in the PATENTS file in the same directory. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. An additional grant | ||
# of patent rights can be found in the PATENTS file in the same directory. | ||
|
||
from parlai.core.dialog_teacher import DialogTeacher | ||
from parlai.core.agents import MultiTaskTeacher | ||
from .build import build | ||
|
||
import copy | ||
import json | ||
import os | ||
import random | ||
|
||
def _path(opt): | ||
build(opt) | ||
|
||
return (os.path.join(opt['datapath'], 'TriviaQA', 'qa'), | ||
os.path.join(opt['datapath'], 'TriviaQA', 'evidence')) | ||
|
||
|
||
class WebTeacher(DialogTeacher): | ||
def __init__(self, opt, shared=None): | ||
if not hasattr(self, 'prefix'): | ||
self.prefix = '' | ||
if opt['datatype'].startswith('train'): | ||
self.suffix = 'train' | ||
else: | ||
self.suffix = 'dev' | ||
|
||
qa_dir, self.evidence_dir = _path(opt) | ||
opt['datafile'] = os.path.join(qa_dir, self.prefix + 'web-' + | ||
self.suffix + '.json') | ||
self.id = 'triviaqa' | ||
super().__init__(opt, shared) | ||
|
||
def setup_data(self, path): | ||
print('loading: ' + path) | ||
with open(path) as data_file: | ||
data = json.load(data_file)['Data'] | ||
for datapoint in data: | ||
question = datapoint['Question'] | ||
answers = datapoint['Answer']['Aliases'] | ||
evidence_list = datapoint['SearchResults'] | ||
|
||
if len(evidence_list) == 0: | ||
continue | ||
|
||
for evidence_item in evidence_list: | ||
evidence_file_path = os.path.join(self.evidence_dir, 'web', | ||
evidence_item['Filename']) | ||
with open(evidence_file_path) as evidence_file: | ||
evidence = 'Title: %s\n' % evidence_item['Title'] | ||
evidence += evidence_file.read() | ||
yield (evidence + '\n' + question, answers), True | ||
|
||
|
||
class VerifiedWebTeacher(WebTeacher): | ||
def __init__(self, opt, shared=None): | ||
self.prefix = 'verified-' | ||
self.suffix = 'dev' | ||
if opt['datatype'] != 'valid': | ||
print('WARNING: Verified teacher only provides dev data') | ||
|
||
opt['datafile'], self.evidence_dir = _path(opt) | ||
self.id = 'triviaqa' | ||
super().__init__(opt, shared) | ||
|
||
|
||
class WikipediaTeacher(DialogTeacher): | ||
def __init__(self, opt, shared=None): | ||
if not hasattr(self, 'prefix'): | ||
self.prefix = '' | ||
if opt['datatype'].startswith('train'): | ||
self.suffix = 'train' | ||
else: | ||
self.suffix = 'dev' | ||
|
||
qa_dir, self.evidence_dir = _path(opt) | ||
opt['datafile'] = os.path.join(qa_dir, self.prefix + 'wikipedia-' + | ||
self.suffix + '.json') | ||
|
||
self.id = 'triviaqa' | ||
super().__init__(opt, shared) | ||
|
||
def setup_data(self, path): | ||
print('loading: ' + path) | ||
with open(path) as data_file: | ||
data = json.load(data_file)['Data'] | ||
for datapoint in data: | ||
question = datapoint['Question'] | ||
answers = datapoint['Answer']['Aliases'] | ||
evidence_list = datapoint['EntityPages'] | ||
|
||
if len(evidence_list) == 0: | ||
continue | ||
|
||
evidence = '' | ||
for evidence_item in evidence_list: | ||
evidence_file_path = os.path.join(self.evidence_dir, | ||
'wikipedia', | ||
evidence_item['Filename']) | ||
with open(evidence_file_path) as evidence_file: | ||
evidence += 'Title: %s\n' % evidence_item['Title'] | ||
evidence += evidence_file.read() + '\n\n' | ||
|
||
yield (evidence + question, answers), True | ||
|
||
|
||
class VerifiedWikipediaTeacher(WikipediaTeacher): | ||
def __init__(self, opt, shared=None): | ||
self.prefix = 'verified-' | ||
self.suffix = 'dev' | ||
if opt['datatype'] != 'valid': | ||
print('WARNING: Verified teacher only provides dev data') | ||
|
||
opt['datafile'], self.evidence_dir = _path(opt) | ||
self.id = 'triviaqa' | ||
super().__init__(opt, shared) | ||
|
||
|
||
class VerifiedTeacher(MultiTaskTeacher): | ||
def __init__(self, opt, shared=None): | ||
opt = copy.deepcopy(opt) | ||
opt['task'] = 'triviaqa:VerifiedWikipedia,triviaqa:VerifiedWeb' | ||
super().__init__(opt, shared) | ||
|
||
class DefaultTeacher(MultiTaskTeacher): | ||
def __init__(self, opt, shared=None): | ||
opt = copy.deepcopy(opt) | ||
opt['task'] = 'triviaqa:wikipedia,triviaqa:web' | ||
super().__init__(opt, shared) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. An additional grant | ||
# of patent rights can be found in the PATENTS file in the same directory. | ||
# | ||
# Download and build the data if it does not exist. | ||
|
||
import parlai.core.build_data as build_data | ||
import os | ||
|
||
|
||
def build(opt): | ||
dpath = os.path.join(opt['datapath'], 'TriviaQA') | ||
version = None | ||
|
||
if not build_data.built(dpath, version_string=version): | ||
print('[building data: ' + dpath + ']') | ||
if build_data.built(dpath): | ||
# An older version exists, so remove these outdated files. | ||
build_data.remove_dir(dpath) | ||
build_data.make_dir(dpath) | ||
|
||
# Download the data. | ||
fname = 'triviaqa-rc.tar.gz' | ||
url = 'http://nlp.cs.washington.edu/triviaqa/data/' | ||
build_data.download(url + fname, dpath, fname) | ||
build_data.untar(dpath, fname) | ||
|
||
# Mark the data as built. | ||
build_data.mark_done(dpath, version_string=version) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters