Skip to content

Commit

Permalink
Add dummy tasks and model for benchmarking (#1026)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#1026

Differential Revision: D19834667

Pulled By: myleott

fbshipit-source-id: 56ab6df5d8145dc37431252de444a2a9728e7898
  • Loading branch information
myleott authored and facebook-github-bot committed Feb 12, 2020
1 parent 4cae680 commit 91f0534
Show file tree
Hide file tree
Showing 6 changed files with 338 additions and 8 deletions.
2 changes: 2 additions & 0 deletions fairseq/__init__.py
Expand Up @@ -13,3 +13,5 @@
import fairseq.optim.lr_scheduler # noqa
import fairseq.pdb # noqa
import fairseq.tasks # noqa

import fairseq.benchmark # noqa
107 changes: 107 additions & 0 deletions fairseq/benchmark/dummy_lm.py
@@ -0,0 +1,107 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch

from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import FairseqTask, register_task


@register_task('dummy_lm')
class DummyLMTask(FairseqTask):

@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('--dict-size', default=50000, type=int)
parser.add_argument('--dataset-size', default=100000, type=int)
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments '
'per sample for BERT dataset')

def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed

seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1

self.dummy_src = seq[:-1]
self.dummy_tgt = seq[1:]

@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol('word{}'.format(i))
print('| dictionary: {} types'.format(len(dictionary)))

return cls(args, dictionary)

def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
bsz = self.args.max_sentences
self.datasets[split] = DummyDataset(
{
'id': 1,
'net_input': {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full((bsz, ), self.args.tokens_per_sample),
},
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
'nsentences': bsz,
'ntokens': bsz * self.args.tokens_per_sample,
},
num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample,
)

@property
def source_dictionary(self):
return self.dictionary

@property
def target_dictionary(self):
return self.dictionary


class DummyDataset(FairseqDataset):

def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size

def __getitem__(self, index):
return index

def __len__(self):
return self.num_items

def collater(self, samples):
return self.batch

@property
def sizes(self):
return np.array([self.item_size] * self.num_items)

def num_tokens(self, index):
return self.item_size

def size(self, index):
return self.item_size

def ordered_indices(self):
return np.arange(self.num_items)

@property
def supports_prefetch(self):
return False
118 changes: 118 additions & 0 deletions fairseq/benchmark/dummy_masked_lm.py
@@ -0,0 +1,118 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch

from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import FairseqTask, register_task


@register_task('dummy_masked_lm')
class DummyMaskedLMTask(FairseqTask):

@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('--dict-size', default=50000, type=int)
parser.add_argument('--dataset-size', default=100000, type=int)
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments '
'per sample for BERT dataset')

def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed

# add mask token
self.mask_idx = dictionary.add_symbol('<mask>')
assert len(dictionary) % 8 == 0

mask_idx = 0
pad_idx = 1
seq = torch.arange(args.tokens_per_sample) + pad_idx + 1
mask = torch.arange(2, args.tokens_per_sample, 7) # ~15%
src = seq.clone()
src[mask] = mask_idx
tgt = torch.full_like(seq, pad_idx)
tgt[mask] = seq[mask]

self.dummy_src = src
self.dummy_tgt = tgt

@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol('word{}'.format(i))
print('| dictionary: {} types'.format(len(dictionary)))

return cls(args, dictionary)

def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
bsz = self.args.max_sentences
self.datasets[split] = DummyDataset(
{
'id': 1,
'net_input': {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full((bsz, ), self.args.tokens_per_sample),
},
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
'nsentences': bsz,
'ntokens': bsz * self.args.tokens_per_sample,
},
num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample,
)

@property
def source_dictionary(self):
return self.dictionary

@property
def target_dictionary(self):
return self.dictionary


class DummyDataset(FairseqDataset):

def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size

def __getitem__(self, index):
return index

def __len__(self):
return self.num_items

def collater(self, samples):
return self.batch

@property
def sizes(self):
return np.array([self.item_size] * self.num_items)

def num_tokens(self, index):
return self.item_size

def size(self, index):
return self.item_size

def ordered_indices(self):
return np.arange(self.num_items)

@property
def supports_prefetch(self):
return False
93 changes: 93 additions & 0 deletions fairseq/benchmark/dummy_model.py
@@ -0,0 +1,93 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
import torch.nn.functional as F

from fairseq.data import Dictionary
from fairseq.models import (
FairseqDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)


@register_model('dummy_model')
class DummyModel(FairseqLanguageModel):

def __init__(self, args, encoder):
super().__init__(encoder)
self.args = args

@staticmethod
def add_args(parser):
parser.add_argument('--num-layers', type=int, default=24)
parser.add_argument('--embed-dim', type=int, default=1024)

@classmethod
def build_model(cls, args, task):
encoder = DummyEncoder(
num_embed=len(task.target_dictionary),
embed_dim=args.embed_dim,
num_layers=args.num_layers,
)
return cls(args, encoder)

def forward(self, src_tokens, **kwargs):
return self.decoder(src_tokens)


class DummyEncoder(FairseqDecoder):

def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
super().__init__(Dictionary())
self.embed = nn.Embedding(
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
)
self.layers_a = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection
nn.Linear(3*embed_dim, embed_dim), # skip self-attention
nn.Linear(embed_dim, embed_dim), # output projection
nn.Dropout(),
)
for i in range(num_layers)
])
self.layers_b = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 4*embed_dim), # FFN
nn.ReLU(),
nn.Linear(4*embed_dim, embed_dim), # FFN
nn.Dropout(0.1),
)
for i in range(num_layers)
])
self.out_proj = nn.Linear(embed_dim, num_embed)

def forward(self, tokens):
x = self.embed(tokens)
for layer_a, layer_b in zip(self.layers_a, self.layers_b):
x = x + layer_a(x)
x = x + layer_b(x)
x = self.out_proj(x)
return (x,)

def max_positions(self):
return 1024

def get_normalized_probs(self, net_output, log_probs, sample=None):
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)


@register_model_architecture('dummy_model', 'dummy_model')
def base_architecture(args):
pass
6 changes: 5 additions & 1 deletion fairseq/models/__init__.py
Expand Up @@ -123,7 +123,11 @@ def register_model_arch_fn(fn):
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)):
if (
not file.startswith('_')
and not file.startswith('.')
and (file.endswith('.py') or os.path.isdir(path))
):
model_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('fairseq.models.' + model_name)

Expand Down
20 changes: 13 additions & 7 deletions fairseq/tasks/__init__.py
Expand Up @@ -53,10 +53,20 @@ def register_task_cls(cls):
return register_task_cls


def get_task(name):
return TASK_REGISTRY[name]


# automatically import any Python files in the tasks/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
task_name = file[:file.find('.py')]
tasks_dir = os.path.dirname(__file__)
for file in os.listdir(tasks_dir):
path = os.path.join(tasks_dir, file)
if (
not file.startswith('_')
and not file.startswith('.')
and (file.endswith('.py') or os.path.isdir(path))
):
task_name = file[:file.find('.py')] if file.endswith('.py') else file
importlib.import_module('fairseq.tasks.' + task_name)

# expose `task_parser` for sphinx
Expand All @@ -70,7 +80,3 @@ def register_task_cls(cls):
group_args = parser.add_argument_group('Additional command-line arguments')
TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + '_parser'] = parser


def get_task(name):
return TASK_REGISTRY[name]

0 comments on commit 91f0534

Please sign in to comment.