diff --git a/parlai/tasks/genderation_bias/README.md b/parlai/tasks/genderation_bias/README.md new file mode 100644 index 00000000000..19ca3f5ccd7 --- /dev/null +++ b/parlai/tasks/genderation_bias/README.md @@ -0,0 +1,47 @@ +Task: Genderation Bias +====================== +Description: The task in this directory is not a task itself, but rather a wrapper. The task will flatten a specified other ParlAI task (that is, turn multi-turn episodes into single-turn examples), and append a control token that corresponds with the level of gender present in the label, where word lists from https://github.com/uclanlp/gn_glove/blob/master/wordlist/ are used to count the number of gendered words. Depending on the counts, the control token will be one of the following: + +- `f0m0` - no gender words in the label +- `f0m1` - there is at least one male-specific word in the label +- `f1m0` - there is at least one female-specific word in the label +- `f1m1` - there is at least one male-specific word AND one female-specific word in the label + +For example, one could run the following command: + +``` +$ parlai display_data -t genderation_bias:controllable_task:convai2 +``` + +Which would yield the following: + +``` +- - - NEW EPISODE: genderation_bias:controllable_task:convai2 - - - +your persona: my mom is my best friend. +your persona: i have four sisters. +your persona: i believe that mermaids are real. +your persona: i love iced tea. +hi , how are you doing today ? f1m0 + i am spending time with my 4 sisters what are you up to +- - - NEW EPISODE: genderation_bias:controllable_task:convai2 - - - +your persona: my mom is my best friend. +your persona: i have four sisters. +your persona: i believe that mermaids are real. +your persona: i love iced tea. +hi , how are you doing today ? +i am spending time with my 4 sisters what are you up to +wow , four sisters . just watching game of thrones . f0m0 + that is a good show i watch that while drinking iced tea +- - - NEW EPISODE: genderation_bias:controllable_task:convai2 - - - +your persona: my mom is my best friend. +your persona: i have four sisters. +your persona: i believe that mermaids are real. +your persona: i love iced tea. +hi , how are you doing today ? +i am spending time with my 4 sisters what are you up to +wow , four sisters . just watching game of thrones . +that is a good show i watch that while drinking iced tea +i agree . what do you do for a living ? f0m0 + i'm a researcher i'm researching the fact that mermaids are real +16:33:19 | loaded 131438 episodes with a total of 131438 examples +``` \ No newline at end of file diff --git a/parlai/tasks/genderation_bias/__init__.py b/parlai/tasks/genderation_bias/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/tasks/genderation_bias/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/tasks/genderation_bias/agents.py b/parlai/tasks/genderation_bias/agents.py new file mode 100644 index 00000000000..a7602ed36cd --- /dev/null +++ b/parlai/tasks/genderation_bias/agents.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Generates a controllable_gen version of a ParlAI task, i.e., for tasks with multi-turn +dialogues (episodes), this will generate a task with single example episodes, in which +we append to the context a special classification token. + +In order to use this teacher, specify the task flag as follows: +`--task genderation_bias:controllable_task:`. + +As an example, try running: + +`parlai display_data -t genderation_bias:controllable_task:convai2` +""" + +from parlai.core.message import Message +from parlai.core.opt import Opt +from parlai.core.teachers import FixedDialogTeacher +from parlai.utils.io import PathManager +import parlai.utils.logging as logging +from parlai.utils.typing import TShared + +from parlai.tasks.genderation_bias.build import build +from parlai.tasks.genderation_bias.utils import ( + flatten_and_classify, + get_original_task_module, +) + +from copy import deepcopy +import datetime +import glob +import json +import os +from tqdm import tqdm +from typing import List, Optional, Tuple + + +class ControllableTaskTeacher(FixedDialogTeacher): + """ + Generates a controllable_gen version of a ParlAI task, i.e., for tasks with multi- + turn dialogues (episodes), this will generate a task with single example episodes, + in which we append to the context a special classification token. + """ + + @staticmethod + def add_cmdline_args(parser): + flattened = parser.add_argument_group('ControllableTaskTeacher Flattening Args') + flattened.add_argument( + '--flatten-include-labels', + type='bool', + default=True, + help='Include labels in the history when flattening an episode', + ) + flattened.add_argument( + '--flatten-delimiter', + type=str, + default='\n', + help='How to join the dialogue history from previous turns.', + ) + flattened.add_argument( + '--flatten-max-context-length', + type=int, + default=-1, + help='Maximum number of utterances to include per episode. ' + 'Default -1 keeps all.', + ) + agent = parser.add_argument_group('ControllableTaskTeacher Args') + agent.add_argument( + '--invalidate-cache', + type='bool', + default=False, + help='Set this to True to rebuild the data (may want to do this if ' + 'original data has changed or you want to rebuild with new options)', + ) + agent.add_argument( + '--max-examples', + type=int, + default=-1, + help='If greater than zero, will stop building after a certain num of exs', + ) + agent.add_argument( + '--fixed-control', + type=str, + default='', + help='Always append this fixed control string, good for deploy time.', + ) + # Add the arguments for the task teacher + opt = parser.parse_and_process_known_args()[0] + tasks = get_original_task_module(opt, multi_possible=True) + for task in tasks: + if hasattr(task, 'add_cmdline_args'): + task.add_cmdline_args(parser) + + return parser + + def __init__(self, opt: Opt, shared: TShared = None): + assert opt['flatten_delimiter'] == opt.get( + 'delimiter', '\n' + ), '--flatten-delimiter and --delimiter are set differently, please inspect and set to the same to avoid unexpected results' + self.opt = opt + + if shared and 'data' in shared: + self.data = shared['data'] + else: + self.word_lists = self.build_wordlists(opt) + self.data = self._setup_data(opt) + + super().__init__(opt, shared) + self.reset() + + def num_episodes(self) -> int: + return len(self.data) + + def num_examples(self) -> int: + return len(self.data) + + def _get_save_path(self, datapath: str, date: str) -> str: + """ + Return save path for the controllable gen data. + + :param datapath: + path to ParlAI Data + :param date: + current date + + :return path: + return path to save + """ + return os.path.join( + datapath, + f"{self.original_task_name.replace(':', '_')}_flattened_controllable_gen_{date}", + ) + + @classmethod + def build_wordlists(cls, opt: Opt) -> Tuple[List[str], List[str]]: + """ + Load list of explicitly gendered words. + + Words taken from . + + Examples include brother, girl, actress, husbands, etc. + """ + build(opt['datapath']) + folder = os.path.join(opt['datapath'], 'genderation_bias') + male_words = os.path.join(folder, 'male_word_file.txt') + female_words = os.path.join(folder, 'female_word_file.txt') + + with open(male_words, 'r') as f: + male = f.read().splitlines() + + with open(female_words, 'r') as f: + female = f.read().splitlines() + + return male, female + + def _setup_data(self, opt: Opt) -> List[List[Message]]: + """ + Flatten and classify the normal task data. + + Save/load where applicable. + + :param opt: + options dict. + """ + # create save directory, if it does not already exist + self.original_task_name = ':'.join(opt['task'].split(':')[2:]) + self.save_dir = self._get_save_path( + opt['datapath'], str(datetime.datetime.today()) + ) + os.makedirs(self.save_dir, exist_ok=True) + + fname = f"{opt['datatype'].split(':')[0]}.json" + self.save_path = os.path.join(self.save_dir, fname) + + data = self.load_data(opt, fname) + if data is not None: + # successfully load data + return data + + # build the original teacher + original_task_module = get_original_task_module(opt) + teacher_opt = deepcopy(opt) + teacher_opt['task'] = self.original_task_name + teacher = original_task_module(teacher_opt) + + total_exs = teacher.num_examples() + if self.opt['max_examples'] > 0: + total_exs = min(self.opt['max_examples'], total_exs) + + progress_bar = tqdm( + total=total_exs, unit='ex', unit_scale=True, desc='Building flattened data' + ) + + all_episodes = [] + num_exs = 0 + while num_exs < total_exs: + current_episode = [] + episode_done = False + + while not episode_done: + action = Message(teacher.act()) + current_episode.append(action) + episode_done = action.get('episode_done', False) + num_exs += 1 + + # flatten the episode into 1-example episodes with context + flattened_ep = flatten_and_classify( + current_episode, + opt['flatten_max_context_length'], + include_labels=opt['flatten_include_labels'], + delimiter=opt['flatten_delimiter'], + word_lists=self.word_lists, + ) + all_episodes += flattened_ep + + progress_bar.update(len(flattened_ep)) + + # save data for future use + self.save_data(all_episodes) + + return all_episodes + + def load_data(self, opt: Opt, filename: str) -> Optional[List[List[Message]]]: + """ + Attempt to load pre-build data. + + Checks for the most recently build data via the date string. + + :param opt: + options dict + :param filename: + name of (potentially) saved data + + :return episodes: + return list of episodes, if available + """ + # first check for the most recent date + save_dir = self._get_save_path(opt['datapath'], '*') + all_dates = [] + for fname in glob.glob(os.path.join(save_dir, filename)): + date = os.path.split(fname)[0].split('_')[-1] + all_dates.append(date) + + if len(all_dates) > 0: + most_recent = os.path.join( + self._get_save_path(opt['datapath'], sorted(all_dates)[-1]), filename + ) + else: + # data has not been built yet + return None + + if opt['invalidate_cache']: + # invalidate the cache and remove the existing data + logging.warn( + f' [ WARNING: invalidating cache at {self.save_path} and rebuilding the data. ]' + ) + if self.save_path == most_recent: + os.remove(self.save_path) + return None + + # Loading from most recent date + self.save_path = most_recent + logging.info(f' [ Data already exists. Loading from: {self.save_path} ]') + with PathManager.open(self.save_path, 'rb') as f: + data = json.load(f) + + return data + + def save_data(self, data: List[List[Message]]): + """ + Save the data via dumping to a json file. + + :param data: + list of episodes + """ + try: + json_data = json.dumps(data) + with PathManager.open(self.save_path, 'w') as f: + f.write(json_data) + logging.info(f'[ Data successfully saved to path: {self.save_path} ]') + except Exception: + logging.warn('Data is not json serializable; not saving') + + def get(self, episode_idx: int, entry_idx: int = 0) -> Message: + """ + Return a flattened example. + + If using a fixed control, put that in instead of what was originally in the text. + + :param episode_idx: + index of ep in data + :param entry_idx: + index of ex in ep + + :return ex: + return an example + """ + ex = Message(self.data[episode_idx]) + + if self.opt['fixed_control'] != '': + old_text = ' '.join(ex['text'].split(' ')[:-1]) + text = f"{old_text} {self.opt['fixed_control']}" + ex.force_set('text', text) + + return ex + + def share(self): + shared = super().share() + shared['data'] = self.data + return shared + + +class DefaultTeacher(ControllableTaskTeacher): + pass diff --git a/parlai/tasks/genderation_bias/build.py b/parlai/tasks/genderation_bias/build.py new file mode 100644 index 00000000000..f5e121acfd4 --- /dev/null +++ b/parlai/tasks/genderation_bias/build.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from parlai.core.build_data import DownloadableFile +import parlai.core.build_data as build_data +import parlai.utils.logging as logging + +RESOURCES = [ + DownloadableFile( + 'https://raw.githubusercontent.com/uclanlp/gn_glove/master/wordlist/male_word_file.txt', + 'male_word_file.txt', + 'd431679ce3ef4134647e22cb5fd89e8dbee3f04636f1c7cbae5f28a369acf60f', + zipped=False, + ), + DownloadableFile( + 'https://raw.githubusercontent.com/uclanlp/gn_glove/master/wordlist/female_word_file.txt', + 'female_word_file.txt', + '5f0803f056de3fbc459589bce26272d3c5453112a3a625fb8ee99c0fbbed5b35', + zipped=False, + ), +] + + +def build(datapath): + version = 'v1.0' + dpath = os.path.join(datapath, 'genderation_bias') + if not build_data.built(dpath, version): + logging.info('[building data: ' + dpath + ']') + if build_data.built(dpath): + # An older version exists, so remove these outdated files. + build_data.remove_dir(dpath) + build_data.make_dir(dpath) + + # Download the data. + for downloadable_file in RESOURCES: + downloadable_file.download_file(dpath) + + # Mark the data as built. + build_data.mark_done(dpath, version) diff --git a/parlai/tasks/genderation_bias/utils.py b/parlai/tasks/genderation_bias/utils.py new file mode 100644 index 00000000000..5aa3a16e0c9 --- /dev/null +++ b/parlai/tasks/genderation_bias/utils.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Utils for Controllable Gen Teacher. +""" +from collections import deque +import random + +from parlai.core.loader import load_teacher_module +from parlai.core.message import Message +from parlai.core.opt import Opt + +from typing import List, Tuple + +PUNCTUATION_LST = [ + (' .', '.'), + (' !', '!'), + (' ?', '?'), + (' ,', ','), + (" ' ", "'"), + (" . . . ", "... "), + (" ( ", " ("), + (" ) ", ") "), + (" ; ", "; "), +] + + +def format_text(text: str, lower: bool = True) -> str: + """ + Space punctuation and lowercase text. + + :param text: + text to lowercase + :param lower: + whether to lowercase or not + + :return text: + return formatted text. + """ + if lower: + text = text.lower() + for punc in PUNCTUATION_LST: + text = text.replace(punc[1], punc[0]) + + return text + + +def get_word_list_token(text: str, word_lists: Tuple[List[str], List[str]]) -> str: + """ + Return a control token corresponding to gender within text. + + :param text: + text to consider for control token + :param word_lists: + tuple of lists for male-specific and female-specific words + + :return token: + return control token corresponding to input text. + """ + m_list, f_list = word_lists + text = format_text(text) + m_cnt = 0 + f_cnt = 0 + for word in text.split(' '): + if word in m_list: + m_cnt += 1 + if word in f_list: + f_cnt += 1 + + if f_cnt == 0 and m_cnt == 0: + return 'f0m0' + elif f_cnt == 0 and m_cnt > 0: + return 'f0m1' + elif f_cnt > 0 and m_cnt == 0: + return 'f1m0' + else: + return 'f1m1' + + +def flatten_and_classify( + episode: List[Message], + context_length: int, + word_lists: Tuple[List[str], List[str]], + include_labels: bool = True, + delimiter: str = '\n', +): + """ + Flatten the dialogue history of an episode, explode into N new examples. + + Additionally, add control token corresponding to gender identified in the + episode. + + :param episode: + list of examples to flatten + :param context_length: + max number of utterances to use while flattening + :param word_lists: + tuple of lists for male-specific and female-specific words + :param include_labels: + whether to include labels while flattening + :param delimiter: + delimiter to use while flattening + """ + context = deque(maxlen=context_length if context_length > 0 else None) + new_episode = [] + + for ex in episode: + context.append(ex.get('text', '')) + # add context + if len(context) > 1: + ex.force_set('text', delimiter.join(context)) + # set episode_done to be True + ex.force_set('episode_done', True) + labels = ex.get('labels', ex.get('eval_labels', None)) + if labels is not None and include_labels: + context.append(random.choice(labels)) + + # word list + control_tok = get_word_list_token(random.choice(labels), word_lists) + ex.force_set('text', ex['text'] + ' ' + control_tok) + new_episode.append(ex) + + return new_episode + + +def get_original_task_module(opt: Opt, multi_possible: bool = False): + """ + Returns task module of "original" task. + + Original task in this case means the task we want to use + with the control teacher. + + :param opt: + opt dict + :param multi_possible: + specify True if multiple tasks are possible. + + :return task_module: + return module associated with task. + """ + modules = [] + tasks = opt['task'].split(',') + if not multi_possible: + assert len(tasks) == 1 + + for task in tasks: + if len(task.split(':')) < 3: + raise RuntimeError( + '\n\n********************************************************\n' + 'Must specify original task using the following format:\n' + '`--task internal:flattened:task:`' + '\n********************************************************\n' + ) + original_task = ':'.join(task.split(':')[2:]) + task_module = load_teacher_module(original_task) + modules.append(task_module) + + if multi_possible: + return modules + + return modules[0] diff --git a/parlai/tasks/task_list.py b/parlai/tasks/task_list.py index 951ad22c08c..a4297b99061 100644 --- a/parlai/tasks/task_list.py +++ b/parlai/tasks/task_list.py @@ -1309,4 +1309,14 @@ "for task oriented dialogue in 7 domains." ), }, + { + "id": "GenderationBiasControlTask", + "display_name": "GenderationBiasControlTask", + "task": "genderation_bias:controllable_task", + "tags": ["All"], + "description": ( + "A teacher that wraps other ParlAI tasks and appends control tokens to the " + "text field indicating the presence of gender words in the label(s)." + ), + }, ] diff --git a/tests/tasks/test_genderation_bias.py b/tests/tasks/test_genderation_bias.py new file mode 100644 index 00000000000..560ba9ae139 --- /dev/null +++ b/tests/tasks/test_genderation_bias.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import unittest + +from parlai.core.message import Message +from parlai.core.params import ParlaiParser +from parlai.tasks.genderation_bias.agents import ControllableTaskTeacher +from parlai.tasks.genderation_bias.utils import flatten_and_classify + + +class TestGenderationBiasTeacher(unittest.TestCase): + """ + Tests for the Genderation Bias Teacher. + + For now, just test the flatten_and_classify utility function. + """ + + def test_flatten_and_classify(self): + word_lists = ControllableTaskTeacher.build_wordlists( + ParlaiParser().parse_args([]) + ) + utterances = [ + "hello there", + "hi there dad, what's up", + "not much, do you know where your sister is?", + "I have not seen her, I thought she was with grandpa", + "well, if you see her, let me know", + "will do!", + "ok, have a good day", + "bye bye! tell mom I say hello", + ] + tokens = ['f0m1', 'f1m1', 'f0m0', 'f1m0'] + episode = [ + Message( + { + 'text': utterances[i], + 'labels': [utterances[i + 1]], + 'episode_done': False, + } + ) + for i in range(0, len(utterances) - 1, 2) + ] + episode[-1].force_set('episode_done', True) + new_episode = flatten_and_classify(episode, -1, word_lists) + assert len(new_episode) == 4 + assert all( + ex['text'].endswith(tok) for ex, tok in zip(new_episode, tokens) + ), f"new episode: {new_episode}"