Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[GenderBias] Add Controllable Gender Bias Task #3146

Merged
merged 6 commits into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions parlai/tasks/genderation_bias/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
Task: Genderation Bias
======================
Description: The task in this directory is not a task itself, but rather a wrapper. The task will flatten a specified other ParlAI task (that is, turn multi-turn episodes into single-turn examples), and append a control token that corresponds with the level of gender present in the label, where word lists from https://github.com/uclanlp/gn_glove/blob/master/wordlist/ are used to count the number of gendered words. Depending on the counts, the control token will be one of the following:

- `f0m0` - no gender words in the label
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would note which word lists were used to generate these labels

- `f0m1` - there is at least one male-specific word in the label
- `f1m0` - there is at least one female-specific word in the label
- `f1m1` - there is at least one male-specific word AND one female-specific word in the label

For example, one could run the following command:

```
$ parlai display_data -t genderation_bias:controllable_task:convai2
```

Which would yield the following:

```
- - - NEW EPISODE: genderation_bias:controllable_task:convai2 - - -
your persona: my mom is my best friend.
your persona: i have four sisters.
your persona: i believe that mermaids are real.
your persona: i love iced tea.
hi , how are you doing today ? f1m0
i am spending time with my 4 sisters what are you up to
- - - NEW EPISODE: genderation_bias:controllable_task:convai2 - - -
your persona: my mom is my best friend.
your persona: i have four sisters.
your persona: i believe that mermaids are real.
your persona: i love iced tea.
hi , how are you doing today ?
i am spending time with my 4 sisters what are you up to
wow , four sisters . just watching game of thrones . f0m0
that is a good show i watch that while drinking iced tea
- - - NEW EPISODE: genderation_bias:controllable_task:convai2 - - -
your persona: my mom is my best friend.
your persona: i have four sisters.
your persona: i believe that mermaids are real.
your persona: i love iced tea.
hi , how are you doing today ?
i am spending time with my 4 sisters what are you up to
wow , four sisters . just watching game of thrones .
that is a good show i watch that while drinking iced tea
i agree . what do you do for a living ? f0m0
i'm a researcher i'm researching the fact that mermaids are real
16:33:19 | loaded 131438 episodes with a total of 131438 examples
```
317 changes: 317 additions & 0 deletions parlai/tasks/genderation_bias/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Generates a controllable_gen version of a ParlAI task, i.e., for tasks with multi-turn
dialogues (episodes), this will generate a task with single example episodes, in which
we append to the context a special classification token.

In order to use this teacher, specify the task flag as follows:
`--task genderation_bias:controllable_task:<ORIGINAL TASK NAME>`.

As an example, try running:

`parlai display_data -t genderation_bias:controllable_task:convai2`
"""

from parlai.core.message import Message
from parlai.core.opt import Opt
from parlai.core.teachers import FixedDialogTeacher
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
from parlai.utils.typing import TShared

from parlai.tasks.genderation_bias.build import build
from parlai.tasks.genderation_bias.utils import (
flatten_and_classify,
get_original_task_module,
)

from copy import deepcopy
import datetime
import glob
import json
import os
from tqdm import tqdm
from typing import List, Optional, Tuple


class ControllableTaskTeacher(FixedDialogTeacher):
"""
Generates a controllable_gen version of a ParlAI task, i.e., for tasks with multi-
turn dialogues (episodes), this will generate a task with single example episodes,
in which we append to the context a special classification token.
"""

@staticmethod
def add_cmdline_args(parser):
flattened = parser.add_argument_group('ControllableTaskTeacher Flattening Args')
flattened.add_argument(
'--flatten-include-labels',
type='bool',
default=True,
help='Include labels in the history when flattening an episode',
)
flattened.add_argument(
'--flatten-delimiter',
type=str,
default='\n',
help='How to join the dialogue history from previous turns.',
)
flattened.add_argument(
'--flatten-max-context-length',
type=int,
default=-1,
help='Maximum number of utterances to include per episode. '
'Default -1 keeps all.',
)
agent = parser.add_argument_group('ControllableTaskTeacher Args')
agent.add_argument(
'--invalidate-cache',
type='bool',
default=False,
help='Set this to True to rebuild the data (may want to do this if '
'original data has changed or you want to rebuild with new options)',
)
agent.add_argument(
'--max-examples',
type=int,
default=-1,
help='If greater than zero, will stop building after a certain num of exs',
)
agent.add_argument(
'--fixed-control',
type=str,
default='',
help='Always append this fixed control string, good for deploy time.',
)
# Add the arguments for the task teacher
opt = parser.parse_and_process_known_args()[0]
tasks = get_original_task_module(opt, multi_possible=True)
for task in tasks:
if hasattr(task, 'add_cmdline_args'):
task.add_cmdline_args(parser)

return parser

def __init__(self, opt: Opt, shared: TShared = None):
assert opt['flatten_delimiter'] == opt.get(
'delimiter', '\n'
), '--flatten-delimiter and --delimiter are set differently, please inspect and set to the same to avoid unexpected results'
klshuster marked this conversation as resolved.
Show resolved Hide resolved
self.opt = opt

if shared and 'data' in shared:
self.data = shared['data']
else:
self.word_lists = self.build_wordlists(opt)
self.data = self._setup_data(opt)

super().__init__(opt, shared)
self.reset()

def num_episodes(self) -> int:
return len(self.data)

def num_examples(self) -> int:
return len(self.data)

def _get_save_path(self, datapath: str, date: str) -> str:
"""
Return save path for the controllable gen data.

:param datapath:
path to ParlAI Data
:param date:
current date

:return path:
return path to save
"""
return os.path.join(
datapath,
f"{self.original_task_name.replace(':', '_')}_flattened_controllable_gen_{date}",
)

@classmethod
def build_wordlists(cls, opt: Opt) -> Tuple[List[str], List[str]]:
"""
Load list of explicitly gendered words.

Words taken from <https://github.com/uclanlp/gn_glove/blob/master/wordlist/>.

Examples include brother, girl, actress, husbands, etc.
"""
build(opt['datapath'])
folder = os.path.join(opt['datapath'], 'genderation_bias')
male_words = os.path.join(folder, 'male_word_file.txt')
female_words = os.path.join(folder, 'female_word_file.txt')

with open(male_words, 'r') as f:
male = f.read().splitlines()

with open(female_words, 'r') as f:
female = f.read().splitlines()

return male, female

def _setup_data(self, opt: Opt) -> List[List[Message]]:
"""
Flatten and classify the normal task data.

Save/load where applicable.

:param opt:
options dict.
"""
# create save directory, if it does not already exist
self.original_task_name = ':'.join(opt['task'].split(':')[2:])
self.save_dir = self._get_save_path(
opt['datapath'], str(datetime.datetime.today())
)
os.makedirs(self.save_dir, exist_ok=True)

fname = f"{opt['datatype'].split(':')[0]}.json"
self.save_path = os.path.join(self.save_dir, fname)

data = self.load_data(opt, fname)
if data is not None:
# successfully load data
return data

# build the original teacher
original_task_module = get_original_task_module(opt)
teacher_opt = deepcopy(opt)
teacher_opt['task'] = self.original_task_name
teacher = original_task_module(teacher_opt)

total_exs = teacher.num_examples()
if self.opt['max_examples'] > 0:
total_exs = min(self.opt['max_examples'], total_exs)

progress_bar = tqdm(
total=total_exs, unit='ex', unit_scale=True, desc='Building flattened data'
)

all_episodes = []
num_exs = 0
while num_exs < total_exs:
current_episode = []
episode_done = False

while not episode_done:
action = Message(teacher.act())
current_episode.append(action)
episode_done = action.get('episode_done', False)
num_exs += 1

# flatten the episode into 1-example episodes with context
flattened_ep = flatten_and_classify(
current_episode,
opt['flatten_max_context_length'],
include_labels=opt['flatten_include_labels'],
delimiter=opt['flatten_delimiter'],
word_lists=self.word_lists,
)
all_episodes += flattened_ep

progress_bar.update(len(flattened_ep))

# save data for future use
self.save_data(all_episodes)

return all_episodes

def load_data(self, opt: Opt, filename: str) -> Optional[List[List[Message]]]:
"""
Attempt to load pre-build data.

Checks for the most recently build data via the date string.

:param opt:
options dict
:param filename:
name of (potentially) saved data

:return episodes:
return list of episodes, if available
"""
# first check for the most recent date
save_dir = self._get_save_path(opt['datapath'], '*')
all_dates = []
for fname in glob.glob(os.path.join(save_dir, filename)):
date = os.path.split(fname)[0].split('_')[-1]
all_dates.append(date)

if len(all_dates) > 0:
most_recent = os.path.join(
self._get_save_path(opt['datapath'], sorted(all_dates)[-1]), filename
)
else:
# data has not been built yet
return None

if opt['invalidate_cache']:
# invalidate the cache and remove the existing data
logging.warn(
f' [ WARNING: invalidating cache at {self.save_path} and rebuilding the data. ]'
)
if self.save_path == most_recent:
os.remove(self.save_path)
return None

# Loading from most recent date
self.save_path = most_recent
logging.info(f' [ Data already exists. Loading from: {self.save_path} ]')
with PathManager.open(self.save_path, 'rb') as f:
data = json.load(f)

return data

def save_data(self, data: List[List[Message]]):
"""
Save the data via dumping to a json file.

:param data:
list of episodes
"""
try:
json_data = json.dumps(data)
with PathManager.open(self.save_path, 'w') as f:
f.write(json_data)
logging.info(f'[ Data successfully saved to path: {self.save_path} ]')
except Exception:
logging.warn('Data is not json serializable; not saving')

def get(self, episode_idx: int, entry_idx: int = 0) -> Message:
"""
Return a flattened example.

If using a fixed control, put that in instead of what was originally in the text.

:param episode_idx:
index of ep in data
:param entry_idx:
index of ex in ep

:return ex:
return an example
"""
ex = Message(self.data[episode_idx])

if self.opt['fixed_control'] != '':
old_text = ' '.join(ex['text'].split(' ')[:-1])
text = f"{old_text} {self.opt['fixed_control']}"
ex.force_set('text', text)

return ex

def share(self):
shared = super().share()
shared['data'] = self.data
return shared


class DefaultTeacher(ControllableTaskTeacher):
pass
Loading