Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast LLMEncoder output to torch.float32, freeze final layer at init. #3900

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions ludwig/encoders/text_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,14 @@ def __init__(self, encoder_config: LLMEncoderConfig = None, **kwargs):

clear_data_cache()

# Because we use the last hidden state as encoder output rather than the logits, the final module of the model
# has input pass through but no gradient update in the backward pass. This can lead to a DDP error. Freezing
# the module prevents this from happening. This is done at initialization to prevent "unused parameters" errors
# from happening when the encoder is used before `prepare_for_training` is called, for example during batch
# size tuning.
out_module = list(self.model.modules())[-1]
out_module.requires_grad_(requires_grad=False)

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
return LLMEncoderConfig
Expand Down Expand Up @@ -2459,13 +2467,6 @@ def prepare_for_training(self):
self.prepare_for_quantized_training()
self.initialize_adapter()

# Because we use the last hidden state as encoder output rather than the logits, the final module of the model
# has input pass through but no gradient update in the backward pass. This can lead to a DDP error. Freezing
# the module prevents this from happening.
if not self.config.adapter:
out_module = list(self.model.modules())[-1]
out_module.requires_grad_(requires_grad=False)

def prepare_for_quantized_training(self):
from peft import prepare_model_for_kbit_training

Expand All @@ -2479,7 +2480,7 @@ def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None):
# Get the hidden state of the last layer and return it as the text encoding
model_outputs = self.model(input_ids=inputs, output_hidden_states=True).hidden_states[-1]

return {ENCODER_OUTPUT: model_outputs}
return {ENCODER_OUTPUT: model_outputs.type(torch.float32)}

def _save_to_state_dict(self, destination: Dict, prefix: str, keep_vars: bool):
# This is called by `torch.nn.Module.state_dict()` under the hood. `state_dict()` does additional work to
Expand Down
12 changes: 12 additions & 0 deletions tests/ludwig/encoders/test_llm_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def test_init(self, encoder_config: LLMEncoderConfig, model_config):
assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length])
assert encoder.output_shape == torch.Size([encoder_config.max_sequence_length, model_config.hidden_size])

# The final layer must not be trainable because it is not used
last_module = list(encoder.model.modules())[-1]
assert all(not p.requires_grad for p in last_module.parameters())

# Test that max sequence length falls back to the context length when too large
context_len = get_context_len(model_config)
cl_config = copy.deepcopy(encoder_config)
Expand All @@ -81,6 +85,10 @@ def test_init(self, encoder_config: LLMEncoderConfig, model_config):
assert encoder.input_shape == torch.Size([context_len])
assert encoder.output_shape == torch.Size([context_len, model_config.hidden_size])

# The final layer must not be trainable because it is not used
last_module = list(encoder.model.modules())[-1]
assert all(not p.requires_grad for p in last_module.parameters())

@pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys()))
def test_init_with_adapter(self, encoder_config: LLMEncoderConfig, adapter: str, model_config):
from peft import PeftModel
Expand All @@ -96,6 +104,10 @@ def test_init_with_adapter(self, encoder_config: LLMEncoderConfig, adapter: str,
assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length])
assert encoder.output_shape == torch.Size([encoder_config.max_sequence_length, model_config.hidden_size])

# The final layer must not be trainable because it is not used
last_module = list(encoder.model.modules())[-1]
assert all(not p.requires_grad for p in last_module.parameters())

@pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys()))
def test_prepare_for_training(self, encoder_config: LLMEncoderConfig, adapter: str):
from peft import PeftModel
Expand Down
Loading