diff --git a/README.md b/README.md index 60ce9e92b6..42253940d3 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ of various sequence-to-sequence models, including: - [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md) - [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel) - [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md) + - **_New_** [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) - **LightConv and DynamicConv models** - **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md) - **Long Short-Term Memory (LSTM) networks** @@ -82,6 +83,7 @@ as well as example training and evaluation commands. - [Language Modeling](examples/language_model/README.md): convolutional models are available We also have more detailed READMEs to reproduce results from specific papers: +- [Schneider et al. (2019): wav2vec: Unsupervised Pre-training for Speech Recognition](examples/wav2vec/README.md) - [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md) - [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md) - [Edunov et al. (2018): Understanding Back-Translation at Scale](examples/backtranslation/README.md) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md new file mode 100644 index 0000000000..325e25420f --- /dev/null +++ b/examples/wav2vec/README.md @@ -0,0 +1,31 @@ +# wav2vec + +Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862). + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) + +### Prepare training data manifest: + +``` +$ python scripts/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav +``` + +### Train a wav2vec model: + +``` +$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion binary_cross_entropy --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 ---skip-invalid-size-inputs-valid-test +``` + +### Extract embeddings from the downstream task data: + +``` +$ PYTHONPATH /path/to/fairseq python scripts/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \ +--model /model/path/checkpoint_best.pt --split train valid test +``` diff --git a/fairseq/criterions/binary_cross_entropy.py b/fairseq/criterions/binary_cross_entropy.py new file mode 100644 index 0000000000..06f269692c --- /dev/null +++ b/fairseq/criterions/binary_cross_entropy.py @@ -0,0 +1,73 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math +import torch +import torch.nn.functional as F + +from fairseq import utils + +from . import FairseqCriterion, register_criterion + + +@register_criterion('binary_cross_entropy') +class BinaryCrossEntropyCriterion(FairseqCriterion): + + def __init__(self, args, task): + super().__init__(args, task) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample['net_input']) + logits = model.get_logits(net_output).float() + target = model.get_targets(sample, net_output, expand_steps=False).float() + + if hasattr(model, 'get_target_weights'): + weights = model.get_target_weights(target, net_output) + if torch.is_tensor(weights): + weights = weights.float() + else: + weights = 1. + + loss = F.binary_cross_entropy_with_logits(logits, target, reduce=False) + + loss = loss * weights + + if reduce: + loss = loss.sum() + + sample_size = target.numel() + logging_output = { + 'loss': utils.item(loss.data) if reduce else loss.data, + 'ntokens': sample_size, + 'nsentences': logits.size(0), + 'sample_size': sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get('loss', 0) for log in logging_outputs) + ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) + nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) + sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + agg_output = { + 'loss': loss_sum / sample_size / math.log(2), + 'ntokens': ntokens, + 'nsentences': nsentences, + 'sample_size': sample_size, + } + if sample_size != ntokens: + agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) + return agg_output \ No newline at end of file diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index 5c910f4e65..fbe5afc8e8 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -13,7 +13,7 @@ class FairseqCriterion(_Loss): def __init__(self, args, task): super().__init__() self.args = args - self.padding_idx = task.target_dictionary.pad() + self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100 @staticmethod def add_args(parser): diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index e4ef89541a..ffee78429e 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -10,6 +10,7 @@ from .fairseq_dataset import FairseqDataset +from .audio.raw_audio_dataset import RawAudioDataset from .backtranslation_dataset import BacktranslationDataset from .block_pair_dataset import BlockPairDataset from .concat_dataset import ConcatDataset @@ -51,6 +52,7 @@ 'MMapIndexedDataset', 'MonolingualDataset', 'NoisingDataset', + 'RawAudioDataset', 'RoundRobinZipDatasets', 'ShardedIterator', 'TokenBlockDataset', diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py new file mode 100644 index 0000000000..7fb25cc6c3 --- /dev/null +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -0,0 +1,128 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +import os +import numpy as np +import sys +import torch +import torch.nn.functional as F + +from .. import FairseqDataset + + +class RawAudioDataset(FairseqDataset): + + def __init__(self, manifest_path, sample_rate, max_sample_size=None, min_sample_size=None, + shuffle=True): + super().__init__() + + self.sample_rate = sample_rate + self.fnames = [] + self.sizes = [] + self.max_sample_size = max_sample_size if max_sample_size is not None else sys.maxsize + self.min_sample_size = min_sample_size if min_sample_size is not None else self.max_sample_size + + with open(manifest_path, 'r') as f: + self.root_dir = f.readline().strip() + for line in f: + items = line.strip().split('\t') + assert len(items) == 2, line + self.fnames.append(items[0]) + self.sizes.append(int(items[1])) + self.shuffle = shuffle + + def __getitem__(self, index): + fname = os.path.join(self.root_dir, self.fnames[index]) + import soundfile as sf + + wav, curr_sample_rate = sf.read(fname) + feats = torch.from_numpy(wav).float() + + if feats.dim() == 2: + feats = feats.mean(-1) + + if curr_sample_rate != self.sample_rate: + factor = self.sample_rate / curr_sample_rate + feats = self.resample(feats, factor) + + assert feats.dim() == 1, feats.dim() + + return { + 'id': index, + 'source': feats, + } + + def resample(self, x, factor): + return F.interpolate(x.view(1, 1, -1), scale_factor=factor).squeeze() + + def __len__(self): + return len(self.fnames) + + def collater(self, samples): + if len(samples) == 0: + return {} + + sources = [s['source'] for s in samples] + sizes = [len(s) for s in sources] + target_size = min(min(sizes), self.max_sample_size) + + if self.min_sample_size < target_size: + target_size = np.random.randint(self.min_sample_size, target_size + 1) + + collated_sources = sources[0].new(len(sources), target_size) + for i, (source, size) in enumerate(zip(sources, sizes)): + diff = size - target_size + assert diff >= 0 + if diff == 0: + collated_sources[i] = source + else: + start = np.random.randint(0, diff + 1) + end = size - diff + start + collated_sources[i] = source[start:end] + + return { + 'id': torch.LongTensor([s['id'] for s in samples]), + 'net_input': { + 'source': collated_sources, + }, + } + + def get_dummy_batch( + self, num_tokens, max_positions, src_len=2048, tgt_len=128, + ): + """Return a dummy batch with a given number of tokens.""" + if isinstance(max_positions, float) or isinstance(max_positions, int): + src_len = min(src_len, max_positions) + bsz = num_tokens // src_len + return self.collater([ + { + 'id': i, + 'source': torch.rand(src_len), + } + for i in range(bsz) + ]) + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return min(self.sizes[index], self.max_sample_size) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.sizes) + return np.lexsort(order) diff --git a/fairseq/models/wav2vec.py b/fairseq/models/wav2vec.py new file mode 100644 index 0000000000..e89ede158b --- /dev/null +++ b/fairseq/models/wav2vec.py @@ -0,0 +1,475 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import sys + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import ( + BaseFairseqModel, register_model, register_model_architecture +) + + +@register_model('wav2vec') +class Wav2VecModel(BaseFairseqModel): + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument('--prediction-steps', type=int, metavar='N', help='number of steps ahead to predict') + parser.add_argument('--sample-distance', type=int, metavar='N', + help='sample distance from target. does not work properly with cross-sampling') + parser.add_argument('--cross-sample-negatives', action='store_true', + help='whether to sample negatives across examples in the same batch') + parser.add_argument('--num-negatives', type=int, metavar='N', help='number of negative examples') + parser.add_argument('--conv-feature-layers', type=str, metavar='EXPR', + help='convolutional feature extraction layers [(dim, kernel_size, stride), ...]') + parser.add_argument('--conv-aggregator-layers', type=str, metavar='EXPR', + help='convolutional feature extraction layers [(dim, kernel_size, stride), ...]') + parser.add_argument('--dropout', type=float, metavar='D', help='dropout to apply within the model') + parser.add_argument('--dropout-features', type=float, metavar='D', help='dropout to apply to the features') + parser.add_argument('--dropout-agg', type=float, metavar='D', help='dropout to apply after aggregation step') + parser.add_argument('--encoder', type=str, choices=['cnn'], help='type of encoder to use') + parser.add_argument('--aggregator', type=str, choices=['cnn', 'gru'], + help='type of aggregator to use') + parser.add_argument('--gru-dim', type=int, metavar='N', help='GRU dimensionality') + + parser.add_argument('--no-conv-bias', action='store_true', + help='if set, does not learn bias for conv layers') + parser.add_argument('--agg-zero-pad', action='store_true', + help='if set, zero pads in aggregator instead of repl pad') + + parser.add_argument('--skip-connections-feat', action='store_true', + help='if set, adds skip connections to the feature extractor') + parser.add_argument('--skip-connections-agg', action='store_true', + help='if set, adds skip connections to the aggregator') + parser.add_argument('--residual-scale', type=float, metavar='D', + help='scales residual by sqrt(value)') + + parser.add_argument('--log-compression', action='store_true', + help='if set, adds a log compression to feature extractor') + + parser.add_argument('--balanced-classes', action='store_true', + help='if set, loss is scaled to balance for number of negatives') + parser.add_argument('--project-features', choices=['none', 'same', 'new'], + help='if not none, features are projected using the (same or new) aggregator') + + parser.add_argument('--non-affine-group-norm', action='store_true', + help='if set, group norm is not affine') + + parser.add_argument('--offset', help='if set, introduces an offset from target to predictions. ' + 'if set to "auto", it is computed automatically from the receptive field') + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_wav2vec_architecture(args) + + model = Wav2VecModel(args) + print(model) + return model + + def __init__(self, args): + super().__init__() + + self.prediction_steps = args.prediction_steps + + offset = args.offset + + if args.encoder == 'cnn': + feature_enc_layers = eval(args.conv_feature_layers) + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0., + log_compression=args.log_compression, + skip_connections=args.skip_connections_feat, + residual_scale=args.residual_scale, + non_affine_group_norm=args.non_affine_group_norm, + ) + embed = feature_enc_layers[-1][0] + else: + raise Exception('unknown encoder type ' + args.encoder) + + if args.offset == 'auto': + assert args.encoder == 'cnn' + jin = 0 + rin = 0 + for _, k, stride in feature_enc_layers: + if rin == 0: + rin = k + rin = rin + (k - 1) * jin + if jin == 0: + jin = stride + else: + jin *= stride + offset = math.ceil(rin / jin) + + offset = int(offset) + + def make_aggregator(): + if args.aggregator == 'cnn': + agg_layers = eval(args.conv_aggregator_layers) + agg_dim = agg_layers[-1][0] + feature_aggregator = ConvAggegator( + conv_layers=agg_layers, + embed=embed, + dropout=args.dropout, + skip_connections=args.skip_connections_agg, + residual_scale=args.residual_scale, + non_affine_group_norm=args.non_affine_group_norm, + conv_bias=not args.no_conv_bias, + zero_pad=args.agg_zero_pad, + ) + elif args.aggregator == 'gru': + agg_dim = args.gru_dim + feature_aggregator = nn.Sequential( + TransposeLast(), + nn.GRU( + input_size=embed, + hidden_size=agg_dim, + num_layers=1, + dropout=args.dropout, + ), + TransposeLast(deconstruct_idx=0), + ) + else: + raise Exception('unknown aggregator type ' + args.aggregator) + + return feature_aggregator, agg_dim + + self.feature_aggregator, agg_dim = make_aggregator() + + self.wav2vec_predictions = Wav2VecPredictionsModel( + in_dim=agg_dim, + out_dim=embed, + prediction_steps=args.prediction_steps, + n_negatives=args.num_negatives, + cross_sample_negatives=args.cross_sample_negatives, + sample_distance=args.sample_distance, + dropout=args.dropout, + offset=offset, + balanced_classes=args.balanced_classes, + ) + + self.dropout_feats = nn.Dropout(p=args.dropout_features) + self.dropout_agg = nn.Dropout(p=args.dropout_agg) + + if args.project_features == 'none': + self.project_features = None + elif args.project_features == 'same': + self.project_features = self.feature_aggregator + elif args.project_features == 'new': + self.project_features, _ = make_aggregator() + + def forward(self, source): + result = {} + + features = self.feature_extractor(source) + + x = self.dropout_feats(features) + x = self.feature_aggregator(x) + x = self.dropout_agg(x) + + if self.project_features is not None: + features = self.project_features(features) + x, targets = self.wav2vec_predictions(x, features) + result['cpc_logits'] = x + result['cpc_targets'] = targets + + return result + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + def max_positions(self): + """Maximum length supported by the model.""" + return sys.maxsize + + def get_logits(self, net_output): + logits = net_output['cpc_logits'] + return logits + + def get_targets(self, sample, net_output, expand_steps=True): + t = net_output['cpc_targets'] + return t.contiguous() + + def get_target_weights(self, targets, net_output): + targets = net_output['cpc_targets'] + if isinstance(targets, tuple) and targets[-1] is not None: + return targets[-1] + return 1. + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), self.num_groups, self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, self.eps) + return output.type_as(input) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, self.eps) + return output.type_as(input) + + +def norm_block(is_layer_norm, dim, affine=True): + if is_layer_norm: + mod = nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=affine), + TransposeLast(), + ) + else: + mod = Fp32GroupNorm(1, dim, affine=affine) + + return mod + + +class ConvFeatureExtractionModel(nn.Module): + def __init__(self, conv_layers, dropout, log_compression, skip_connections, residual_scale, non_affine_group_norm): + super().__init__() + + def block(n_in, n_out, k, stride): + return nn.Sequential( + nn.Conv1d(n_in, n_out, k, stride=stride, bias=False), + nn.Dropout(p=dropout), + norm_block(is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm), + nn.ReLU(), + ) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, (dim, k, stride) in enumerate(conv_layers): + self.conv_layers.append( + block(in_d, dim, k, stride)) + in_d = dim + + self.log_compression = log_compression + self.skip_connections = skip_connections + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x): + # BxT -> BxCxT + x = x.unsqueeze(1) + + for conv in self.conv_layers: + residual = x + x = conv(x) + if self.skip_connections and x.size(1) == residual.size(1): + tsz = x.size(2) + r_tsz = residual.size(2) + residual = residual[..., ::r_tsz // tsz][..., :tsz] + x = (x + residual) * self.residual_scale + + if self.log_compression: + x = x.abs() + x = x + 1 + x = x.log() + + return x + + +class ZeroPad1d(nn.Module): + def __init__(self, pad_left, pad_right): + super().__init__() + self.pad_left = pad_left + self.pad_right = pad_right + + def forward(self, x): + return F.pad(x, (self.pad_left, self.pad_right)) + + +class ConvAggegator(nn.Module): + def __init__(self, conv_layers, embed, dropout, skip_connections, residual_scale, non_affine_group_norm, conv_bias, + zero_pad): + super().__init__() + + def block(n_in, n_out, k, stride): + # padding dims only really make sense for stride = 1 + ka = k // 2 + kb = ka - 1 if k % 2 == 0 else ka + + pad = ZeroPad1d(ka + kb, 0) if zero_pad else nn.ReplicationPad1d((ka + kb, 0)) + + return nn.Sequential( + pad, + nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias), + nn.Dropout(p=dropout), + norm_block(False, n_out, affine=not non_affine_group_norm), + nn.ReLU(), + ) + + in_d = embed + self.conv_layers = nn.ModuleList() + self.residual_proj = nn.ModuleList() + for i, (dim, k, stride) in enumerate(conv_layers): + if in_d != dim and skip_connections: + self.residual_proj.append( + nn.Conv1d(in_d, dim, 1, bias=False), + ) + else: + self.residual_proj.append(None) + + self.conv_layers.append( + block(in_d, dim, k, stride)) + in_d = dim + self.conv_layers = nn.Sequential(*self.conv_layers) + self.skip_connections = skip_connections + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x): + for rproj, conv in zip(self.residual_proj, self.conv_layers): + residual = x + x = conv(x) + if self.skip_connections: + if rproj != None: + residual = rproj(residual) + x = (x + residual) * self.residual_scale + return x + + +class Wav2VecPredictionsModel(nn.Module): + def __init__(self, in_dim, out_dim, prediction_steps, n_negatives, cross_sample_negatives, sample_distance, + dropout, offset, balanced_classes): + super().__init__() + + self.n_negatives = n_negatives + self.cross_sample_negatives = cross_sample_negatives + self.sample_distance = sample_distance + + self.project_to_steps = nn.ConvTranspose2d(in_dim, out_dim, (1, prediction_steps)) + self.dropout = nn.Dropout(p=dropout) + self.offset = offset + self.balanced_classes = balanced_classes + + def sample_negatives(self, y): + bsz, fsz, tsz = y.shape + + y = y.transpose(0, 1) # BCT -> CBT + y = y.contiguous().view(fsz, -1) # CBT => C(BxT) + + if self.cross_sample_negatives: + high = tsz * bsz + assert self.sample_distance is None, 'sample distance is not supported with cross sampling' + else: + high = tsz if self.sample_distance is None else min(tsz, self.sample_distance) + + neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz)) + + if self.sample_distance is not None and self.sample_distance < tsz: + neg_idxs += torch.cat( + [torch.arange(start=1, end=tsz - self.sample_distance, device=neg_idxs.device, dtype=neg_idxs.dtype), + torch.arange(start=tsz - self.sample_distance, end=tsz - self.sample_distance * 2 - 1, step=-1, + device=neg_idxs.device, dtype=neg_idxs.dtype)]) + + if not self.cross_sample_negatives: + for i in range(1, bsz): + neg_idxs[i] += i * high + + negs = y[..., neg_idxs.view(-1)] + negs = negs.view(fsz, bsz, self.n_negatives, tsz).permute(2, 1, 0, 3) # to NxBxCxT + + return negs + + def forward(self, x, y): + negatives = self.sample_negatives(y) + y = y.unsqueeze(0) + targets = torch.cat([y, negatives], dim=0) + + x = x.unsqueeze(-1) + x = self.project_to_steps(x) # BxCxTxS + x = self.dropout(x) + x = x.unsqueeze(0).expand(targets.size(0), -1, -1, -1, -1) + + copies, bsz, dim, tsz, steps = x.shape + steps = min(steps, tsz - self.offset) + predictions = x.new(bsz * copies * (tsz - self.offset + 1) * steps - ((steps + 1) * steps // 2) * copies * bsz) + labels = torch.zeros_like(predictions) + weights = torch.full_like(labels, 1 / self.n_negatives) if self.balanced_classes else None + + start = end = 0 + for i in range(steps): + offset = i + self.offset + end = start + (tsz - offset) * bsz * copies + pos_num = (end - start) // copies + predictions[start:end] = (x[..., :-offset, i] * targets[..., offset:]).sum(dim=2).flatten() + labels[start:start + pos_num] = 1. + if weights is not None: + weights[start:start + pos_num] = 1. + start = end + assert end == predictions.numel(), '{} != {}'.format(end, predictions.numel()) + + if weights is not None: + labels = (labels, weights) + + return predictions, labels + + +@register_model_architecture('wav2vec', 'wav2vec') +def base_wav2vec_architecture(args): + conv_feature_layers = '[(512, 10, 5)]' + conv_feature_layers += ' + [(512, 8, 4)]' + conv_feature_layers += ' + [(512, 4, 2)] * 3' + args.conv_feature_layers = getattr(args, 'conv_feature_layers', conv_feature_layers) + + args.conv_aggregator_layers = getattr(args, 'conv_aggregator_layers', '[(512, 3, 1)] * 9') + + args.prediction_steps = getattr(args, 'prediction_steps', 12) + args.num_negatives = getattr(args, 'num_negatives', 1) + args.sample_distance = getattr(args, 'sample_distance', None) + args.cross_sample_negatives = getattr(args, 'cross_sample_negatives', False) + + args.dropout = getattr(args, 'dropout', 0.) + args.dropout_features = getattr(args, 'dropout_features', 0.) + args.dropout_agg = getattr(args, 'dropout_agg', 0.) + args.encoder = getattr(args, 'encoder', 'cnn') + args.aggregator = getattr(args, 'aggregator', 'cnn') + + args.skip_connections_feat = getattr(args, 'skip_connections_feat', False) + args.skip_connections_agg = getattr(args, 'skip_connections_agg', False) + args.residual_scale = getattr(args, 'residual_scale', 0.5) + + args.gru_dim = getattr(args, 'gru_dim', 512) + + args.no_conv_bias = getattr(args, 'no_conv_bias', False) + args.agg_zero_pad = getattr(args, 'agg_zero_pad', False) + + args.log_compression = getattr(args, 'log_compression', False) + + args.balanced_classes = getattr(args, 'balanced_classes', False) + args.project_features = getattr(args, 'project_features', 'none') + + args.non_affine_group_norm = getattr(args, 'non_affine_group_norm', False) + + args.offset = getattr(args, 'offset', 'auto') diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py new file mode 100644 index 0000000000..ac97c38ada --- /dev/null +++ b/fairseq/tasks/audio_pretraining.py @@ -0,0 +1,60 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import os + +from fairseq.data import RawAudioDataset +from . import FairseqTask, register_task + + +@register_task('audio_pretraining') +class AudioPretrainingTask(FairseqTask): + """ + + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument('data', help='path to data directory') + parser.add_argument('--sample-rate', default=16000, type=int, + help='target sample rate. audio files will be up/down sampled to this rate') + parser.add_argument('--max-sample-size', default=None, type=int, + help='max sample size to crop to for batching. default = min sample length') + parser.add_argument('--min-sample-size', default=None, type=int, + help='min sample size to crop to for batching. default = same as --max-sample-size') + + def __init__(self, args): + super().__init__(args) + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + return cls(args) + + def load_dataset(self, split, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + + manifest = os.path.join(self.args.data, '{}.tsv'.format(split)) + self.datasets[split] = RawAudioDataset(manifest, + sample_rate=self.args.sample_rate, + max_sample_size=self.args.max_sample_size, + min_sample_size=self.args.min_sample_size) + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return None \ No newline at end of file diff --git a/fairseq/trainer.py b/fairseq/trainer.py index d3e09efcaf..220eb81429 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -233,11 +233,11 @@ def train_step(self, samples, dummy_batch=False, raise_oom=False): # forward and backward pass logging_outputs, sample_sizes, ooms = [], [], 0 for i, sample in enumerate(samples): - sample = self._prepare_sample(sample) + sample = self._prepare_sample(sample, self.args.fp16) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients - sample = self._prepare_sample(self._dummy_batch) + sample = self._prepare_sample(self._dummy_batch, self.args.fp16) ignore_grad = True else: ignore_grad = False @@ -381,9 +381,9 @@ def valid_step(self, sample, raise_oom=False): self.model.eval() self.criterion.eval() - sample = self._prepare_sample(sample) + sample = self._prepare_sample(sample, self.args.fp16) if sample is None: - sample = self._prepare_sample(self._dummy_batch) + sample = self._prepare_sample(self._dummy_batch, self.args.fp16) ignore_results = True else: ignore_results = False @@ -488,12 +488,19 @@ def set_num_updates(self, num_updates): self._num_updates = num_updates self.lr_step_update() - def _prepare_sample(self, sample): + def _prepare_sample(self, sample, fp16): if sample is None or len(sample) == 0: return None + if self.cuda: sample = utils.move_to_cuda(sample) - return sample + + def apply_half(t): + if t.dtype is torch.float32: + return t.half() + return t + + return utils.apply(apply_half, sample) if fp16 else sample def _set_seed(self): # Set seed based on args.seed and the update number so that we get diff --git a/fairseq/utils.py b/fairseq/utils.py index 10e70a8d08..7af1400275 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -31,24 +31,26 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): ) -def move_to_cuda(sample): +def apply(f, sample): if len(sample) == 0: return {} + if torch.is_tensor(sample): + return f(sample) + elif isinstance(sample, dict): + return { + key: apply(f, value) + for key, value in sample.items() + } + elif isinstance(sample, list): + return [apply(f, x) for x in sample] + else: + return sample - def _move_to_cuda(maybe_tensor): - if torch.is_tensor(maybe_tensor): - return maybe_tensor.cuda() - elif isinstance(maybe_tensor, dict): - return { - key: _move_to_cuda(value) - for key, value in maybe_tensor.items() - } - elif isinstance(maybe_tensor, list): - return [_move_to_cuda(x) for x in maybe_tensor] - else: - return maybe_tensor - - return _move_to_cuda(sample) + +def move_to_cuda(sample): + def _move_to_cuda(tensor): + return tensor.cuda() + return apply(_move_to_cuda, sample) INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) diff --git a/scripts/wav2vec_featurize.py b/scripts/wav2vec_featurize.py new file mode 100644 index 0000000000..70764d7933 --- /dev/null +++ b/scripts/wav2vec_featurize.py @@ -0,0 +1,231 @@ +""" Helper script to pre-compute embeddings for a wav2letter++ dataset +""" + +import glob, os +import tqdm +from shutil import copy + +import soundfile as sf + +import h5py +import numpy as np + +import torch +from torch import nn + +from fairseq.models.wav2vec import Wav2VecModel + +import argparse + + +def read_audio(fname): + """ Load an audio file and return PCM along with the sample rate """ + + wav, sr = sf.read(fname) + assert sr == 16e3 + + return wav, 16e3 + + +class PretrainedWav2VecModel(nn.Module): + + def __init__(self, fname): + super().__init__() + + checkpoint = torch.load(fname) + self.args = checkpoint["args"] + model = Wav2VecModel.build_model(self.args, None) + model.load_state_dict(checkpoint["model"]) + model.eval() + + self.model = model + + def forward(self, x): + with torch.no_grad(): + z = self.model.feature_extractor(x) + if isinstance(z, tuple): + z = z[0] + c = self.model.feature_aggregator(z) + return z, c + + +class EmbeddingWriterConfig(argparse.ArgumentParser): + + def __init__(self): + super().__init__("Pre-compute embeddings for wav2letter++ datasets") + + kwargs = {"action": "store", "type": str, "required": True} + + self.add_argument("--input", "-i", + help="Input Directory", **kwargs) + self.add_argument("--output", "-o", + help="Output Directory", **kwargs) + self.add_argument("--model", + help="Path to model checkpoint", **kwargs) + self.add_argument("--split", + help="Dataset Splits", nargs='+', **kwargs) + self.add_argument("--ext", default="wav", required=False, + help="Audio file extension") + + self.add_argument("--no-copy-labels", action="store_true", + help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.") + self.add_argument("--use-feat", action="store_true", + help="Use the feature vector ('z') instead of context vector ('c') for features") + self.add_argument("--gpu", + help="GPU to use", default=0, type=int) + + +class Prediction(): + """ Lightweight wrapper around a fairspeech embedding model """ + + def __init__(self, fname, gpu=0): + self.gpu = gpu + self.model = PretrainedWav2VecModel(fname).cuda(gpu) + + def __call__(self, x): + x = torch.from_numpy(x).float().cuda(self.gpu) + with torch.no_grad(): + z, c = self.model(x.unsqueeze(0)) + + return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy() + + +class H5Writer(): + """ Write features as hdf5 file in wav2letter++ compatible format """ + + def __init__(self, fname): + self.fname = fname + os.makedirs(os.path.dirname(self.fname), exist_ok=True) + + def write(self, data): + channel, T = data.shape + + with h5py.File(self.fname, "w") as out_ds: + data = data.T.flatten() + out_ds["features"] = data + out_ds["info"] = np.array([16e3 // 160, T, channel]) + + +class EmbeddingDatasetWriter(object): + """ Given a model and a wav2letter++ dataset, pre-compute and store embeddings + + Args: + input_root, str : + Path to the wav2letter++ dataset + output_root, str : + Desired output directory. Will be created if non-existent + split, str : + Dataset split + """ + + def __init__(self, input_root, output_root, split, + model_fname, + extension="wav", + gpu=0, + verbose=False, + use_feat=False, + ): + + assert os.path.exists(model_fname) + + self.model_fname = model_fname + self.model = Prediction(self.model_fname, gpu) + + self.input_root = input_root + self.output_root = output_root + self.split = split + self.verbose = verbose + self.extension = extension + self.use_feat = use_feat + + assert os.path.exists(self.input_path), \ + "Input path '{}' does not exist".format(self.input_path) + + def _progress(self, iterable, **kwargs): + if self.verbose: + return tqdm.tqdm(iterable, **kwargs) + return iterable + + def require_output_path(self, fname=None): + path = self.get_output_path(fname) + os.makedirs(path, exist_ok=True) + + @property + def input_path(self): + return self.get_input_path() + + @property + def output_path(self): + return self.get_output_path() + + def get_input_path(self, fname=None): + if fname is None: + return os.path.join(self.input_root, self.split) + return os.path.join(self.get_input_path(), fname) + + def get_output_path(self, fname=None): + if fname is None: + return os.path.join(self.output_root, self.split) + return os.path.join(self.get_output_path(), fname) + + def copy_labels(self): + self.require_output_path() + + labels = list(filter(lambda x: self.extension not in x, glob.glob(self.get_input_path("*")))) + for fname in tqdm.tqdm(labels): + copy(fname, self.output_path) + + @property + def input_fnames(self): + return sorted(glob.glob(self.get_input_path("*.{}".format(self.extension)))) + + def __len__(self): + return len(self.input_fnames) + + def write_features(self): + + paths = self.input_fnames + + fnames_context = map(lambda x: os.path.join(self.output_path, x.replace("." + self.extension, ".h5context")), \ + map(os.path.basename, paths)) + + for name, target_fname in self._progress(zip(paths, fnames_context), total=len(self)): + wav, sr = read_audio(name) + z, c = self.model(wav) + feat = z if self.use_feat else c + writer = H5Writer(target_fname) + writer.write(feat) + + def __repr__(self): + + return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format( + n_files=len(self), **self.__dict__) + + +if __name__ == "__main__": + + args = EmbeddingWriterConfig().parse_args() + + for split in args.split: + + writer = EmbeddingDatasetWriter( + input_root=args.input, + output_root=args.output, + split=split, + model_fname=args.model, + gpu=args.gpu, + extension=args.ext, + use_feat=args.use_feat, + ) + + print(writer) + writer.require_output_path() + + print("Writing Features...") + writer.write_features() + print("Done.") + + if not args.no_copy_labels: + print("Copying label data...") + writer.copy_labels() + print("Done.") \ No newline at end of file diff --git a/scripts/wav2vec_manifest.py b/scripts/wav2vec_manifest.py new file mode 100644 index 0000000000..949edd58dc --- /dev/null +++ b/scripts/wav2vec_manifest.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. +""" +Data pre-processing: build vocabularies and binarize training data. +""" + +import argparse +import glob +import os +import soundfile +import random + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index') + parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D', + help='percentage of data to use as validation set (between 0 and 1)') + parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory') + parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for') + parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed') + parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG', + help='if set, path must contain this substring for a file to be included in the manifest') + return parser + + +def main(args): + assert args.valid_percent >= 0 and args.valid_percent <= 1. + + dir_path = os.path.realpath(args.root) + search_path = os.path.join(dir_path, '**/*.' + args.ext) + rand = random.Random(args.seed) + + with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open( + os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f: + print(dir_path, file=train_f) + print(dir_path, file=valid_f) + + for fname in glob.iglob(search_path, recursive=True): + file_path = os.path.realpath(fname) + + if args.path_must_contain and args.path_must_contain not in file_path: + continue + + frames = soundfile.info(fname).frames + dest = train_f if rand.random() > args.valid_percent else valid_f + print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest) + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args)