This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: add friends dataset for multiparty convo * feat: generate examples for all 6 main characters in the friends corpus by default * fix: remove unused file * fix: add testing * fix: add speaker label inside label * feat: add support for data folds; clean up code * fix: skip __MACOS folder for zipped files to avoid exception * fix: formatting with autoformat.sh * feat: add convenience teacher classes * feat: add command line option to specify list of characters * undo changes to build_data.py * cleanup * style fix
- Loading branch information
1 parent
97377a3
commit 8fc3fec
Showing
29 changed files
with
1,872 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Optional | ||
from parlai.core.opt import Opt | ||
from parlai.core.teachers import DialogTeacher | ||
from parlai.core.params import ParlaiParser | ||
from .build import build | ||
from collections import defaultdict | ||
import jsonlines | ||
from parlai.utils.data import DatatypeHelper | ||
|
||
import copy | ||
import os | ||
|
||
START_TOKEN = '__START__' | ||
SILENCE_TOKEN = '__SILENCE__' | ||
|
||
|
||
def _path(opt, filename): | ||
return os.path.join(opt['datapath'], 'Friends', filename) | ||
|
||
|
||
class DefaultTeacher(DialogTeacher): | ||
def __init__(self, opt, shared=None): | ||
opt = copy.deepcopy(opt) | ||
build(opt) | ||
self.fold = DatatypeHelper.fold(opt['datatype']) | ||
opt['datafile'] = _path(opt, self.fold + '.jsonl') | ||
self.characters = opt['characters'].split(',') | ||
self.character = opt['character'] | ||
self.use_silence_token = opt['use_silence_token'] | ||
self.silence_token = opt['silence_token'] | ||
self.use_start_token = opt['use_start_token'] | ||
self.start_token = opt['start_token'] | ||
super().__init__(opt, shared) | ||
|
||
def setup_data(self, datafile): | ||
conversations = defaultdict(list) | ||
|
||
with jsonlines.open(datafile) as reader: | ||
for utterance in reader: | ||
text = utterance['text'] | ||
speaker = utterance['speaker'] | ||
conversation_id = utterance['conversation_id'] | ||
|
||
conversations[conversation_id].append( | ||
{"text": text, "speaker": speaker} | ||
) | ||
|
||
for conversation_id in conversations: | ||
utterances = conversations[conversation_id] | ||
characters = set( | ||
[u['speaker'] for u in utterances if u['speaker'] in self.characters] | ||
) | ||
characters_string = ','.join( | ||
sorted(list(characters)) | ||
) # sorted to ensure same order across runs | ||
last_utterance_index = len(utterances) - 1 | ||
|
||
for index, utterance in enumerate(utterances): | ||
if index == 0: | ||
if self.use_start_token: | ||
context = self.start_token | ||
|
||
else: # skip the first utterance since there's no context | ||
speaker = utterance['speaker'] | ||
text = utterance['text'] | ||
context = f'{speaker}: {text}' | ||
continue | ||
|
||
speaker = utterance['speaker'] | ||
text = utterance['text'] | ||
|
||
prev_context = context | ||
context += '\n' + f'{speaker}: {text}' | ||
|
||
isConversationDone = index == last_utterance_index | ||
|
||
# By default, generate training examples for all 6 main characters. | ||
# Otherwise only generate training examples for the chosen character. | ||
if ( | ||
self.character == 'All' and speaker in self.characters | ||
) or speaker == self.character: | ||
yield { | ||
"text": prev_context, | ||
"label": f'{speaker}: {text}', | ||
"characters": characters_string, | ||
}, isConversationDone | ||
elif self.use_silence_token: | ||
yield { | ||
"text": prev_context, | ||
"label": f'{self.character}: {self.silence_token}', | ||
"characters": characters_string, | ||
}, isConversationDone | ||
|
||
@classmethod | ||
def add_cmdline_args( | ||
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None | ||
) -> ParlaiParser: | ||
super().add_cmdline_args(parser, partial_opt) | ||
agent = parser.add_argument_group('Friends Corpus Arguments') | ||
agent.add_argument( | ||
'--character', | ||
type=str, | ||
default='All', | ||
choices=[ | ||
'All', | ||
'Rachel Green', | ||
'Monica Geller', | ||
'Phoebe Buffay', | ||
'Joey Tribbiani', | ||
'Chandler Bing', | ||
'Ross Geller', | ||
], | ||
help='Which speaker labels to train on', | ||
) | ||
agent.add_argument( | ||
'--characters', | ||
type=str, | ||
default='Rachel Green,Monica Geller,Phoebe Buffay,Joey Tribbiani,Chandler Bing,Ross Geller', | ||
help='A comma-separated list of characters to train on when `--character` == `All`', | ||
) | ||
agent.add_argument( | ||
'--use-silence-token', | ||
type='bool', | ||
default=True, | ||
help='Use silence token to generate training example for sentences where the chosen speaker is not speaking. Defaults to True.', | ||
) | ||
agent.add_argument( | ||
'--silence-token', | ||
type=str, | ||
default=SILENCE_TOKEN, | ||
help='The token to use to indicate the chosen speaker is silent. Defaults to __SILENCE__', | ||
) | ||
agent.add_argument( | ||
'--use-start-token', | ||
type='bool', | ||
default=False, | ||
help='Use start token at the beginning of each conversation, and include the first sentence as a training example. Defaults to False.', | ||
) | ||
agent.add_argument( | ||
'--start-token', | ||
type=str, | ||
default=START_TOKEN, | ||
help='The token to use to indicate the beginning of a conversation. Defaults to __START__', | ||
) | ||
return parser | ||
|
||
|
||
class AllCharactersTeacher(DefaultTeacher): | ||
def __init__(self, opt, shared=None): | ||
opt['character'] = 'All' | ||
super().__init__(opt, shared) | ||
|
||
|
||
class RachelTeacher(DefaultTeacher): | ||
def __init__(self, opt, shared=None): | ||
opt['character'] = 'Rachel Green' | ||
super().__init__(opt, shared) | ||
|
||
|
||
class MonicaTeacher(DefaultTeacher): | ||
def __init__(self, opt, shared=None): | ||
opt['character'] = 'Monica Geller' | ||
super().__init__(opt, shared) | ||
|
||
|
||
class PhoebeTeacher(DefaultTeacher): | ||
def __init__(self, opt, shared=None): | ||
opt['character'] = 'Phoebe Buffay' | ||
super().__init__(opt, shared) | ||
|
||
|
||
class JoeyTeacher(DefaultTeacher): | ||
def __init__(self, opt, shared=None): | ||
opt['character'] = 'Joey Tribbiani' | ||
super().__init__(opt, shared) | ||
|
||
|
||
class ChandlerTeacher(DefaultTeacher): | ||
def __init__(self, opt, shared=None): | ||
opt['character'] = 'Chandler Bing' | ||
super().__init__(opt, shared) | ||
|
||
|
||
class RossTeacher(DefaultTeacher): | ||
def __init__(self, opt, shared=None): | ||
opt['character'] = 'Ross Geller' | ||
super().__init__(opt, shared) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# Download and build the data if it does not exist. | ||
|
||
import parlai.core.build_data as build_data | ||
from parlai.core.build_data import DownloadableFile | ||
import os | ||
import jsonlines | ||
from collections import defaultdict | ||
from sklearn.model_selection import train_test_split | ||
|
||
RANDOM_SEED = 123 | ||
|
||
RESOURCES = [ | ||
DownloadableFile( | ||
'http://zissou.infosci.cornell.edu/convokit/datasets/friends-corpus/friends-corpus.zip', | ||
'friends-corpus.zip', | ||
'51ae80ce345212839d256b59b4982e9b40229ff6049115bd54d885a285d2b921', | ||
zipped=True, | ||
) | ||
] | ||
|
||
|
||
def generate_folds(dpath): | ||
""" | ||
Generate Data Folds based on the scene id. | ||
""" | ||
datafile = os.path.join(dpath, 'friends-corpus/utterances.jsonl') | ||
train_datafile = os.path.join(dpath, 'train.jsonl') | ||
valid_datafile = os.path.join(dpath, 'valid.jsonl') | ||
test_datafile = os.path.join(dpath, 'test.jsonl') | ||
|
||
# Load the dataset | ||
conversations = defaultdict(list) | ||
with jsonlines.open(datafile) as reader: | ||
for utterance in reader: | ||
text = utterance['text'] | ||
speaker = utterance['speaker'] | ||
conversation_id = utterance['conversation_id'] | ||
|
||
if speaker != 'TRANSCRIPT_NOTE': | ||
conversations[conversation_id].append( | ||
{ | ||
"text": text, | ||
"speaker": speaker, | ||
"conversation_id": conversation_id, | ||
} | ||
) | ||
|
||
# Split the dataset into 80% train, 10% valid, 10% test | ||
train, valid_and_test = train_test_split( | ||
list(conversations.keys()), test_size=0.2, random_state=RANDOM_SEED | ||
) | ||
valid, test = train_test_split( | ||
valid_and_test, test_size=0.5, random_state=RANDOM_SEED | ||
) | ||
|
||
# Save the data folds into separate files | ||
with jsonlines.open(train_datafile, mode='w') as writer: | ||
for conversation_id in train: | ||
for utterance in conversations[conversation_id]: | ||
writer.write(utterance) | ||
|
||
with jsonlines.open(valid_datafile, mode='w') as writer: | ||
for conversation_id in valid: | ||
for utterance in conversations[conversation_id]: | ||
writer.write(utterance) | ||
|
||
with jsonlines.open(test_datafile, mode='w') as writer: | ||
for conversation_id in test: | ||
for utterance in conversations[conversation_id]: | ||
writer.write(utterance) | ||
|
||
|
||
def build(opt): | ||
dpath = os.path.join(opt['datapath'], 'Friends') | ||
version = '1.00' | ||
|
||
if not build_data.built(dpath, version_string=version): | ||
print('[building data: ' + dpath + ']') | ||
if build_data.built(dpath): | ||
# An older version exists, so remove these outdated files. | ||
build_data.remove_dir(dpath) | ||
build_data.make_dir(dpath) | ||
|
||
# Download the data. | ||
for downloadable_file in RESOURCES: | ||
downloadable_file.download_file(dpath) | ||
|
||
generate_folds(dpath) | ||
|
||
# Mark the data as built. | ||
build_data.mark_done(dpath, version_string=version) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from parlai.utils.testing import AutoTeacherTest # noqa: F401 | ||
|
||
|
||
class TestDefaultTeacher(AutoTeacherTest): | ||
task = 'friends' | ||
|
||
|
||
class TestAllCharactersTeacher(AutoTeacherTest): | ||
task = 'friends:all_characters' | ||
|
||
|
||
class TestRachelTeacher(AutoTeacherTest): | ||
task = 'friends:rachel' | ||
|
||
|
||
class TestMonicaTeacher(AutoTeacherTest): | ||
task = 'friends:monica' | ||
|
||
|
||
class TestPhoebeTeacher(AutoTeacherTest): | ||
task = 'friends:phoebe' | ||
|
||
|
||
class TestJoeyTeacher(AutoTeacherTest): | ||
task = 'friends:joey' | ||
|
||
|
||
class TestChandlerTeacher(AutoTeacherTest): | ||
task = 'friends:chandler' | ||
|
||
|
||
class TestRossTeacher(AutoTeacherTest): | ||
task = 'friends:ross' |
Oops, something went wrong.