Skip to content

Commit

Permalink
add masked_lm test
Browse files Browse the repository at this point in the history
  • Loading branch information
azzhipa committed Apr 18, 2022
1 parent f862ff5 commit 5c85c0d
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ jobs:
- name: Lint with black
run: |
pip install black
pip install black==22.3.0
black --check . --extend-exclude 'examples|fairseq\/model_parallel\/megatron'
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/ambv/black
rev: 22.1.0
rev: 22.3.0
hooks:
- id: black
language_version: python3.8
Expand Down
24 changes: 13 additions & 11 deletions fairseq/tasks/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
import logging
import os

from omegaconf import MISSING, II, OmegaConf
from dataclasses import dataclass, field

import numpy as np
from omegaconf import II, MISSING, OmegaConf

from fairseq import utils
from fairseq.data import (
Dictionary,
Expand All @@ -31,7 +31,6 @@

from .language_modeling import SAMPLE_BREAK_MODE_CHOICES, SHORTEN_METHOD_CHOICES


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -131,12 +130,7 @@ def setup_task(cls, cfg: MaskedLMConfig, **kwargs):
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(cfg, dictionary)

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
def _load_dataset_split(self, split, epoch, combine):
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
Expand Down Expand Up @@ -173,7 +167,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))

# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
return PrependTokenDataset(dataset, self.source_dictionary.bos())

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset = self._load_dataset_split(split, epoch, combine)

# create masked input and targets
mask_whole_words = (
Expand Down
78 changes: 78 additions & 0 deletions tests/tasks/test_masked_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import unittest
from tempfile import TemporaryDirectory

from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
from tests.utils import build_vocab, make_data


class TestMaskedLM(unittest.TestCase):
def test_masks_tokens(self):
with TemporaryDirectory() as dirname:

# prep input file
raw_file = os.path.join(dirname, "raw")
data = make_data(out_file=raw_file)
vocab = build_vocab(data)

# binarize
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
split = "train"
bin_file = os.path.join(dirname, split)
FileBinarizer.multiprocess_dataset(
input_file=raw_file,
binarizer=binarizer,
dataset_impl="mmap",
vocab_size=len(vocab),
output_prefix=bin_file,
)

# setup task
cfg = MaskedLMConfig(
data=dirname,
seed=42,
mask_prob=0.5, # increasing the odds of masking
random_token_prob=0, # avoiding random tokens for exact match
leave_unmasked_prob=0, # always masking for exact match
)
task = MaskedLMTask(cfg, binarizer.dict)

original_dataset = task._load_dataset_split(bin_file, 1, False)

# load datasets
task.load_dataset(split)
masked_dataset = task.dataset(split)

mask_index = task.source_dictionary.index("<mask>")
iterator = task.get_batch_iterator(
dataset=masked_dataset,
max_tokens=65_536,
max_positions=4_096,
).next_epoch_itr(shuffle=False)
for batch in iterator:
for sample in range(len(batch)):
net_input = batch["net_input"]
masked_src_tokens = net_input["src_tokens"][sample]
masked_src_length = net_input["src_lengths"][sample]
masked_tgt_tokens = batch["target"][sample]

sample_id = batch["id"][sample]
original_tokens = original_dataset[sample_id]
original_tokens = original_tokens.masked_select(
masked_src_tokens[:masked_src_length] == mask_index
)
masked_tokens = masked_tgt_tokens.masked_select(
masked_tgt_tokens != task.source_dictionary.pad()
)

assert masked_tokens.equal(original_tokens)


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,3 +786,12 @@ def make_data(length=1000, out_file=None) -> tp.List[tp.List[str]]:
print(" ".join(s), file=out)

return data


def build_vocab(data: tp.List[tp.List[str]]) -> Dictionary:
d = Dictionary()
for s in data:
for token in s:
d.add_symbol(token)
d.finalize()
return d

0 comments on commit 5c85c0d

Please sign in to comment.