Skip to content

Commit

Permalink
Merge pull request #5609 from Stanwang1210/master
Browse files Browse the repository at this point in the history
Integrate adapter for s3prl frontend
  • Loading branch information
sw005320 committed Feb 22, 2024
2 parents 9718190 + 6f1ae7d commit 98b0387
Show file tree
Hide file tree
Showing 14 changed files with 1,238 additions and 266 deletions.
60 changes: 60 additions & 0 deletions egs2/ml_superb/asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,66 @@ General steps to run tasks in LID trask are as follows:
```
./run_multi.sh --asr_config <your_training_config> --duration {10min, 1h} --lid true --only_lid false
```
## Adapter usage guidelines
### General steps to run ASR tasks are as follows:
- Follow the preparation of MLSUPERB until finishing stage 10
- Enabling the usage of adapter by setting asr_config to `conf/tuning/train_asr_s3prl_houlsby.yaml` or `conf/tuning/train_asr_s3prl_lora.yaml`
- Pretrained model: https://huggingface.co/espnet/s3prl_adapter_model
- For example,
```
./run_mono.sh --asr_config conf/tuning/train_asr_s3prl_houlsby.yaml
```
- For the configuration for adapter, you may set the following argument in the yaml-style config files located in `conf/tuning`:
```
# LoRA
use_adapter: true
adapter: lora
save_strategy: adapter
adapter_conf:
rank: 4
alpha: 4
dropout_rate: 0.1
target_modules:
- fc1
- fc2
# Houlsby
use_adapter: true
adapter: houlsby
save_strategy: required_grad_only
adapter_conf:
bottleneck: 32
# target layers to insert adapters, Insert adapter to all layers if not specified
target_layers:
- 0
- 1
```
### Result: CER/PER
Experiment Setup

- SSL: HuBERT Base
- optim: adam
- Basically follow default settings of MLSUPERB

#### eng1

|Baseline|10min|1h|
|---|---|---|
|No Adapter|33.8|26.7|
|Houlsby Adapter|31.0|23.6|

#### deu1
|Baseline|10min|1h|
|---|---|---|
|No Adapter|35.1|30.2|
|Houlsby Adapter|33.7|27.7|

#### jpn
|Baseline|10min|1h|
|---|---|---|
|No Adapter|20.6|15.6|
|Houlsby Adapter|15.3|11.9|


## Credits

Expand Down
80 changes: 80 additions & 0 deletions egs2/ml_superb/asr1/conf/tuning/train_asr_s3prl_houlsby.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
encoder: transformer
encoder_conf:
output_size: 256
attention_heads: 8
linear_units: 1024
num_blocks: 2
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d2
normalize_before: true

decoder: none

model_conf:
ctc_weight: 1.0
extract_feats_in_collect_stats: False
unused_parameters: true
freeze_param: [
"frontend.upstream"
]

frontend: s3prl
frontend_conf:
frontend_conf:
upstream: hubert_base # Note: If the upstream is changed, please change the input_size in the preencoder.
download_dir: ./hub
multilayer_feature: True

preencoder: linear
preencoder_conf:
input_size: 768 # Note: If the upstream is changed, please change this value accordingly.
output_size: 80

use_adapter: true
adapter: houlsby
save_strategy: required_grad_only
adapter_conf:
bottleneck: 32
# target layers to insert adapters, Insert adapter to all layers if not specified
# target_layers:
# - 0
# - 1

num_workers: 4
batch_type: sorted
batch_size: 8
accum_grad: 4
patience: none
init: none
best_model_criterion:
- - valid
- loss
- min
keep_nbest_models: 5

optim: adam
optim_conf:
lr: 0.0001
weight_decay: 0.000001

specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 10


num_iters_per_epoch: 0 # number of iterations per epoch
max_epoch: 1
82 changes: 82 additions & 0 deletions egs2/ml_superb/asr1/conf/tuning/train_asr_s3prl_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
encoder: transformer
encoder_conf:
output_size: 256
attention_heads: 8
linear_units: 1024
num_blocks: 2
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d2
normalize_before: true

decoder: none

model_conf:
ctc_weight: 1.0
extract_feats_in_collect_stats: False
unused_parameters: true
freeze_param: [
"frontend.upstream"
]

frontend: s3prl
frontend_conf:
frontend_conf:
upstream: hubert_base # Note: If the upstream is changed, please change the input_size in the preencoder.
download_dir: ./hub
multilayer_feature: True

preencoder: linear
preencoder_conf:
input_size: 768 # Note: If the upstream is changed, please change this value accordingly.
output_size: 80


use_adapter: true
adapter: lora
save_strategy: adapter
adapter_conf:
rank: 4
alpha: 4
dropout_rate: 0.1
target_modules:
- q_proj
- k_proj

num_workers: 4
batch_type: sorted
batch_size: 8
accum_grad: 4
patience: none
init: none
best_model_criterion:
- - valid
- loss
- min
keep_nbest_models: 5

optim: adam
optim_conf:
lr: 0.0001
weight_decay: 0.000001

specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 10


num_iters_per_epoch: 500 # number of iterations per epoch
max_epoch: 30
43 changes: 43 additions & 0 deletions espnet2/layers/create_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Definition of the low-rank adaptation (LoRA) for large models.
References:
1. LoRA: Low-Rank Adaptation of Large Language Models
(https://arxiv.org/pdf/2106.09685.pdf)
2. https://github.com/microsoft/LoRA.git
3. https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py
"""

from typing import List

import torch
from typeguard import check_argument_types

from espnet2.layers.create_adapter_fn import create_houlsby_adapter, create_lora_adapter
from espnet2.train.class_choices import ClassChoices

create_adapter_fn_table = {
"lora": create_lora_adapter,
"houlsby": create_houlsby_adapter,
}


def create_adapter(
model: torch.nn.Module,
adapter: str,
adapter_conf: dict,
):
"""Create adapter for the base model.
Args:
model (torch.nn.Module): Base model to be adapted.
adapter_type (str): Name of adapter
adapter_conf (dict): Configuration for the adapter
e.g. {"rank": 8, "alpha": 8, ...} for lora
"""
assert check_argument_types()
assert adapter in create_adapter_fn_table, f"Adapter {adapter} is not supported."
create_adapter_fn = create_adapter_fn_table[adapter]
create_adapter_fn(model=model, **adapter_conf)

0 comments on commit 98b0387

Please sign in to comment.