Skip to content

Commit

Permalink
Add fairseq-hydra-train and update docs (#1449)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#1449

Test Plan: Imported from OSS

Reviewed By: alexeib

Differential Revision: D25094525

Pulled By: myleott

fbshipit-source-id: 430387d11196d3292933bb168cf09ea16ebc0d3b
  • Loading branch information
myleott authored and facebook-github-bot committed Nov 20, 2020
1 parent bf71f14 commit dbfca6e
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 113 deletions.
226 changes: 135 additions & 91 deletions docs/hydra_integration.md

Large diffs are not rendered by default.

21 changes: 14 additions & 7 deletions examples/wav2vec/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ This configuration was used for the base model trained on the Librispeech datase
Note that the input is expected to be single channel, sampled at 16 kHz

```shell script
$ python fairseq_cli/hydra_train.py task.data=/path/to/data \
--config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining --config-name wav2vec2_base_librispeech
$ fairseq-hydra-train \
task.data=/path/to/data \
--config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining \
--config-name wav2vec2_base_librispeech
```

Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before --config-path)
Expand All @@ -68,8 +70,10 @@ Note: you can simulate 64 GPUs by using k GPUs and adding command line parameter
This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper

```shell script
$ python fairseq_cli/hydra_train.py task.data=/path/to/data \
--config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining --config-name wav2vec2_large_librivox
$ fairseq-hydra-train \
task.data=/path/to/data \
--config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining \
--config-name wav2vec2_large_librivox
```

Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before --config-path)
Expand All @@ -88,9 +92,12 @@ $ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $sp

Fine-tuning on 100h of Librispeech with letter targets:
```shell script
python fairseq_cli/hydra_train.py distributed_training.distributed_port=$PORT task.data=/path/to/data \
model.w2v_path=/path/to/model.pt --config-path /path/to/fairseq-py/examples/wav2vec/config/finetuning \
--config-name base_100h
$ fairseq-hydra-train \
distributed_training.distributed_port=$PORT \
task.data=/path/to/data \
model.w2v_path=/path/to/model.pt \
--config-path /path/to/fairseq-py/examples/wav2vec/config/finetuning \
--config-name base_100h
```

There are other config files in the config/finetuning directory that can be used to fine-tune on other splits.
Expand Down
6 changes: 3 additions & 3 deletions fairseq/config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# @package _group_
defaults:
- task: language_modeling
- task: null
- model: null
- criterion: cross_entropy
- optimizer: adam
- lr_scheduler: cosine
- optimizer: null
- lr_scheduler: fixed
- bpe: null
- tokenizer: null
- scoring: null
Expand Down
6 changes: 6 additions & 0 deletions fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ class CommonConfig(FairseqDataclass):
profile: bool = field(
default=False, metadata={"help": "enable autograd profiler emit_nvtx"}
)
reset_logging: bool = field(
default=True,
metadata={
"help": "when using Hydra, reset the logging at the beginning of training"
},
)


@dataclass
Expand Down
6 changes: 4 additions & 2 deletions fairseq/modules/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"):
import xentropy_cuda
from apex.contrib import xentropy

logger.info("using fused cross entropy")

def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
if logits.device == torch.device("cpu"):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
else:
if not getattr(cross_entropy, "_has_logged_once", False):
logger.info("using fused cross entropy")
cross_entropy._has_logged_once = True

half_to_float = logits.dtype == torch.half
losses = xentropy.SoftmaxCrossEntropyLoss.apply(
logits,
Expand Down
34 changes: 28 additions & 6 deletions fairseq_cli/hydra_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,32 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import hydra
from omegaconf import OmegaConf
import logging
import os
import sys

from fairseq.dataclass.initialize import hydra_init
from fairseq_cli.train import main as pre_main
from fairseq import distributed_utils
from fairseq.dataclass.configs import FairseqConfig

import logging
import hydra
import torch
from omegaconf import OmegaConf


logger = logging.getLogger(__name__)
logger = logging.getLogger("fairseq_cli.hydra_train")


@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
def hydra_main(cfg: FairseqConfig) -> None:

cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))

OmegaConf.set_struct(cfg, True)

if cfg.common.reset_logging:
reset_logging() # Hydra hijacks logging, fix that

if cfg.common.profile:
with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx():
Expand All @@ -35,7 +38,22 @@ def hydra_main(cfg: FairseqConfig) -> None:
distributed_utils.call_main(cfg, pre_main)


if __name__ == "__main__":
def reset_logging():
root = logging.getLogger()
for handler in root.handlers:
root.removeHandler(handler)
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
root.addHandler(handler)


def cli_main():
try:
from hydra._internal.utils import get_args

Expand All @@ -46,3 +64,7 @@ def hydra_main(cfg: FairseqConfig) -> None:

hydra_init(cfg_name)
hydra_main()


if __name__ == "__main__":
cli_main()
17 changes: 13 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@ def write_version_py():

# append latest commit hash to version string
try:
sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
sha = (
subprocess.check_output(["git", "rev-parse", "HEAD"])
.decode("ascii")
.strip()
)
version += "+" + sha[:7]
except Exception:
pass

# write version info to fairseq/version.py
with open(os.path.join("fairseq", "version.py"), "w") as f:
f.write("__version__ = \"{}\"\n".format(version))
f.write('__version__ = "{}"\n'.format(version))
return version


Expand Down Expand Up @@ -194,14 +198,16 @@ def do_setup(package_data):
"tests",
"tests.*",
]
) + extra_packages,
)
+ extra_packages,
package_data=package_data,
ext_modules=extensions,
test_suite="tests",
entry_points={
"console_scripts": [
"fairseq-eval-lm = fairseq_cli.eval_lm:cli_main",
"fairseq-generate = fairseq_cli.generate:cli_main",
"fairseq-hydra-train = fairseq_cli.hydra_train:cli_main",
"fairseq-interactive = fairseq_cli.interactive:cli_main",
"fairseq-preprocess = fairseq_cli.preprocess:cli_main",
"fairseq-score = fairseq_cli.score:cli_main",
Expand Down Expand Up @@ -230,8 +236,11 @@ def get_files(path, relative_to="fairseq"):
fairseq_examples = os.path.join("fairseq", "examples")
if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples):
os.symlink(os.path.join("..", "examples"), fairseq_examples)

package_data = {
"fairseq": get_files("fairseq/examples"),
"fairseq": (
get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config"))
)
}
do_setup(package_data)
finally:
Expand Down

0 comments on commit dbfca6e

Please sign in to comment.