Skip to content

Commit

Permalink
Move MoE files into examples (#1040)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#1040

Differential Revision: D20030279

Pulled By: myleott

fbshipit-source-id: 76b48a62409020039225cf98e8fcf7a494d0b7f8
  • Loading branch information
myleott authored and facebook-github-bot committed Feb 21, 2020
1 parent e1de989 commit 8845dcf
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 23 deletions.
4 changes: 2 additions & 2 deletions examples/speech_recognition/criterions/ASG_loss.py
Expand Up @@ -13,8 +13,6 @@
from fairseq.criterions import FairseqCriterion, register_criterion
from examples.speech_recognition.data.replabels import pack_replabels

from wav2letter.criterion import ASGLoss, CriterionScaleMode


@register_criterion("asg_loss")
class ASGCriterion(FairseqCriterion):
Expand Down Expand Up @@ -43,6 +41,8 @@ def add_args(parser):
)

def __init__(self, args, task):
from wav2letter.criterion import ASGLoss, CriterionScaleMode

super().__init__(args, task)
self.tgt_dict = task.target_dictionary
self.eos = self.tgt_dict.eos()
Expand Down
2 changes: 1 addition & 1 deletion examples/speech_recognition/datasets/asr_prep_json.py
Expand Up @@ -14,14 +14,14 @@
import json
import sentencepiece as spm
import multiprocessing
import torchaudio

from fairseq.data import Dictionary

MILLISECONDS_TO_SECONDS = 0.001


def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
import torchaudio
input = {}
output = {}
si, ei = torchaudio.info(aud_path)
Expand Down
26 changes: 16 additions & 10 deletions examples/speech_recognition/w2l_decoder.py
Expand Up @@ -13,16 +13,22 @@
import torch
from fairseq import utils
from examples.speech_recognition.data.replabels import unpack_replabels
from wav2letter.common import create_word_dict, load_words
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from wav2letter.decoder import (
CriterionType,
DecoderOptions,
KenLM,
SmearingMode,
Trie,
WordLMDecoder,
)

try:
from wav2letter.common import create_word_dict, load_words
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from wav2letter.decoder import (
CriterionType,
DecoderOptions,
KenLM,
SmearingMode,
Trie,
WordLMDecoder,
)
except ImportError:
# wav2letter is a required dependency for the speech_recognition
# example, but don't break on import
pass


class W2lDecoder(object):
Expand Down
6 changes: 3 additions & 3 deletions examples/translation_moe/README.md
Expand Up @@ -18,7 +18,7 @@ The following command will train a `hMoElp` model with `3` experts:
fairseq-train --ddp-backend='no_c10d' \
data-bin/wmt17_en_de \
--max-update 100000 \
--task translation_moe \
--task translation_moe --user-dir examples/translation_moe/src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--arch transformer_wmt_en_de --share-all-embeddings \
Expand All @@ -37,7 +37,7 @@ For example, to generate from expert 0:
fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--task translation_moe \
--task translation_moe --user-dir examples/translation_moe/src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert 0
Expand All @@ -61,7 +61,7 @@ for EXPERT in $(seq 0 2); do \
--beam 1 \
--bpe subword_nmt --bpe-codes $BPE_CODE \
--buffer-size 500 --max-tokens 6000 \
--task translation_moe \
--task translation_moe --user-dir examples/translation_moe/src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert $EXPERT ; \
Expand Down
6 changes: 6 additions & 0 deletions examples/translation_moe/src/__init__.py
@@ -0,0 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import translation_moe # noqa
File renamed without changes.
File renamed without changes.
Expand Up @@ -5,10 +5,13 @@

import torch

from fairseq import metrics, modules, utils
from fairseq import metrics, utils
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask

from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork


@register_task('translation_moe')
class TranslationMoETask(TranslationTask):
Expand Down Expand Up @@ -100,7 +103,7 @@ def build_model(self, args):
else:
raise ValueError('Must specify --mean-pool-gating-network-dropout')

model.gating_network = modules.MeanPoolGatingNetwork(
model.gating_network = MeanPoolGatingNetwork(
encoder_dim, args.num_experts, dropout,
)
else:
Expand Down Expand Up @@ -171,7 +174,7 @@ def get_lprob_yz(winners=None):
loss = -get_lprob_yz(winners)
else:
lprob_yz = get_lprob_yz() # B x K
loss = -modules.LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)

loss = loss.sum()
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
Expand Down
4 changes: 0 additions & 4 deletions fairseq/modules/__init__.py
Expand Up @@ -17,8 +17,6 @@
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution
from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork
from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding
from .scalar_bias import ScalarBias
Expand Down Expand Up @@ -47,8 +45,6 @@
'LightweightConv1dTBC',
'LightweightConv',
'LinearizedConvolution',
'LogSumExpMoE',
'MeanPoolGatingNetwork',
'MultiheadAttention',
'PositionalEmbedding',
'ScalarBias',
Expand Down
7 changes: 7 additions & 0 deletions fairseq/options.py
Expand Up @@ -113,6 +113,13 @@ def parse_args_and_arch(

from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY

# Before creating the true parser, we need to import optional user module
# in order to eagerly import custom tasks, optimizers, architectures, etc.
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
usr_parser.add_argument("--user-dir", default=None)
usr_args, _ = usr_parser.parse_known_args(input_args)
utils.import_user_module(usr_args)

if modify_parser is not None:
modify_parser(parser)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_binaries.py
Expand Up @@ -375,6 +375,7 @@ def test_mixture_of_experts(self):
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'transformer_iwslt_de_en', [
'--task', 'translation_moe',
'--user-dir', 'examples/translation_moe/src',
'--method', 'hMoElp',
'--mean-pool-gating-network',
'--num-experts', '3',
Expand All @@ -385,6 +386,7 @@ def test_mixture_of_experts(self):
])
generate_main(data_dir, [
'--task', 'translation_moe',
'--user-dir', 'examples/translation_moe/src',
'--method', 'hMoElp',
'--mean-pool-gating-network',
'--num-experts', '3',
Expand Down

0 comments on commit 8845dcf

Please sign in to comment.