Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing other models for extracting features #2

Closed
jinyongyoo opened this issue Dec 10, 2021 · 1 comment
Closed

Allowing other models for extracting features #2

jinyongyoo opened this issue Dec 10, 2021 · 1 comment

Comments

@jinyongyoo
Copy link

jinyongyoo commented Dec 10, 2021

Hello!

First off, thanks for sharing the code. In the paper, it says that MAUVE works with other embedding models. Therefore, I wanted to try out models such as DialoGPT from Microsoft. But in the code, it limits the model and tokenizer name to "gpt2" family. I think it would better we remove this restriction since others might also want to try out other models.

If you want, I can make a PR to change this.

mauve/src/mauve/utils.py

Lines 25 to 39 in b3c01d5

def get_model(model_name, tokenizer, device_id):
device = get_device_from_arg(device_id)
if 'gpt2' in model_name:
model = AutoModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id).to(device)
model = model.eval()
else:
raise ValueError(f'Unknown model: {model_name}')
return model
def get_tokenizer(model_name='gpt2'):
if 'gpt2' in model_name:
tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
raise ValueError(f'Unknown model: {model_name}')
return tokenizer

@krishnap25
Copy link
Owner

Hi @jinyongyoo, a PR for this would be fantastic. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants