Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[bpe] Support BPE dropout (#3232)
Browse files Browse the repository at this point in the history
* Implement BPE dropout.

* Only BPE dropout on text, not labels.

* Add a unit test.

* Notes for the future.

* Dictionary save works for slow bytelevel bpe

* Finish adding tests.

* Reviewer coments.

* Rip out unrelated change.
  • Loading branch information
stephenroller committed Nov 11, 2020
1 parent 8dd1588 commit 2930a64
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 10 deletions.
48 changes: 47 additions & 1 deletion parlai/core/dict.py
Expand Up @@ -21,6 +21,15 @@
import parlai.utils.logging as logging
from parlai.utils.io import PathManager
from typing import List
import enum


class TokenizationMode(enum.Enum):
TRAIN_TIME_TEXT = 0
TRAIN_TIME_LABEL = 1
TEST_TIME_TEXT = 2
TEST_TIME_LABEL = 3


RETOK = re.compile(r'\w+|[^\w\s]|\n', re.UNICODE)

Expand Down Expand Up @@ -235,6 +244,9 @@ def __init__(self, opt: Opt, shared=None):
'dict_textfields', DictionaryAgent.default_textfields
).split(",")

# used to signal whether we should use training time tricks, like bpe droput
self._tokenization_mode = TokenizationMode.TEST_TIME_LABEL

try:
self.tokenizer_fun = getattr(self, self.tokenizer + '_tokenize')
except AttributeError:
Expand Down Expand Up @@ -663,7 +675,7 @@ def save(self, filename=None, append=False, sort=True):
with PathManager.open(filename + '.opt', 'w', encoding='utf-8') as handle:
json.dump(self.opt, handle, indent=4)
# save the byte level bpe model file as well
if self.tokenizer == 'bytelevelbpe':
if self.tokenizer == 'bytelevelbpe' or self.tokenizer == 'slow_bytelevel_bpe':
# This saves filename-vocab.json and filename-merges.txt as
# hugging face tokenizer does
self.bpe.save(os.path.dirname(filename), os.path.basename(filename))
Expand Down Expand Up @@ -702,6 +714,21 @@ def sort(self, trim=True):
assert len(self.freq) == len(self.ind2tok) == len(self.tok2ind)
return sorted_pairs

def parse(self, txt_or_vec, vec_type=list):
"""
Parse either text or a vector of indices.
Calls `~txt2vec` if `txt_or_vec is a string, or `~vec2txt` otherwise.
:param vec_type:
type of the returned vector if the input is a string.
"""
# TODO: try to deprecate this, preferring straight txt2vec
if type(txt_or_vec) == str:
return self.txt2vec(txt_or_vec, vec_type)
else:
return self.vec2txt(txt_or_vec)

def txt2vec(self, text, vec_type=list):
"""
Convert a string to a vector (list of ints).
Expand Down Expand Up @@ -791,3 +818,22 @@ def __str__(self):
Return string representation of frequencies in dictionary.
"""
return str(self.freq)

def set_tokenization_mode(self, mode: TokenizationMode):
"""
Indicate what "kind" of tokenization is being done.
This can be Training Time / Testing Time, and it can be over
context or labels.
This is used to signal from TorchAgent to the dict that it's allowed
to enable things like BPE dropout. It is NOT used to indicate whether
the dictionary itself is in training time.
Use True for training time, False for not.
"""
self._context_mode = mode
if hasattr(self, 'bpe'):
# enable bpe dropout only in texts at training time. disable all
# other times
self.bpe.enable_bpe_dropout(mode == TokenizationMode.TRAIN_TIME_TEXT)
20 changes: 19 additions & 1 deletion parlai/core/torch_agent.py
Expand Up @@ -28,7 +28,7 @@

from parlai.core.opt import Opt
from parlai.core.agents import Agent
from parlai.core.dict import DictionaryAgent
from parlai.core.dict import DictionaryAgent, TokenizationMode
from parlai.nn.lr_scheduler import ParlAILRScheduler
from parlai.core.message import Message
from parlai.utils.distributed import is_distributed
Expand Down Expand Up @@ -1660,14 +1660,32 @@ def observe(self, observation):
# make sure we note that we're expecting a reply in the future
self.__expecting_to_reply = True

# keep around the observation for updating history based on label
self.observation = observation

# possibly change tokenization methodology based on if this is a
# training example
is_training_mode = 'labels' in observation
if hasattr(self.dict, 'set_tokenization_mode'):
if is_training_mode:
self.dict.set_tokenization_mode(TokenizationMode.TRAIN_TIME_TEXT)
else:
self.dict.set_tokenization_mode(TokenizationMode.TEST_TIME_TEXT)

# Update the history using the observation.
# We may also consider adding a temporary string to the history
# using the `get_temp_history()` function: this string will
# persist until it is updated.
self.history.update_history(
observation, temp_history=self.get_temp_history(observation)
)

if hasattr(self.dict, 'set_tokenization_mode'):
if is_training_mode:
self.dict.set_tokenization_mode(TokenizationMode.TRAIN_TIME_LABEL)
else:
self.dict.set_tokenization_mode(TokenizationMode.TEST_TIME_LABEL)

return self.vectorize(
observation,
self.history,
Expand Down
50 changes: 46 additions & 4 deletions parlai/utils/bpe.py
Expand Up @@ -13,6 +13,7 @@
from functools import lru_cache
import json
import os
import random
import re
from typing import Dict, List, Optional, Set, Tuple
from typing_extensions import final
Expand Down Expand Up @@ -108,6 +109,8 @@ def __init__(self, opt: Opt, shared: TShared = None):
self.debug = opt.get('bpe_debug', False)
self.add_prefix_space = opt.get('bpe_add_prefix_space', False)
self._special_tokens: Dict[str, int] = {}
self.bpe_dropout: Optional[float] = opt.get('bpe_dropout')
self._bpe_dropout_enabled = False

@staticmethod
def add_cmdline_args(argparser):
Expand All @@ -124,8 +127,20 @@ def add_cmdline_args(argparser):
hidden=True,
help='add prefix space before encoding',
)
parser.add_argument(
'--bpe-dropout',
type=float,
default=None,
help='Use BPE dropout during training.',
)
return parser

def enable_bpe_dropout(self, enabled: bool):
"""
Used to toggle BPE dropout on (True) or off (False).
"""
self._bpe_dropout_enabled = enabled

@final
def encode(self, text: str) -> List[str]:
"""
Expand Down Expand Up @@ -610,14 +625,31 @@ def get_pairs(self, word: Tuple[str, ...]) -> Set[Tuple[str, str]]:
:return pairs:
set of tuples of symbols
"""
pairs = set()
pairs = []
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
pairs.append((prev_char, char))
prev_char = char
return pairs

@lru_cache(maxsize=10240)
def _dropout_pairs(self, pairs):
"""
Implements BPE dropout (Provlikov et al., 2019).
https://arxiv.org/abs/1910.13267
Randomly removes merges from the list of possible merges. This can
result in different subwords being used to realized the same string,
and effectively regularizes representations.
"""
if not self.bpe_dropout or not self._bpe_dropout_enabled:
return pairs

dropped_pairs = [p for p in pairs if random.random() > self.bpe_dropout]
if not dropped_pairs:
dropped_pairs = [random.choice(pairs)]
return dropped_pairs

def bpe(self, token: str) -> str:
"""
Convert token to BPE.
Expand All @@ -635,7 +667,10 @@ def bpe(self, token: str) -> str:
return token

while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
dropped_pairs = self._dropout_pairs(pairs)
bigram = min(
dropped_pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))
)
if bigram not in self.bpe_ranks:
break
first, second = bigram
Expand Down Expand Up @@ -771,6 +806,13 @@ def __init__(self, opt: Opt, shared: TShared = None):
'Please install HuggingFace tokenizer with: pip install tokenizers'
)

if self.bpe_dropout:
raise NotImplementedError(
'--bpe-dropout is not supported with ByteLevelBPE because tokenizers '
'library does not allow dynamically turning BPE on/off. You can use '
'--dict-tokenizer slow_bytelevel_bpe to gain this feature.'
)

if self.lower:
warn_once('Are you sure you want to lower case your BPE dictionary?')
if self.maxtokens > 0 or self.minfreq > 0:
Expand Down
45 changes: 41 additions & 4 deletions tests/test_dict.py
Expand Up @@ -11,7 +11,7 @@
from parlai.core.build_data import modelzoo_path
from parlai.core.dict import find_ngrams
from parlai.core.params import ParlaiParser
from parlai.core.dict import DictionaryAgent
from parlai.core.dict import DictionaryAgent, TokenizationMode
from parlai.core.opt import Opt
import parlai.scripts.build_dict as build_dict

Expand Down Expand Up @@ -152,17 +152,17 @@ def test_basic_parse(self):
dictionary.act()
assert len(dictionary) - num_builtin == 2

vec = dictionary.txt2vec('hello world')
vec = dictionary.parse('hello world')
assert len(vec) == 2
assert vec[0] == num_builtin
assert vec[1] == num_builtin + 1

vec = dictionary.txt2vec('hello world', vec_type=list)
vec = dictionary.parse('hello world', vec_type=list)
assert len(vec) == 2
assert vec[0] == num_builtin
assert vec[1] == num_builtin + 1

vec = dictionary.txt2vec('hello world', vec_type=tuple)
vec = dictionary.parse('hello world', vec_type=tuple)
assert len(vec) == 2
assert vec[0] == num_builtin
assert vec[1] == num_builtin + 1
Expand Down Expand Up @@ -602,3 +602,40 @@ def test_specialtok_nonsupport(self):
for tokenizer in ["bpe"]:
with self.assertRaises(NotImplementedError):
self._run_specialtok_test(dict_tokenizer=tokenizer)


class TestBpeDropout(unittest.TestCase):
def _test_bpe_dropout(self, **dict_args):
pp = ParlaiParser(False, False)
DictionaryAgent.add_cmdline_args(pp)
opt = pp.parse_kwargs(bpe_dropout=0.5, **dict_args)
da = DictionaryAgent(opt)
da.set_tokenization_mode(TokenizationMode.TEST_TIME_TEXT)
s = (
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
"Donec vitae metus sollicitudin, ullamcorper tortor ut, rhoncus lacus. "
"Praesent sollicitudin commodo turpis, ut pharetra tortor gravida nec."
)
no_dropout = da.txt2vec(s)
da.set_tokenization_mode(TokenizationMode.TRAIN_TIME_TEXT)
not_the_same = 0
for _ in range(30):
r = da.txt2vec(s)
assert da.vec2txt(r) == s
if r != no_dropout:
not_the_same += 1
assert not_the_same > 0

def test_gpt2_bpe_dropout(self):
self._test_bpe_dropout(dict_tokenizer='gpt2')

def test_slowbytelevel_dropout(self):
self._test_bpe_dropout(
dict_tokenizer="slow_bytelevel_bpe", dict_file="zoo:blender/dict_3B/dict"
)

def test_bytelevelbpe_dropout(self):
with self.assertRaises(NotImplementedError):
self._test_bpe_dropout(
dict_tokenizer="bytelevelbpe", dict_file="zoo:blender/dict_3B/dict"
)

0 comments on commit 2930a64

Please sign in to comment.