-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5609 from Stanwang1210/master
Integrate adapter for s3prl frontend
- Loading branch information
Showing
14 changed files
with
1,238 additions
and
266 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
egs2/ml_superb/asr1/conf/tuning/train_asr_s3prl_houlsby.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
encoder: transformer | ||
encoder_conf: | ||
output_size: 256 | ||
attention_heads: 8 | ||
linear_units: 1024 | ||
num_blocks: 2 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.1 | ||
input_layer: conv2d2 | ||
normalize_before: true | ||
|
||
decoder: none | ||
|
||
model_conf: | ||
ctc_weight: 1.0 | ||
extract_feats_in_collect_stats: False | ||
unused_parameters: true | ||
freeze_param: [ | ||
"frontend.upstream" | ||
] | ||
|
||
frontend: s3prl | ||
frontend_conf: | ||
frontend_conf: | ||
upstream: hubert_base # Note: If the upstream is changed, please change the input_size in the preencoder. | ||
download_dir: ./hub | ||
multilayer_feature: True | ||
|
||
preencoder: linear | ||
preencoder_conf: | ||
input_size: 768 # Note: If the upstream is changed, please change this value accordingly. | ||
output_size: 80 | ||
|
||
use_adapter: true | ||
adapter: houlsby | ||
save_strategy: required_grad_only | ||
adapter_conf: | ||
bottleneck: 32 | ||
# target layers to insert adapters, Insert adapter to all layers if not specified | ||
# target_layers: | ||
# - 0 | ||
# - 1 | ||
|
||
num_workers: 4 | ||
batch_type: sorted | ||
batch_size: 8 | ||
accum_grad: 4 | ||
patience: none | ||
init: none | ||
best_model_criterion: | ||
- - valid | ||
- loss | ||
- min | ||
keep_nbest_models: 5 | ||
|
||
optim: adam | ||
optim_conf: | ||
lr: 0.0001 | ||
weight_decay: 0.000001 | ||
|
||
specaug: specaug | ||
specaug_conf: | ||
apply_time_warp: true | ||
time_warp_window: 5 | ||
time_warp_mode: bicubic | ||
apply_freq_mask: true | ||
freq_mask_width_range: | ||
- 0 | ||
- 27 | ||
num_freq_mask: 2 | ||
apply_time_mask: true | ||
time_mask_width_ratio_range: | ||
- 0. | ||
- 0.05 | ||
num_time_mask: 10 | ||
|
||
|
||
num_iters_per_epoch: 0 # number of iterations per epoch | ||
max_epoch: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
encoder: transformer | ||
encoder_conf: | ||
output_size: 256 | ||
attention_heads: 8 | ||
linear_units: 1024 | ||
num_blocks: 2 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.1 | ||
input_layer: conv2d2 | ||
normalize_before: true | ||
|
||
decoder: none | ||
|
||
model_conf: | ||
ctc_weight: 1.0 | ||
extract_feats_in_collect_stats: False | ||
unused_parameters: true | ||
freeze_param: [ | ||
"frontend.upstream" | ||
] | ||
|
||
frontend: s3prl | ||
frontend_conf: | ||
frontend_conf: | ||
upstream: hubert_base # Note: If the upstream is changed, please change the input_size in the preencoder. | ||
download_dir: ./hub | ||
multilayer_feature: True | ||
|
||
preencoder: linear | ||
preencoder_conf: | ||
input_size: 768 # Note: If the upstream is changed, please change this value accordingly. | ||
output_size: 80 | ||
|
||
|
||
use_adapter: true | ||
adapter: lora | ||
save_strategy: adapter | ||
adapter_conf: | ||
rank: 4 | ||
alpha: 4 | ||
dropout_rate: 0.1 | ||
target_modules: | ||
- q_proj | ||
- k_proj | ||
|
||
num_workers: 4 | ||
batch_type: sorted | ||
batch_size: 8 | ||
accum_grad: 4 | ||
patience: none | ||
init: none | ||
best_model_criterion: | ||
- - valid | ||
- loss | ||
- min | ||
keep_nbest_models: 5 | ||
|
||
optim: adam | ||
optim_conf: | ||
lr: 0.0001 | ||
weight_decay: 0.000001 | ||
|
||
specaug: specaug | ||
specaug_conf: | ||
apply_time_warp: true | ||
time_warp_window: 5 | ||
time_warp_mode: bicubic | ||
apply_freq_mask: true | ||
freq_mask_width_range: | ||
- 0 | ||
- 27 | ||
num_freq_mask: 2 | ||
apply_time_mask: true | ||
time_mask_width_ratio_range: | ||
- 0. | ||
- 0.05 | ||
num_time_mask: 10 | ||
|
||
|
||
num_iters_per_epoch: 500 # number of iterations per epoch | ||
max_epoch: 30 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""Definition of the low-rank adaptation (LoRA) for large models. | ||
References: | ||
1. LoRA: Low-Rank Adaptation of Large Language Models | ||
(https://arxiv.org/pdf/2106.09685.pdf) | ||
2. https://github.com/microsoft/LoRA.git | ||
3. https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py | ||
""" | ||
|
||
from typing import List | ||
|
||
import torch | ||
from typeguard import check_argument_types | ||
|
||
from espnet2.layers.create_adapter_fn import create_houlsby_adapter, create_lora_adapter | ||
from espnet2.train.class_choices import ClassChoices | ||
|
||
create_adapter_fn_table = { | ||
"lora": create_lora_adapter, | ||
"houlsby": create_houlsby_adapter, | ||
} | ||
|
||
|
||
def create_adapter( | ||
model: torch.nn.Module, | ||
adapter: str, | ||
adapter_conf: dict, | ||
): | ||
"""Create adapter for the base model. | ||
Args: | ||
model (torch.nn.Module): Base model to be adapted. | ||
adapter_type (str): Name of adapter | ||
adapter_conf (dict): Configuration for the adapter | ||
e.g. {"rank": 8, "alpha": 8, ...} for lora | ||
""" | ||
assert check_argument_types() | ||
assert adapter in create_adapter_fn_table, f"Adapter {adapter} is not supported." | ||
create_adapter_fn = create_adapter_fn_table[adapter] | ||
create_adapter_fn(model=model, **adapter_conf) |
Oops, something went wrong.