Skip to content

Commit

Permalink
wav2vec model (#654)
Browse files Browse the repository at this point in the history
Summary:
Merging wav2vec to master. Includes renames (Cpc -> wav2vec) and some light example files.
Pull Request resolved: fairinternal/fairseq-py#654

Differential Revision: D15913409

Pulled By: alexeib

fbshipit-source-id: f723e6f211706cd9431c7d76dc12c4e80c9cfc80
  • Loading branch information
alexeib authored and facebook-github-bot committed Jun 20, 2019
1 parent bd710e7 commit 392fce8
Show file tree
Hide file tree
Showing 12 changed files with 1,091 additions and 22 deletions.
2 changes: 2 additions & 0 deletions README.md
Expand Up @@ -9,6 +9,7 @@ of various sequence-to-sequence models, including:
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- **_New_** [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
- **LightConv and DynamicConv models**
- **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- **Long Short-Term Memory (LSTM) networks**
Expand Down Expand Up @@ -82,6 +83,7 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional models are available

We also have more detailed READMEs to reproduce results from specific papers:
- [Schneider et al. (2019): wav2vec: Unsupervised Pre-training for Speech Recognition](examples/wav2vec/README.md)
- [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](examples/backtranslation/README.md)
Expand Down
31 changes: 31 additions & 0 deletions examples/wav2vec/README.md
@@ -0,0 +1,31 @@
# wav2vec

Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862).

## Training a new model with the CLI tools

Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length)

### Prepare training data manifest:

```
$ python scripts/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav
```

### Train a wav2vec model:

```
$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \
--arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \
--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \
--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \
--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion binary_cross_entropy --num-negatives 10 \
--max-sample-size 150000 --max-tokens 1500000 ---skip-invalid-size-inputs-valid-test
```

### Extract embeddings from the downstream task data:

```
$ PYTHONPATH /path/to/fairseq python scripts/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \
--model /model/path/checkpoint_best.pt --split train valid test
```
73 changes: 73 additions & 0 deletions fairseq/criterions/binary_cross_entropy.py
@@ -0,0 +1,73 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math
import torch
import torch.nn.functional as F

from fairseq import utils

from . import FairseqCriterion, register_criterion


@register_criterion('binary_cross_entropy')
class BinaryCrossEntropyCriterion(FairseqCriterion):

def __init__(self, args, task):
super().__init__(args, task)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output, expand_steps=False).float()

if hasattr(model, 'get_target_weights'):
weights = model.get_target_weights(target, net_output)
if torch.is_tensor(weights):
weights = weights.float()
else:
weights = 1.

loss = F.binary_cross_entropy_with_logits(logits, target, reduce=False)

loss = loss * weights

if reduce:
loss = loss.sum()

sample_size = target.numel()
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample_size,
'nsentences': logits.size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output

@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output
2 changes: 1 addition & 1 deletion fairseq/criterions/fairseq_criterion.py
Expand Up @@ -13,7 +13,7 @@ class FairseqCriterion(_Loss):
def __init__(self, args, task):
super().__init__()
self.args = args
self.padding_idx = task.target_dictionary.pad()
self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100

@staticmethod
def add_args(parser):
Expand Down
2 changes: 2 additions & 0 deletions fairseq/data/__init__.py
Expand Up @@ -10,6 +10,7 @@

from .fairseq_dataset import FairseqDataset

from .audio.raw_audio_dataset import RawAudioDataset
from .backtranslation_dataset import BacktranslationDataset
from .block_pair_dataset import BlockPairDataset
from .concat_dataset import ConcatDataset
Expand Down Expand Up @@ -51,6 +52,7 @@
'MMapIndexedDataset',
'MonolingualDataset',
'NoisingDataset',
'RawAudioDataset',
'RoundRobinZipDatasets',
'ShardedIterator',
'TokenBlockDataset',
Expand Down
128 changes: 128 additions & 0 deletions fairseq/data/audio/raw_audio_dataset.py
@@ -0,0 +1,128 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.


import os
import numpy as np
import sys
import torch
import torch.nn.functional as F

from .. import FairseqDataset


class RawAudioDataset(FairseqDataset):

def __init__(self, manifest_path, sample_rate, max_sample_size=None, min_sample_size=None,
shuffle=True):
super().__init__()

self.sample_rate = sample_rate
self.fnames = []
self.sizes = []
self.max_sample_size = max_sample_size if max_sample_size is not None else sys.maxsize
self.min_sample_size = min_sample_size if min_sample_size is not None else self.max_sample_size

with open(manifest_path, 'r') as f:
self.root_dir = f.readline().strip()
for line in f:
items = line.strip().split('\t')
assert len(items) == 2, line
self.fnames.append(items[0])
self.sizes.append(int(items[1]))
self.shuffle = shuffle

def __getitem__(self, index):
fname = os.path.join(self.root_dir, self.fnames[index])
import soundfile as sf

wav, curr_sample_rate = sf.read(fname)
feats = torch.from_numpy(wav).float()

if feats.dim() == 2:
feats = feats.mean(-1)

if curr_sample_rate != self.sample_rate:
factor = self.sample_rate / curr_sample_rate
feats = self.resample(feats, factor)

assert feats.dim() == 1, feats.dim()

return {
'id': index,
'source': feats,
}

def resample(self, x, factor):
return F.interpolate(x.view(1, 1, -1), scale_factor=factor).squeeze()

def __len__(self):
return len(self.fnames)

def collater(self, samples):
if len(samples) == 0:
return {}

sources = [s['source'] for s in samples]
sizes = [len(s) for s in sources]
target_size = min(min(sizes), self.max_sample_size)

if self.min_sample_size < target_size:
target_size = np.random.randint(self.min_sample_size, target_size + 1)

collated_sources = sources[0].new(len(sources), target_size)
for i, (source, size) in enumerate(zip(sources, sizes)):
diff = size - target_size
assert diff >= 0
if diff == 0:
collated_sources[i] = source
else:
start = np.random.randint(0, diff + 1)
end = size - diff + start
collated_sources[i] = source[start:end]

return {
'id': torch.LongTensor([s['id'] for s in samples]),
'net_input': {
'source': collated_sources,
},
}

def get_dummy_batch(
self, num_tokens, max_positions, src_len=2048, tgt_len=128,
):
"""Return a dummy batch with a given number of tokens."""
if isinstance(max_positions, float) or isinstance(max_positions, int):
src_len = min(src_len, max_positions)
bsz = num_tokens // src_len
return self.collater([
{
'id': i,
'source': torch.rand(src_len),
}
for i in range(bsz)
])

def num_tokens(self, index):
return self.size(index)

def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return min(self.sizes[index], self.max_sample_size)

def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""

if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]

order.append(self.sizes)
return np.lexsort(order)

0 comments on commit 392fce8

Please sign in to comment.