forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
wav2vec2_encoder.py
168 lines (135 loc) · 5.5 KB
/
wav2vec2_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright 2021 Xuankai Chang
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder definition."""
import contextlib
import copy
import logging
import os
from typing import Optional, Tuple
import torch
from filelock import FileLock
from typeguard import check_argument_types
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
class FairSeqWav2Vec2Encoder(AbsEncoder):
"""FairSeq Wav2Vec2 encoder module.
Args:
input_size: input dim
output_size: dimension of attention
w2v_url: url to Wav2Vec2.0 pretrained model
w2v_dir_path: directory to download the Wav2Vec2.0 pretrained model.
normalize_before: whether to use layer_norm before the first block
finetune_last_n_layers: last n layers to be finetuned in Wav2Vec2.0
0 means to finetune every layer if freeze_w2v=False.
"""
def __init__(
self,
input_size: int,
w2v_url: str,
w2v_dir_path: str = "./",
output_size: int = 256,
normalize_before: bool = False,
freeze_finetune_updates: int = 0,
):
assert check_argument_types()
super().__init__()
if w2v_url != "":
try:
import fairseq
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model
except Exception as e:
print("Error: FairSeq is not properly installed.")
print(
"Please install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done"
)
raise e
self.w2v_model_path = download_w2v(w2v_url, w2v_dir_path)
self._output_size = output_size
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[self.w2v_model_path],
arg_overrides={"data": w2v_dir_path},
)
model = models[0]
if not isinstance(model, Wav2Vec2Model):
try:
model = model.w2v_encoder.w2v_model
except Exception as e:
print(
"Error: pretrained models should be within: "
"'Wav2Vec2Model, Wav2VecCTC' classes, etc."
)
raise e
self.encoders = model
self.pretrained_params = copy.deepcopy(model.state_dict())
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
if model.cfg.encoder_embed_dim != output_size:
# TODO(xkc09): try LSTM
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(model.cfg.encoder_embed_dim, output_size),
)
else:
self.output_layer = None
self.freeze_finetune_updates = freeze_finetune_updates
self.register_buffer("num_updates", torch.LongTensor([0]))
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Forward FairSeqWav2Vec2 Encoder.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = make_pad_mask(ilens).to(xs_pad.device)
ft = self.freeze_finetune_updates <= self.num_updates
if self.num_updates <= self.freeze_finetune_updates:
self.num_updates += 1
elif ft and self.num_updates == self.freeze_finetune_updates + 1:
self.num_updates += 1
logging.info("Start fine-tuning wav2vec parameters!")
with torch.no_grad() if not ft else contextlib.nullcontext():
enc_outputs = self.encoders(
xs_pad,
masks,
mask=self.training,
features_only=True,
)
xs_pad = enc_outputs["x"] # (B,T,C),
bs = xs_pad.shape[0]
if enc_outputs["padding_mask"] is not None:
masks = enc_outputs["padding_mask"] # (B, T)
olens = (~masks).sum(dim=1) # (B)
else:
olens = torch.IntTensor([xs_pad.shape[1]]).repeat(bs).to(xs_pad.device)
if self.output_layer is not None:
xs_pad = self.output_layer(xs_pad)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
return xs_pad, olens, None
def reload_pretrained_parameters(self):
self.encoders.load_state_dict(self.pretrained_params)
logging.info("Pretrained Wav2Vec model parameters reloaded!")
def download_w2v(model_url, dir_path):
os.makedirs(dir_path, exist_ok=True)
model_name = model_url.split("/")[-1]
model_path = os.path.join(dir_path, model_name)
dict_url = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt"
dict_path = os.path.join(dir_path, dict_url.split("/")[-1])
with FileLock(model_path + ".lock"):
if not os.path.exists(model_path):
torch.hub.download_url_to_file(model_url, model_path)
torch.hub.download_url_to_file(dict_url, dict_path)
logging.info(f"Wav2Vec model downloaded {model_path}")
else:
logging.info(f"Wav2Vec model {model_path} already exists.")
return model_path