Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add use_pretrained attribute for AutoTransformers #3498

Merged
merged 6 commits into from
Aug 5, 2023
Merged

Conversation

arnavgarg1
Copy link
Contributor

Fixes the following error:

'AutoTransformerConfig' object has no attribute 'use_pretrained'

when trying to train an custom transformer model from HF using a config that looks like this:

encoder:
      type: auto_transformer
      trainable: false
      pretrained_model_name_or_path: huggyllama/llama-7b
    preprocessing:
      tokenizer: space_punct
      max_sequence_length: null

@arnavgarg1
Copy link
Contributor Author

Will add a test

@arnavgarg1
Copy link
Contributor Author

Okay this is probably wrong, going to close and discuss this first.

@arnavgarg1 arnavgarg1 marked this pull request as draft August 4, 2023 18:48
@arnavgarg1 arnavgarg1 closed this Aug 4, 2023
@arnavgarg1 arnavgarg1 reopened this Aug 4, 2023
@arnavgarg1 arnavgarg1 marked this pull request as ready for review August 4, 2023 20:14
@@ -3092,6 +3097,10 @@ def module_name():
description=ENCODER_METADATA["AutoTransformer"]["type"].long_description,
)

# Always set this to True since we always want to use the pretrained weights
# We don't currently support training from scratch for AutoTransformers
use_pretrained: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this a property so the user could never modify it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@property
def use_pretrained(self) -> bool:
    return True

@@ -292,3 +289,42 @@ def test_tfidf_encoder(vocab_size: int):
inputs = torch.randint(2, (batch_size, sequence_length)).to(DEVICE)
outputs = text_encoder(inputs)
assert outputs[ENCODER_OUTPUT].shape[1:] == text_encoder.output_shape


def test_hf_auto_transformer_use_pretrained(tmpdir, csv_filename):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The True case would be tested implicitly elsewhere, correct? If not, maybe we could parametrize the test with both cases.

@github-actions
Copy link

github-actions bot commented Aug 4, 2023

Unit Test Results

  6 files  ±       0    6 suites  ±0   1h 3m 51s ⏱️ - 8m 28s
34 tests  - 2 747  29 ✔️  - 2 739    5 💤  - 7  0  - 1 
88 runs   - 2 736  72 ✔️  - 2 730  16 💤  - 5  0  - 1 

Results for commit 46f6c5a. ± Comparison against base commit 6b9a5e4.

♻️ This comment has been updated with latest results.

text_feature(
encoder={
"type": "auto_transformer",
"use_pretrained": False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this should be an error if we were more strict with our config validation rules. We should instead just leave this out of the config.

@arnavgarg1 arnavgarg1 merged commit 0c5f251 into master Aug 5, 2023
16 checks passed
@arnavgarg1 arnavgarg1 deleted the use_pretrained branch August 5, 2023 21:20
dennisrall pushed a commit to dennisrall/ludwig that referenced this pull request Aug 9, 2023
dennisrall pushed a commit to dennisrall/ludwig that referenced this pull request Aug 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants