Skip to content

Commit

Permalink
Cleanup LM + Flake8
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #720

Differential Revision: D15259091

Pulled By: myleott

fbshipit-source-id: 06a35996c06ccddb49fdc9e01e348ff3c9da334e
  • Loading branch information
myleott authored and facebook-github-bot committed May 8, 2019
1 parent eddcdf0 commit f2563c2
Show file tree
Hide file tree
Showing 34 changed files with 649 additions and 501 deletions.
1 change: 1 addition & 0 deletions fairseq/data/__init__.py
Expand Up @@ -47,4 +47,5 @@
'TokenBlockDataset',
'TransformEosDataset',
'TransformEosLangPairDataset',
'TruncatedDictionary',
]
3 changes: 2 additions & 1 deletion fairseq/data/data_utils.py
Expand Up @@ -10,6 +10,7 @@
import numpy as np
from collections import Iterable


def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
Expand Down Expand Up @@ -182,7 +183,7 @@ def is_batch_full(num_tokens):

def process_bpe_symbol(sentence: str, bpe_symbol: str):
if bpe_symbol == 'sentencepiece':
sentence = sentence.replace(' ','').replace('\u2581', ' ').strip()
sentence = sentence.replace(' ', '').replace('\u2581', ' ').strip()
elif bpe_symbol is not None:
sentence = (sentence + ' ').replace(bpe_symbol, '').rstrip()
return sentence
2 changes: 2 additions & 0 deletions fairseq/data/dictionary.py
Expand Up @@ -18,6 +18,7 @@

class Dictionary(object):
"""A mapping from symbols to consecutive integers"""

def __init__(self, pad='<pad>', eos='</s>', unk='<unk>', bos='<s>'):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
Expand Down Expand Up @@ -282,6 +283,7 @@ def merge_result(counter):
else:
merge_result(Dictionary._add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))


class TruncatedDictionary(object):

def __init__(self, wrapped_dict, length):
Expand Down
2 changes: 0 additions & 2 deletions fairseq/data/iterators.py
Expand Up @@ -7,8 +7,6 @@

import itertools
import math
import queue
import threading

import numpy as np
import torch
Expand Down
2 changes: 0 additions & 2 deletions fairseq/data/language_pair_dataset.py
Expand Up @@ -8,8 +8,6 @@
import numpy as np
import torch

from fairseq import utils

from . import data_utils, FairseqDataset


Expand Down
2 changes: 1 addition & 1 deletion fairseq/data/masked_lm_dataset.py
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import torch

from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple

from . import FairseqDataset, data_utils

Expand Down
32 changes: 23 additions & 9 deletions fairseq/models/__init__.py
Expand Up @@ -9,19 +9,33 @@
import importlib
import os

from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_decoder import FairseqDecoder
from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import (
BaseFairseqModel,
FairseqModel, # noqa: F401
FairseqMultiModel, # noqa: F401
FairseqLanguageModel, # noqa: F401
FairseqEncoderModel, # noqa: F401
FairseqModel,
FairseqMultiModel,
FairseqLanguageModel,
FairseqEncoderModel,
)

from .composite_encoder import CompositeEncoder # noqa: F401
from .distributed_fairseq_model import DistributedFairseqModel # noqa: F401
from .composite_encoder import CompositeEncoder
from .distributed_fairseq_model import DistributedFairseqModel


__all__ = [
'BaseFairseqModel',
'CompositeEncoder',
'DistributedFairseqModel',
'FairseqDecoder',
'FairseqEncoder',
'FairseqEncoderModel',
'FairseqIncrementalDecoder',
'FairseqLanguageModel',
'FairseqModel',
'FairseqMultiModel',
]


MODEL_REGISTRY = {}
Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/composite_encoder.py
Expand Up @@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from . import FairseqEncoder
from fairseq.models import FairseqEncoder


class CompositeEncoder(FairseqEncoder):
Expand Down
5 changes: 1 addition & 4 deletions fairseq/models/distributed_fairseq_model.py
Expand Up @@ -6,14 +6,11 @@
# can be found in the PATENTS file in the same directory.

import inspect
import socket

from torch.nn import parallel

from fairseq import distributed_utils
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel

from . import BaseFairseqModel
from fairseq.models import BaseFairseqModel


def DistributedFairseqModel(args, model):
Expand Down
4 changes: 2 additions & 2 deletions fairseq/models/fairseq_incremental_decoder.py
Expand Up @@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from . import FairseqDecoder
from fairseq.models import FairseqDecoder


class FairseqIncrementalDecoder(FairseqDecoder):
Expand All @@ -25,7 +25,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
The :class:`FairseqIncrementalDecoder` interface also defines the
:func:`reorder_incremental_state` method, which is used during beam search
to select and reorder the incremental state based on the selection of beams.
To learn more about how incremental decoding works, refer to `this blog
<http://www.telesens.co/2019/04/21/understanding-incremental-decoding-in-fairseq/>`_.
"""
Expand Down
3 changes: 2 additions & 1 deletion fairseq/models/fairseq_model.py
Expand Up @@ -4,14 +4,15 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from typing import Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from . import FairseqDecoder, FairseqEncoder
from fairseq.data import Dictionary
from fairseq.models import FairseqDecoder, FairseqEncoder


class BaseFairseqModel(nn.Module):
Expand Down
106 changes: 8 additions & 98 deletions fairseq/models/fconv.py
Expand Up @@ -10,17 +10,19 @@
import torch.nn as nn
import torch.nn.functional as F

from fairseq import options, utils
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqModel,
register_model,
register_model_architecture,
)
from fairseq.modules import (
AdaptiveSoftmax, BeamableMM, GradMultiply, LearnedPositionalEmbedding,
LinearizedConvolution,
)

from . import (
FairseqEncoder, FairseqIncrementalDecoder, FairseqModel,
FairseqLanguageModel, register_model, register_model_architecture,
)


@register_model('fconv')
class FConvModel(FairseqModel):
Expand Down Expand Up @@ -111,58 +113,6 @@ def build_model(cls, args, task):
return FConvModel(encoder, decoder)


@register_model('fconv_lm')
class FConvLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)

@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-layers', type=str, metavar='EXPR',
help='decoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
help='decoder output embedding dimension')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')

@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_lm_architecture(args)

if hasattr(args, 'max_target_positions') and not hasattr(args, 'tokens_per_sample'):
args.tokens_per_sample = args.max_target_positions

decoder = FConvDecoder(
dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.tokens_per_sample,
share_embed=False,
positional_embeddings=False,
adaptive_softmax_cutoff=(
options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None
),
adaptive_softmax_dropout=args.adaptive_softmax_dropout,
)
return FConvLanguageModel(decoder)


class FConvEncoder(FairseqEncoder):
"""
Convolutional encoder consisting of `len(convolutions)` layers.
Expand Down Expand Up @@ -643,46 +593,6 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
return nn.utils.weight_norm(m, dim=2)


@register_model_architecture('fconv_lm', 'fconv_lm')
def base_lm_architecture(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 128)
args.decoder_layers = getattr(args, 'decoder_layers', '[(1268, 4)] * 13')
args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)


@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103')
def fconv_lm_dauphin_wikitext103(args):
layers = '[(850, 6)] * 3'
layers += ' + [(850, 1)] * 1'
layers += ' + [(850, 5)] * 4'
layers += ' + [(850, 1)] * 1'
layers += ' + [(850, 4)] * 3'
layers += ' + [(1024, 4)] * 1'
layers += ' + [(2048, 4)] * 1'
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 280)
args.decoder_layers = getattr(args, 'decoder_layers', layers)
args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,20000,200000')
base_lm_architecture(args)


@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_gbw')
def fconv_lm_dauphin_gbw(args):
layers = '[(512, 5)]'
layers += ' + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3'
layers += ' + [(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3'
layers += ' + [(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6'
layers += ' + [(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]'
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 128)
args.decoder_layers = getattr(args, 'decoder_layers', layers)
args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
base_lm_architecture(args)


@register_model_architecture('fconv', 'fconv')
def base_architecture(args):
args.dropout = getattr(args, 'dropout', 0.1)
Expand Down

0 comments on commit f2563c2

Please sign in to comment.