Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request/Help] BLEURT model -> PyTorch #224

Closed
adamwlev opened this issue May 30, 2020 · 6 comments
Closed

[Feature Request/Help] BLEURT model -> PyTorch #224

adamwlev opened this issue May 30, 2020 · 6 comments
Assignees
Labels
enhancement New feature or request

Comments

@adamwlev
Copy link

Hi, I am interested in porting google research's new BLEURT learned metric to PyTorch (because I wish to do something experimental with language generation and backpropping through BLEURT). I noticed that you guys don't have it yet so I am partly just asking if you plan to add it (@thomwolf said you want to do so on Twitter).

I had a go of just like manually using the checkpoint that they publish which includes the weights. It seems like the architecture is exactly aligned with the out-of-the-box BertModel in transformers just with a single linear layer on top of the CLS embedding. I loaded all the weights to the PyTorch model but I am not able to get the same numbers as the BLEURT package's python api. Here is my colab notebook where I tried https://colab.research.google.com/drive/1Bfced531EvQP_CpFvxwxNl25Pj6ptylY?usp=sharing . If you have any pointers on what might be going wrong that would be much appreciated!

Thank you muchly!

@adamwlev adamwlev changed the title [Feature Request/Help] [Feature Request/Help] BLEURT model -> PyTorch May 30, 2020
@yjernite yjernite self-assigned this Jun 2, 2020
@thomwolf thomwolf added the enhancement New feature or request label Jun 20, 2020
@manikbhandari
Copy link

Is there any update on this?

Thanks!

@ohmeow
Copy link

ohmeow commented Dec 27, 2020

Hitting this error when using bleurt with PyTorch ...

UnrecognizedFlagError: Unknown command line flag 'f'

... and I'm assuming because it was built for TF specifically. Is there a way to use this metric in PyTorch?

@yjernite
Copy link
Member

yjernite commented Jan 4, 2021

We currently provide a wrapper on the TensorFlow implementation: https://huggingface.co/metrics/bleurt

We have long term plans to better handle model-based metrics, but they probably won't be implemented right away

@adamwlev it would still be cool to add the BLEURT checkpoints to the transformers repo if you're interested, but that would best be discussed there :)

closing for now

@yjernite yjernite closed this as completed Jan 4, 2021
@LoraIpsum
Copy link

Hi there. We ran into the same problem this year (converting BLEURT to PyTorch) and thanks to @adamwlev found his colab notebook which didn't work but served as a good starting point. Finally, we made it work by doing just two simple conceptual fixes:

  1. Transposing 'kernel' layers instead of 'dense' ones when copying params from the original model;
  2. Taking pooler_output as a cls_state in forward function of the BleurtModel class.

Plus few minor syntactical fixes for the outdated parts. The result is still not exactly the same, but is very close to the expected one (1.0483 vs 1.0474).

Find the fixed version here (fixes are commented): https://colab.research.google.com/drive/1KsCUkFW45d5_ROSv2aHtXgeBa2Z98r03?usp=sharing

@lucadiliello
Copy link
Contributor

I created a new model based on transformers that can load every BLEURT checkpoints released so far. https://github.com/lucadiliello/bleurt-pytorch

@vaiibhavgupta
Copy link

@LoraIpsum Thanks for sharing your work here. However, I'm unable to reproduce the results. That's strange because you are. FYI, I am trying to convert a finetuned BLEURT to PyTorch. Any suggestions on how I can reproduce results?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

8 participants