Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add masked_lm test #4344

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ jobs:

- name: Lint with black
run: |
pip install black
pip install black==22.3.0
black --check . --extend-exclude 'examples|fairseq\/model_parallel\/megatron'
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/ambv/black
rev: 22.1.0
rev: 22.3.0
clumsy marked this conversation as resolved.
Show resolved Hide resolved
hooks:
- id: black
language_version: python3.8
Expand Down
24 changes: 13 additions & 11 deletions fairseq/tasks/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
import logging
import os

from omegaconf import MISSING, II, OmegaConf
from dataclasses import dataclass, field

import numpy as np
from omegaconf import II, MISSING, OmegaConf

from fairseq import utils
from fairseq.data import (
Dictionary,
Expand All @@ -31,7 +31,6 @@

from .language_modeling import SAMPLE_BREAK_MODE_CHOICES, SHORTEN_METHOD_CHOICES


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -131,12 +130,7 @@ def setup_task(cls, cfg: MaskedLMConfig, **kwargs):
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(cfg, dictionary)

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.

Args:
split (str): name of the split (e.g., train, valid, test)
"""
def _load_dataset_split(self, split, epoch, combine):
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
Expand Down Expand Up @@ -173,7 +167,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))

# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
return PrependTokenDataset(dataset, self.source_dictionary.bos())

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.

Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset = self._load_dataset_split(split, epoch, combine)

# create masked input and targets
mask_whole_words = (
Expand Down
78 changes: 78 additions & 0 deletions tests/tasks/test_masked_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import unittest
from tempfile import TemporaryDirectory

from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
from tests.utils import build_vocab, make_data


class TestMaskedLM(unittest.TestCase):
def test_masks_tokens(self):
with TemporaryDirectory() as dirname:

# prep input file
raw_file = os.path.join(dirname, "raw")
data = make_data(out_file=raw_file)
vocab = build_vocab(data)

# binarize
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
split = "train"
bin_file = os.path.join(dirname, split)
FileBinarizer.multiprocess_dataset(
input_file=raw_file,
binarizer=binarizer,
dataset_impl="mmap",
vocab_size=len(vocab),
output_prefix=bin_file,
)

# setup task
cfg = MaskedLMConfig(
data=dirname,
seed=42,
mask_prob=0.5, # increasing the odds of masking
random_token_prob=0, # avoiding random tokens for exact match
leave_unmasked_prob=0, # always masking for exact match
)
task = MaskedLMTask(cfg, binarizer.dict)

original_dataset = task._load_dataset_split(bin_file, 1, False)

# load datasets
task.load_dataset(split)
masked_dataset = task.dataset(split)

mask_index = task.source_dictionary.index("<mask>")
iterator = task.get_batch_iterator(
dataset=masked_dataset,
max_tokens=65_536,
max_positions=4_096,
).next_epoch_itr(shuffle=False)
for batch in iterator:
for sample in range(len(batch)):
net_input = batch["net_input"]
masked_src_tokens = net_input["src_tokens"][sample]
masked_src_length = net_input["src_lengths"][sample]
masked_tgt_tokens = batch["target"][sample]

sample_id = batch["id"][sample]
original_tokens = original_dataset[sample_id]
original_tokens = original_tokens.masked_select(
masked_src_tokens[:masked_src_length] == mask_index
)
masked_tokens = masked_tgt_tokens.masked_select(
masked_tgt_tokens != task.source_dictionary.pad()
)

assert masked_tokens.equal(original_tokens)


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,3 +786,12 @@ def make_data(length=1000, out_file=None) -> tp.List[tp.List[str]]:
print(" ".join(s), file=out)

return data


def build_vocab(data: tp.List[tp.List[str]]) -> Dictionary:
d = Dictionary()
for s in data:
for token in s:
d.add_symbol(token)
d.finalize()
return d