Skip to content

Commit

Permalink
Merge pull request #4600 from simpleoier/s3prl_update
Browse files Browse the repository at this point in the history
Update to fit the recent update in s3prl.
  • Loading branch information
sw005320 committed Aug 26, 2022
2 parents b6f65c3 + 0a62508 commit 80e0420
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 120 deletions.
3 changes: 1 addition & 2 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ fi
# pycodestyle
pycodestyle -r ${modules} --show-source --show-pep8

LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-}:$(pwd)/tools/chainer_ctc/ext/warp-ctc/build" \
PYTHONPATH="${PYTHONPATH:-}:$(pwd)/tools/s3prl" pytest -q
LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-}:$(pwd)/tools/chainer_ctc/ext/warp-ctc/build" pytest -q

echo "=== report ==="
coverage report
Expand Down
2 changes: 1 addition & 1 deletion egs2/americasnlp22/asr1/conf/train_asr_transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ frontend: s3prl
frontend_conf:
frontend_conf:
upstream: wav2vec2_url # Note: If the upstream is changed, please change the input_size in the preencoder.
upstream_ckpt: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr2_300m.pt
path_or_url: https://huggingface.co/s3prl/converted_ckpts/resolve/main/xlsr2_300m.pt
download_dir: ./hub
multilayer_feature: True

Expand Down
2 changes: 1 addition & 1 deletion egs2/catslu/asr1/conf/train_asr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ frontend_conf:
frontend: s3prl
frontend_conf:
frontend_conf:
upstream: wav2vec2_xlsr # Note: If the upstream is changed, please change the input_size in the preencoder.
upstream: xlsr_53 # Note: If the upstream is changed, please change the input_size in the preencoder.
download_dir: ./hub
multilayer_feature: True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ frontend_conf:
frontend: s3prl
frontend_conf:
frontend_conf:
upstream: wav2vec2_xlsr # Note: If the upstream is changed, please change the input_size in the preencoder.
upstream: xlsr_53 # Note: If the upstream is changed, please change the input_size in the preencoder.
download_dir: ./hub
multilayer_feature: True

Expand Down
2 changes: 1 addition & 1 deletion egs2/ms_indic_18/asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
- Commit date: `Mon Mar 14 22:32:17 2022 -0400`
- Pretrained model: [espnet/chai_microsoft_indian_langs_te](https://huggingface.co/espnet/chai_microsoft_indian_langs_te)

## Self-supervised learning features [wav2vec2_xlsr, Conformer, utt_mvn](conf/tuning/train_asr_xlsr53_conformer.yaml) with [Transformer-LM](conf/tuning/train_lm_transformer.yaml) and [RNN-LM](conf/tuning/train_lm_rnn.yaml). During inference, all below models use the same [decoding parameters](conf/tuning/decode_asr_transformer.yaml).
## Self-supervised learning features [wav2vec2_xlsr_53, Conformer, utt_mvn](conf/tuning/train_asr_xlsr53_conformer.yaml) with [Transformer-LM](conf/tuning/train_lm_transformer.yaml) and [RNN-LM](conf/tuning/train_lm_rnn.yaml). During inference, all below models use the same [decoding parameters](conf/tuning/decode_asr_transformer.yaml).

### WER

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ frontend_conf:

- frontend_type: s3prl
frontend_conf:
upstream: wav2vec2_xlsr
upstream: xlsr_53
download_dir: ./hub
multilayer_feature: True

Expand Down
125 changes: 39 additions & 86 deletions espnet2/asr/frontend/s3prl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import copy
import logging
import os
from argparse import Namespace
from pathlib import Path
from typing import Optional, Tuple, Union

import humanfriendly
Expand All @@ -12,18 +9,6 @@
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet.nets.pytorch_backend.frontends.frontend import Frontend
from espnet.nets.pytorch_backend.nets_utils import pad_list


def base_s3prl_setup(args):
args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
args.upstream_model_config = getattr(args, "upstream_model_config", None)
args.upstream_refresh = getattr(args, "upstream_refresh", False)
args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
args.init_ckpt = getattr(args, "init_ckpt", None)
args.verbose = getattr(args, "verbose", False)
args.tile_factor = getattr(args, "tile_factor", 1)
return args


class S3prlFrontend(AbsFrontend):
Expand All @@ -36,75 +21,48 @@ def __init__(
download_dir: str = None,
multilayer_feature: bool = False,
):
try:
import s3prl
from s3prl.nn import Featurizer, S3PRLUpstream
except Exception as e:
print("Error: S3PRL is not properly installed.")
print("Please install S3PRL: cd ${MAIN_ROOT}/tools && make s3prl.done")
raise e

assert check_argument_types()
super().__init__()

if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
if fs != 16000:
logging.warning(
"All the upstream models in S3PRL now only support 16 kHz audio."
)

if download_dir is not None:
torch.hub.set_dir(download_dir)

self.multilayer_feature = multilayer_feature
self.upstream, self.featurizer = self._get_upstream(frontend_conf)
self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
self.output_dim = self.featurizer.output_dim
self.frontend_type = "s3prl"
self.hop_length = self.upstream.get_downsample_rates("key")
s3prl.util.download.set_dir(download_dir)

def _get_upstream(self, frontend_conf):
"""Get S3PRL upstream model."""
s3prl_args = base_s3prl_setup(
Namespace(**frontend_conf, device="cpu"),
assert frontend_conf.get("upstream", None) in S3PRLUpstream.available_names()
upstream = S3PRLUpstream(
frontend_conf.get("upstream"),
path_or_url=frontend_conf.get("path_or_url", None),
)
self.args = s3prl_args

try:
import s3prl # noqa
except ModuleNotFoundError:
raise ModuleNotFoundError(
"s3prl is not installed, please git clone s3prl"
" (DO NOT USE PIP or CONDA) "
"and install it from Github repo, "
"by cloning it locally."
)
s3prl_path = Path(os.path.abspath(s3prl.__file__)).parent.parent
if not os.path.exists(os.path.join(s3prl_path, "hubconf.py")):
raise RuntimeError(
"You probably have s3prl installed as a pip"
"package, please uninstall it and then install it from "
"the GitHub repo, by cloning it locally."
)

s3prl_upstream = torch.hub.load(
s3prl_path,
s3prl_args.upstream,
ckpt=s3prl_args.upstream_ckpt,
model_config=s3prl_args.upstream_model_config,
refresh=s3prl_args.upstream_refresh,
source="local",
).to("cpu")

upstream.eval()
if getattr(
s3prl_upstream, "model", None
) is not None and s3prl_upstream.model.__class__.__name__ in [
upstream, "model", None
) is not None and upstream.model.__class__.__name__ in [
"Wav2Vec2Model",
"HubertModel",
]:
s3prl_upstream.model.encoder.layerdrop = 0.0

from s3prl.upstream.interfaces import Featurizer

if self.multilayer_feature:
feature_selection = "hidden_states"
else:
feature_selection = "last_hidden_state"
s3prl_featurizer = Featurizer(
upstream=s3prl_upstream,
feature_selection=feature_selection,
upstream_device="cpu",
)
upstream.model.encoder.layerdrop = 0.0
featurizer = Featurizer(upstream)

return s3prl_upstream, s3prl_featurizer
self.multilayer_feature = multilayer_feature
self.upstream, self.featurizer = upstream, featurizer
self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
self.frontend_type = "s3prl"
self.hop_length = self.featurizer.downsample_rate
self.tile_factor = frontend_conf.get("tile_factor", 1)

def _tile_representations(self, feature):
"""Tile up the representations by `tile_factor`.
Expand All @@ -118,33 +76,28 @@ def _tile_representations(self, feature):
assert (
len(feature.shape) == 3
), "Input argument `feature` has invalid shape: {}".format(feature.shape)
tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
tiled_feature = feature.repeat(1, 1, self.tile_factor)
tiled_feature = tiled_feature.reshape(
feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
feature.size(0), feature.size(1) * self.tile_factor, feature.size(2)
)
return tiled_feature

def output_size(self) -> int:
return self.output_dim
return self.featurizer.output_size

def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
self.upstream.eval()
feats = self.upstream(wavs)
feats = self.featurizer(wavs, feats)
feats, feats_lens = self.upstream(input, input_lengths)
if self.multilayer_feature:
feats, feats_lens = self.featurizer(feats, feats_lens)
else:
feats, feats_lens = self.featurizer(feats[-1:], feats_lens[-1:])

if self.args.tile_factor != 1:
if self.tile_factor != 1:
feats = self._tile_representations(feats)

input_feats = pad_list(feats, 0.0)
feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)

# Saving CUDA Memory
del feats

return input_feats, feats_lens
return feats, feats_lens

def reload_pretrained_parameters(self):
self.upstream.load_state_dict(self.pretrained_params)
Expand Down
18 changes: 14 additions & 4 deletions espnet2/torch_utils/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,20 @@ def initialize(model: torch.nn.Module, init: str):
model.encoder, "reload_pretrained_parameters", None
):
model.encoder.reload_pretrained_parameters()
if getattr(model, "frontend", None) and getattr(
model.frontend, "reload_pretrained_parameters", None
):
model.frontend.reload_pretrained_parameters()
if getattr(model, "frontend", None):
if getattr(model.frontend, "reload_pretrained_parameters", None):
model.frontend.reload_pretrained_parameters()
elif isinstance(
getattr(model.frontend, "frontends", None),
torch.nn.ModuleList,
):
for i, _ in enumerate(getattr(model.frontend, "frontends")):
if getattr(
model.frontend.frontends[i],
"reload_pretrained_parameters",
None,
):
model.frontend.frontends[i].reload_pretrained_parameters()
if getattr(model, "postencoder", None) and getattr(
model.postencoder, "reload_pretrained_parameters", None
):
Expand Down
32 changes: 21 additions & 11 deletions test/espnet2/asr/frontend/test_s3prl.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import pytest
import torch
from packaging.version import parse as V

from espnet2.asr.frontend.s3prl import S3prlFrontend

is_torch_1_7_plus = V(torch.__version__) >= V("1.7.0")
is_torch_1_8_plus = V(torch.__version__) >= V("1.8.0")


def test_frontend_init():
if not is_torch_1_7_plus:
if not is_torch_1_8_plus:
return

frontend = S3prlFrontend(
fs=16000,
frontend_conf=dict(upstream="mel"),
)
assert frontend.frontend_type == "s3prl"
assert frontend.output_dim > 0
assert frontend.output_size() > 0


def test_frontend_output_size():
# Skip some testing cases
if not is_torch_1_7_plus:
if not is_torch_1_8_plus:
return

frontend = S3prlFrontend(
Expand All @@ -32,17 +33,26 @@ def test_frontend_output_size():
wavs = torch.randn(2, 1600)
lengths = torch.LongTensor([1600, 1600])
feats, _ = frontend(wavs, lengths)
assert feats.shape[-1] == frontend.output_dim


def test_frontend_backward():
if not is_torch_1_7_plus:
assert feats.shape[-1] == frontend.output_size()


@pytest.mark.parametrize(
"fs, frontend_conf, multilayer_feature",
[
(16000, dict(upstream="mel"), True),
(16000, dict(upstream="mel"), False),
(16000, dict(upstream="mel", tile_factor=1), False),
],
)
def test_frontend_backward(fs, frontend_conf, multilayer_feature):
if not is_torch_1_8_plus:
return

frontend = S3prlFrontend(
fs=16000,
frontend_conf=dict(upstream="mel"),
fs=fs,
frontend_conf=frontend_conf,
download_dir="./hub",
multilayer_feature=multilayer_feature,
)
wavs = torch.randn(2, 1600, requires_grad=True)
lengths = torch.LongTensor([1600, 1600])
Expand Down
17 changes: 5 additions & 12 deletions tools/installers/install_s3prl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ fi
if ! python -c "import packaging.version" &> /dev/null; then
python3 -m pip install packaging
fi
torch_17_plus=$(python3 <<EOF
torch_18_plus=$(python3 <<EOF
from packaging.version import parse as V
import torch
if V(torch.__version__) >= V("1.7"):
if V(torch.__version__) >= V("1.8"):
print("true")
else:
print("false")
Expand Down Expand Up @@ -46,17 +46,10 @@ EOF
)
echo "cuda_version=${cuda_version}"

if "${torch_17_plus}" && "${python_36_plus}"; then

rm -rf s3prl

# S3PRL Commit id when making this PR: `commit e2db27b2fa87b09fc720264635dcc4515dc63825`
git clone https://github.com/s3prl/s3prl.git
cd s3prl
git checkout -b legacy_version e2db27b2fa87b09fc720264635dcc4515dc63825
cd ..
if "${torch_18_plus}" && "${python_36_plus}"; then
python -m pip install s3prl

else
echo "[WARNING] s3prl is not prepared for pytorch<1.7.0, python<3.6 now"
echo "[WARNING] s3prl is not prepared for pytorch<1.8.0, python<3.6 now"

fi

0 comments on commit 80e0420

Please sign in to comment.