Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable AdaLoRA tests for LLM adapter #3896

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,12 +1209,22 @@ def llm_encoder_config() -> dict[str, Any]:

@pytest.mark.parametrize(
"adapter,quantization",
[(None, None), ("lora", None), ("lora", {"bits": 4}), ("lora", {"bits": 8})],
ids=["FFT", "LoRA", "LoRA 4-bit", "LoRA 8-bit"],
[
(None, None),
("lora", None),
("lora", {"bits": 4}),
("lora", {"bits": 8}),
("adalora", None),
("adalora", {"bits": 4}),
("adalora", {"bits": 8}),
],
ids=["FFT", "LoRA", "LoRA 4-bit", "LoRA 8-bit", "AdaLoRA", "AdaLoRA 4-bit", "AdaLoRA 8-bit"],
)
def test_llm_encoding(llm_encoder_config, adapter, quantization, tmpdir):
if (
_finetune_strategy_requires_cuda(finetune_strategy_name=adapter, quantization_args=quantization)
_finetune_strategy_requires_cuda(
finetune_strategy_name="lora" if adapter else None, quantization_args=quantization
)
and not (torch.cuda.is_available() and torch.cuda.device_count()) > 0
):
pytest.skip("Skip: quantization requires GPU and none are available.")
Expand Down
53 changes: 42 additions & 11 deletions tests/ludwig/encoders/test_llm_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

from ludwig.encoders.text_encoders import LLMEncoder
from ludwig.schema.encoders.text_encoders import LLMEncoderConfig
from ludwig.schema.llms.peft import BaseAdapterConfig, LoraConfig
from ludwig.schema.llms.peft import AdaloraConfig, BaseAdapterConfig, LoraConfig
from ludwig.utils.llm_utils import get_context_len

# Mapping of adapter types to test against and their respective config objects.
ADAPTER_CONFIG_MAP = {"lora": LoraConfig}
ADAPTER_CONFIG_MAP = {
"lora": LoraConfig,
"adalora": AdaloraConfig,
}


@pytest.fixture()
Expand Down Expand Up @@ -58,13 +61,30 @@ def create_encoder_config_with_adapter(
new_config.adapter = ADAPTER_CONFIG_MAP[adapter](**kwargs)
return new_config

def adapter_param_name_prefix(self, adapter: str) -> str:
"""Get the PEFT paramter name prefix for a given adapter type.

Args:
adapter: A valid config value for `adapter.type`

Returns:
The PEFT-applied prefix for the adapter's parameter names.

Raises:
KeyError: raised when the provided adapter name is not valid for LLMEncoder.
"""
return LLMEncoder.ADAPTER_PARAM_NAME_PREFIX[adapter]

def test_init(self, encoder_config: LLMEncoderConfig, model_config):
# Test initializing without an adapter
encoder = LLMEncoder(encoder_config=encoder_config)

assert encoder.model_name == encoder_config.base_model
assert isinstance(encoder.model, PreTrainedModel)
assert all(map(lambda k: "lora_" not in k, encoder.state_dict().keys())) # Check adapter was not initialized
# Check adapter was not initialized
for k in ADAPTER_CONFIG_MAP.keys():
prefix = self.adapter_param_name_prefix(k)
assert all(map(lambda k: prefix not in k, encoder.state_dict().keys()))
assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length])
assert encoder.output_shape == torch.Size([encoder_config.max_sequence_length, model_config.hidden_size])

Expand All @@ -77,7 +97,10 @@ def test_init(self, encoder_config: LLMEncoderConfig, model_config):

assert encoder.model_name == encoder_config.base_model
assert isinstance(encoder.model, PreTrainedModel)
assert all(map(lambda k: "lora_" not in k, encoder.state_dict().keys())) # Check adapter was not initialized
# Check adapter was not initialized
for k in ADAPTER_CONFIG_MAP.keys():
prefix = self.adapter_param_name_prefix(k)
assert all(map(lambda k: prefix not in k, encoder.state_dict().keys()))
assert encoder.input_shape == torch.Size([context_len])
assert encoder.output_shape == torch.Size([context_len, model_config.hidden_size])

Expand All @@ -87,10 +110,11 @@ def test_init_with_adapter(self, encoder_config: LLMEncoderConfig, adapter: str,

encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)
prefix = self.adapter_param_name_prefix(adapter)

# The adapter should not be initialized until `prepare_for_training` is called
assert not isinstance(encoder.model, PeftModel)
assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))
assert not any(map(lambda k: prefix in k, encoder.state_dict().keys()))

assert encoder.model_name == encoder_config.base_model
assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length])
Expand All @@ -102,31 +126,36 @@ def test_prepare_for_training(self, encoder_config: LLMEncoderConfig, adapter: s

encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)
prefix = self.adapter_param_name_prefix(adapter)

# The adapter should not be initialized until `prepare_for_training` is called
assert not isinstance(encoder.model, PeftModel)
assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))
assert not any(map(lambda k: prefix in k, encoder.state_dict().keys()))

# Initialize the adapter
encoder.prepare_for_training()

# At this point, the adapter should be initialized and the state dict should contain adapter parameters
assert isinstance(encoder.model, PeftModel)
assert any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))
assert any(map(lambda k: prefix in k, encoder.state_dict().keys()))

def test_save_to_state_dict(self, encoder_config: LLMEncoderConfig, tmpdir):
# With no adapter, the state dict should only contain the model parameters
encoder = LLMEncoder(encoder_config=encoder_config)
assert all(map(lambda k: "lora_" not in k, encoder.state_dict().keys()))
# Check adapter was not initialized
for k in ADAPTER_CONFIG_MAP.keys():
prefix = self.adapter_param_name_prefix(k)
assert all(map(lambda k: prefix not in k, encoder.state_dict().keys()))

@pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys()))
def test_save_to_state_dict_adapter(self, encoder_config: LLMEncoderConfig, adapter: str, tmpdir):
# With an adapter, the state dict should only contain adapter parameters
encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)
prefix = self.adapter_param_name_prefix(adapter)
# Initialize the adapters
encoder.prepare_for_training()
assert all(map(lambda k: "lora_" in k, encoder.state_dict().keys()))
assert all(map(lambda k: prefix in k, encoder.state_dict().keys()))

@pytest.mark.parametrize("wrap", [False, True], ids=["no_wrapper", "with_wrapper"])
def test_load_from_state_dict(self, encoder_config: LLMEncoderConfig, wrap: bool):
Expand Down Expand Up @@ -164,6 +193,8 @@ def weights_init(m):
if hasattr(m, "weight") and m.weight.ndim > 1:
torch.nn.init.xavier_uniform_(m.weight.data)

prefix = self.adapter_param_name_prefix(adapter)

# Update the config with an adapter
encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)

Expand All @@ -183,8 +214,8 @@ def weights_init(m):

encoder1_sd = encoder1.state_dict()
encoder2_sd = encoder2.state_dict()
adapter_keys = [k for k in encoder1_sd.keys() if "lora_" in k and "weight" in k]
model_keys = [k for k in encoder1_sd.keys() if "lora_" not in k]
adapter_keys = [k for k in encoder1_sd.keys() if prefix in k and "weight" in k]
model_keys = [k for k in encoder1_sd.keys() if prefix not in k]

# The LoRA weights should no longer be equal
assert all(map(lambda k: not torch.equal(encoder1_sd[k], encoder2_sd[k]), adapter_keys))
Expand Down
Loading