Skip to content

Commit

Permalink
Add LoRA installation to Makefile.
Browse files Browse the repository at this point in the history
  • Loading branch information
pengchengguo committed Oct 10, 2023
1 parent 121a77d commit df4536b
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 54 deletions.
2 changes: 1 addition & 1 deletion ci/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ${CXX:-g++} -v

. ./activate_python.sh
# FIXME(kamo): Failed to compile pesq
make TH_VERSION="${TH_VERSION}" WITH_OMP="${WITH_OMP-ON}" all warp-transducer.done chainer_ctc.done nkf.done moses.done mwerSegmenter.done pyopenjtalk.done py3mmseg.done s3prl.done transformers.done phonemizer.done fairseq.done k2.done gtn.done longformer.done whisper.done parallel-wavegan.done muskits.done
make TH_VERSION="${TH_VERSION}" WITH_OMP="${WITH_OMP-ON}" all warp-transducer.done chainer_ctc.done nkf.done moses.done mwerSegmenter.done pyopenjtalk.done py3mmseg.done s3prl.done transformers.done phonemizer.done fairseq.done k2.done gtn.done longformer.done whisper.done parallel-wavegan.done muskits.done lora.done
rm -rf kaldi
)
. tools/activate_python.sh
Expand Down
14 changes: 0 additions & 14 deletions espnet2/layers/create_lora_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,6 @@ def create_new_module(
r=rank,
lora_alpha=alpha,
)
elif isinstance(target_module, torch.nn.Conv2d):
out_channels, in_channels = target_module.weight.size()[:2]
kernel_size = target_module.weight.size()[2:]
new_module = lora.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=target_module.stride,
padding=target_module.padding,
bias=bias,
r=rank,
lora_alpha=alpha,
lora_dropout=dropout_rate,
)
elif isinstance(target_module, torch.nn.Linear):
new_module = lora.Linear(

Check warning on line 112 in espnet2/layers/create_lora_adapter.py

View check run for this annotation

Codecov / codecov/patch

espnet2/layers/create_lora_adapter.py#L111-L112

Added lines #L111 - L112 were not covered by tests
target_module.in_features,
Expand Down
87 changes: 50 additions & 37 deletions test/espnet2/layers/test_create_lora_adapter.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
import sys

import pytest
import torch

from espnet2.asr.decoder.transformer_decoder import TransformerDecoder
from espnet2.asr.encoder.transformer_encoder import TransformerEncoder
from espnet2.layers.create_lora_adapter import create_lora_adapter

pytest.importorskip("lora")

is_python_3_8_plus = sys.version_info >= (3, 8)
is_torch_1_8_plus = V(torch.__version__) >= V("1.8.0")

def init_encoder():
return TransformerEncoder(
20,
output_size=40,
attention_heads=4,
linear_units=40,
num_blocks=2,
input_layer="conv2d",
)


def init_decoder():
def init_model():
return TransformerDecoder(
vocab_size=10,
encoder_output_size=40,
Expand All @@ -29,44 +21,65 @@ def init_decoder():
)


@pytest.mark.parametrize(
"rank, alpha, target_modules",
[
(2, 4, ["linear_q"]),
(2, 4, ["linear_q", "linear_k", "linear_v", "linear_out"]),
],
@pytest.mark.skipif(
not is_torch_1_8_plus or not is_python_3_8_plus, reason="Not supported"
)
def test_create_lora_adapter_encoder(rank, alpha, target_modules):
model = init_encoder()
@pytest.mark.parametrize("rank, alpha, target_modules", [(2, 4, ["linear_q"])])
def test_create_lora_adapter_linear(rank, alpha, target_modules):
model = init_model()
create_lora_adapter(
model=model, rank=rank, alpha=alpha, target_modules=target_modules
)
print(model)

assert model.decoders[0].self_attn.linear_q.lora_A.shape[0] == rank
assert model.decoders[0].self_attn.linear_q.lora_B.shape[1] == rank


@pytest.mark.parametrize(
"rank, alpha, target_modules",
[
(2, 4, ["linear_q"]),
(2, 4, ["linear_q", "linear_k", "linear_v", "linear_out"]),
(2, 4, ["embed.0"]), # Embedding layer
],
@pytest.mark.skipif(
not is_torch_1_8_plus or not is_python_3_8_plus, reason="Not supported"
)
def test_create_lora_adapter_decoder(rank, alpha, target_modules):
model = init_decoder()
@pytest.mark.parametrize("rank, alpha, target_modules", [(2, 4, ["embed.0"])])
def test_create_lora_adapter_embedding(rank, alpha, target_modules):
model = init_model()
create_lora_adapter(
model=model, rank=rank, alpha=alpha, target_modules=target_modules
)
print(model)

assert model.embed[0].lora_A.shape[0] == rank
assert model.embed[0].lora_B.shape[1] == rank


@pytest.mark.parametrize(
"rank, alpha, target_modules",
[(2, 4, ["linear"])],
@pytest.mark.skipif(
not is_torch_1_8_plus or not is_python_3_8_plus, reason="Not supported"
)
@pytest.mark.parametrize("rank, alpha, target_modules", [(2, 4, ["query_proj"])])
def test_create_lora_adapter_invalid_target(rank, alpha, target_modules):
model = init_encoder()
model = init_model()
with pytest.raises(ValueError):
create_lora_adapter(
model=model, rank=rank, alpha=alpha, target_modules=target_modules
)


@pytest.mark.skipif(
not is_torch_1_8_plus or not is_python_3_8_plus, reason="Not supported"
)
@pytest.mark.parametrize("rank, alpha, target_modules", [(2, 4, ["norm1"])])
def test_create_lora_adapter_unsupport_target(rank, alpha, target_modules):
model = init_model()
with pytest.raises(ValueError):
create_lora_adapter(
model=model, rank=rank, alpha=alpha, target_modules=target_modules
)


@pytest.mark.skipif(
not is_torch_1_8_plus or not is_python_3_8_plus, reason="Not supported"
)
@pytest.mark.parametrize("rank, alpha, target_modules", [(2, 4, 5)])
def test_create_lora_adapter_invalid_type(rank, alpha, target_modules):
model = init_model()
with pytest.raises(TypeError):
create_lora_adapter(
model=model, rank=rank, alpha=alpha, target_modules=target_modules
)
2 changes: 1 addition & 1 deletion test/espnet2/text/test_whisper_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_init_lang_invalid():
)
def test_init_task_invalid():
with pytest.raises(ValueError):
OpenAIWhisperTokenizer("whisper_multilingual", "aaa", "transcribe_aaa")
OpenAIWhisperTokenizer("whisper_multilingual", "zh", "transcribe_aaa")


@pytest.mark.skipif(
Expand Down
6 changes: 5 additions & 1 deletion tools/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ WITH_OMP=ON
all: showenv python conda_packages.done ffmpeg.done sctk sph2pipe check_install

python: activate_python.sh packaging.done espnet.done pytorch.done chainer.done fairscale.done torch_optimizer.done
extra: warp-transducer.done chainer_ctc.done nkf.done moses.done mwerSegmenter.done pesq kenlm.done pyopenjtalk.done py3mmseg.done beamformit.done fairseq.done s3prl.done k2.done transformers.done phonemizer.done longformer.done muskits.done whisper.done rvad_fast.done sounfile_test parallel-wavegan.done
extra: warp-transducer.done chainer_ctc.done nkf.done moses.done mwerSegmenter.done pesq kenlm.done pyopenjtalk.done py3mmseg.done beamformit.done fairseq.done s3prl.done k2.done transformers.done phonemizer.done longformer.done muskits.done whisper.done rvad_fast.done sounfile_test parallel-wavegan.done lora.done

activate_python.sh:
test -f activate_python.sh || { echo "Error: Run ./setup_python.sh or ./setup_anaconda.sh"; exit 1; }
Expand Down Expand Up @@ -215,6 +215,10 @@ whisper.done: espnet.done
. ./activate_python.sh && ./installers/install_whisper.sh
touch whisper.done

lora.done: espnet.done
. ./activate_python.sh && ./installers/install_lora.sh
touch lora.done

k2.done: espnet.done
. ./activate_python.sh && ./installers/install_k2.sh
touch k2.done
Expand Down
51 changes: 51 additions & 0 deletions tools/installers/install_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env bash

set -euo pipefail

MAKE=make

if [ $# != 0 ]; then
echo "Usage: $0"
exit 1;
fi

if ! python -c "import packaging.version" &> /dev/null; then
python3 -m pip install packaging
fi
torch_17_plus=$(python3 <<EOF
from packaging.version import parse as V
import torch
if V(torch.__version__) >= V("1.7.0"):
print("true")
else:
print("false")
EOF
)

python_38_plus=$(python3 <<EOF
from packaging.version import parse as V
import sys
if V("{}.{}.{}".format(*sys.version_info[:3])) >= V("3.8.0"):
print("true")
else:
print("false")
EOF
)

cuda_version=$(python3 <<EOF
import torch
if torch.cuda.is_available():
version=torch.version.cuda.split(".")
# 10.1.aa -> 101
print(version[0] + version[1])
else:
print("")
EOF
)
echo "cuda_version=${cuda_version}"

if "${torch_17_plus}" && "${python_38_plus}"; then
python -m pip install loralib
else
echo "[ERROR] lora does not work with pytorch<1.7.0, python<3.8"
fi

0 comments on commit df4536b

Please sign in to comment.