Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use fine-tuned BART for prediction? #3853

Closed
riacheruvu opened this issue Apr 18, 2020 · 56 comments · Fixed by #3866
Closed

How to use fine-tuned BART for prediction? #3853

riacheruvu opened this issue Apr 18, 2020 · 56 comments · Fixed by #3866
Labels
Discussion Discussion on a topic (keep it focused or open a new issue though) wontfix

Comments

@riacheruvu
Copy link

❓ Questions & Help

Details

I fine-tuned the BART model on a custom summarization dataset using the transformers/examples/summarization/bart/finetune.py and transformers/examples/summarization/bart/run_train.sh files in the repository for training (which generated three checkpointepoch=*.ckpt files) and prediction (which generated a .txt file with the test loss scores).

I have two questions on using this model for prediction:

  • How can I modify finetune.py to generate predictions for the test set, in addition to the loss scores? I see some test functions in finetune.py, but I'm not sure how to use these for generating a .txt file with the predictions.

  • How can I load the generated .ckpt files into BartForConditionalGeneration()? A config.json file was not generated along with the checkpoint files; there doesn't seem to be a TFBartForConditionalGeneration; and the convert_tf_checkpoint_to_pytorch.py script in the repo doesn't seem to support BART yet.

Thank you for your time!

@prabalbansal
Copy link

Facing a similar type of issue for T5. @sshleifer

@sshleifer sshleifer linked a pull request Apr 20, 2020 that will close this issue
@sshleifer
Copy link
Contributor

sshleifer commented Apr 20, 2020

The last ckpt file should be loaded into a pl.LightningModule if the --do_predict flag is specified.

There is a bug on master that messes up the loading, but it's fixed in #3866

To use that code immediately, you can run:

git fetch
git checkout examples-summ-do-predict

then your same finetune.py command
with --do_predict (and not --do_train) and the proper --output_dir.

Would love to know if that works!

cc: @ethanjperez.

@sshleifer
Copy link
Contributor

Change is on master, let me know if this solves the problem!

@sshleifer sshleifer reopened this Apr 20, 2020
@prabalbansal
Copy link

Config.json is still not generated while training.

@sshleifer
Copy link
Contributor

    def log_hyperparams(model: pl.LightningModule):
        model.config.save_pretrained(model.hparams.output_dir)
        with open(os.path.join(model.hparams.output_dir, "hparam.json")) as f:
            json.dump(model.hparams, f)

You can call this somewhere in your code, if that's helpful.

@riacheruvu
Copy link
Author

@sshleifer, thank you - I can run ./run_train.sh with the --predict() option successfully.

Regarding my original question, could you please specify how to load the checkpoint into the LighteningModule?

After inspecting transformer_base.py, I think hparams is equivalent to the arguments provided in run_train.sh, so a separate hparams.json file does not need to be generated. Please correct me if I'm wrong.

I am receiving the following error with my current code:

pytorch_lightning.utilities.exceptions.MisconfigurationException: Checkpoint contains hyperparameters but LightningModule's __init__ is missing the argument 'hparams'. Are you loading the correct checkpoint?

I've been using the following code, based on the discussion in Lightning-AI/pytorch-lightning#525 and https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html:


# load model
import pytorch_lightning as pl

from argparse import Namespace

# usually these come from command line args
args = Namespace(data_dir='CE_data/',
model_type='bart',
model_name_or_path='bart-large',
learning_rate='3e-5',
train_batch_size=4,
eval_batch_size=4,
output_dir='transformers/examples/summarization/bart/bart_sum',
do_predict='do_predict')

pretrained_model = pl.LightningModule.load_from_checkpoint('bart_sum/checkpointepoch=2.ckpt', hparams=args)
pretrained_model.eval()

# or for prediction
out = model(inputs['input_ids'])
print(out)
``'

Thank you for your time.

@sshleifer
Copy link
Contributor

sshleifer commented Apr 21, 2020

Seems close to correct.

model = SummarizationTrainer(args)
trainer = generic_train(model, args)
# Optionally, predict on dev set and write to output_dir
if args.do_predict:
# See https://github.com/huggingface/transformers/issues/3159
# pl use this format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
model = model.load_from_checkpoint(checkpoints[-1])
trainer.test(model)

is how we do it @riacheruvu

@prabalbansal
Copy link

prabalbansal commented Apr 21, 2020

@sshleifer

  1. Originally config.json is not created which is a requirement for prediction using fine-tuned model.
    *As shown in the screenshot, I add this code in transformer_base.py in end, config and hparam files are created.
  • Then try to predict with --do_predict, then it gives, ""We assumed '/content/t5' was a path, a model identifier, or url to a directory containing vocabulary files named ['spiece.model'] but couldn't find such vocabulary files at this path or url.""
    What are the requirements to use fine-tuned model?

Screenshot 2020-04-21 at 5 50 10 PM


  1. To predict for a single instance using the fine-tuned model, do I need to specify the test.target file also. I want to predict unknown instance without calculating the loss value.

@riacheruvu
Copy link
Author

riacheruvu commented Apr 22, 2020

@sshleifer, thank you. I've got to the point where I can load the model and generate "outputs" using the forward() function, but I can't decode the outputs - using tokenizer.decoder() results in an error. Should I be using model.generate() instead of model.forward()? If so, it seems SummarizationTrainer does not support model.generate?

Revised code:

        tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
        ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
        inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')['input_ids']
        checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1])
        model.eval()
        model.freeze()
        outputs = model(inputs)
        print(outputs) #Successfully prints two 3D tensors in a tuple
        #print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs]) #Results in ValueError: only one element tensors can be converted to Python scalars
        print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs[0][0]])

The error I'm encountering

Traceback (most recent call last):
  File "finetune.py", line 194, in <module>
    main(args)
  File "finetune.py", line 184, in main
    print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs[1][0]])
  File "finetune.py", line 184, in <listcomp>
    print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs[1][0]])
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/transformers/tokenization_utils.py", line 2141, in decode
    sub_texts.append(self.convert_tokens_to_string(current_sub_text))
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/transformers/tokenization_gpt2.py", line 235, in convert_tokens_to_string
    text = "".join(tokens)
TypeError: sequence item 0: expected str instance, NoneType found

@riacheruvu
Copy link
Author

riacheruvu commented Apr 22, 2020

I found a solution. The model.generate() function is necessary to extract the predictions. I defined a separate function in the SummarizationTrainer() class to use self.model.generate(), and was able to use tokenizer.decoder() on the outputs.

I was encountering issues when using self.tokenizer, so I assume using 'bart-large-cnn' tokenizer for similar custom summarization datasets is okay.

@prabalbansal, I'm not sure if the same method will apply to T5, but it could work for predicting for a single instance, per one of your questions.

My code is below:

    def text_predictions(self, input_ids):
        generated_ids = self.model.generate(
            input_ids=input_ids,
            num_beams=1,
            max_length=80,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True,
        )
        preds = [
            self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            for g in generated_ids
        ]
        return preds
...
    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
        ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
        inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')['input_ids']
        checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1])
        model.eval()
        model.freeze()
        outputs = model.text_predictions(inputs)
        print(outputs)

Thank you for the help, @sshleifer !

@prabalbansal
Copy link

@riacheruvu Thank You. It works for T5 also.

@sangeethabal15
Copy link

sangeethabal15 commented Apr 22, 2020

I followed the steps given in this thread and am still facing an issue. I get an error saying the below when I try to use my fine-tuned model for prediction.

OSError: Can't load '/home/bart/bart_1/checkpointepoch=3.ckpt'. Make sure that:

  • '/home/bart/bart_1/checkpointepoch=3.ckpt' is a correct model identifier listed on 'https://huggingface.co/models'

  • or '/home/bart/bart_1/checkpointepoch=3.ckpt' is the correct path to a directory containing a 'config.json' file

@riacheruvu
Copy link
Author

@sangeethabal15, with my model, files were only generated up till the 2nd epoch. Just to confirm, do you have a checkpointepoch=3.ckpt file?

Are you using the load_from_checkpoint() function?

@sangeethabal15
Copy link

@riacheruvu yes I do have checkpoint=3.ckpt file. I gave my own number of epochs instead of the default 3.

Yes I am using the load_from_checkpoint() function

@riacheruvu
Copy link
Author

Ok. Could you share your code here, @sangeethabal15? It might be easier to help debug.

@sangeethabal15
Copy link

sangeethabal15 commented Apr 23, 2020

@riacheruvu This is my modified code -

# Optionally, predict on dev set and write to output_dir
if args.do_predict:
    # See https://github.com/huggingface/transformers/issues/3159
    # pl use this format to create a checkpoint:
    # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
    # /pytorch_lightning/callbacks/model_checkpoint.py#L169
    examples = [" " + x.rstrip() for x in open("/home/bart/input/test.source").readlines()]
    fout = Path("output.txt").open("w")
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
    model = model.load_from_checkpoint(checkpoints[-1])
    tokenizer = BartTokenizer.from_pretrained("bart-large")

    max_length = 80
    min_length = 5

    for batch in tqdm(list(chunks(examples, 8))):
        dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
        summaries = model.generate(
            input_ids=dct["input_ids"].to(device),
            attention_mask=dct["attention_mask"],
            num_beams=4,
            length_penalty=2.0,
            max_length=max_length + 2,  # +2 from original because we start at step=1 and stop before max_length
            min_length=min_length + 1,  # +1 from original because we start at step=1
            no_repeat_ngram_size=3,
            early_stopping=True,
            decoder_start_token_id=model.config.eos_token_id,
        )
        dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
        for hypothesis in dec:
            fout.write(hypothesis + "\n")
            fout.flush()

@sshleifer sshleifer removed their assignment Apr 23, 2020
@sshleifer sshleifer added the Discussion Discussion on a topic (keep it focused or open a new issue though) label Apr 23, 2020
@riacheruvu
Copy link
Author

Thank you, @sangeethabal15. From the error message you posted earlier, it seems load_from_checkpoint() is expecting a config.json file in the specified directory.

I have a few more debug questions:

  • Do you have the latest version of the code?

  • Does load_from_checkpoint() work with the checkpoint file for the 2nd epoch?

  • If that fails, does your code run successfully if you use the default number of epochs?

@sangeethabal15
Copy link

@riacheruvu

  • I do have the latest version of the code though I have not trained the model on the latest version of it.

  • load_from_checkpoint doesn't work with the 2nd either and expects a config.json file

  • and yes the code runs successfully on the default number of epochs as well.

@prabalbansal
Copy link

import json
def log_hyperparams(model: pl.LightningModule):
    model.config.save_pretrained(model.hparams.output_dir)
    with open(os.path.join(model.hparams.output_dir, "hparam.json"),'w') as f:
        json.dump(model.hparams.__dict__, f)
if args.do_train:
    trainer.fit(model)
    log_hyperparams(model)

@sangeethabal15 Could you add this at the end of transformer_base.py. This works for me.

@sangeethabal15
Copy link

@prabalbansal this is for when I am training my model. Since I have already fine-tuned my model, is there any workaround for test time when I am trying to predict my outputs?

@murugeshmanthiramoorthi

@riacheruvu I am currently working on a Text Summarization problem. I have collected a small dataset of my own. Implementing BART is very easy. I can generate a great summary. But I want to know how to how to use BART model for training my own custom dataset. Can you please kindly help me with this?

I have browsed through internet. But I cannot any find any helpful resources as it is relatively new compared to other Transfer learning models.

@sangeethabal15
Copy link

@murugeshmanthiramoorthi you can just use run_train.sh in the bart folder where you give in your parameters to run the fiinetune.py file

@murugeshmanthiramoorthi

@sangeethabal15 Thank you so much for your reply mam. I am completely new to transfer learning mam. I can't get what you are upto. Can you kindly explain more elaborately or share a resource so that I can follow up?
Thanks in advance mam.

@murugeshmanthiramoorthi

@riacheruvu Thank you so much for your help. But when I proceeded with those steps, I get the error

Traceback (most recent call last):
File "finetune.py", line 10, in
from transformer_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup
ModuleNotFoundError: No module named 'transformer_base'

Do you have any idea solving this issue.

@sangeethabal15
Copy link

@murugeshmanthiramoorthi Follow the below steps and you should be able to run your code.

Important To run the latest versions of the examples, you have to install from source and install some specific requirements for the examples. Execute the following steps in a new virtual environment:

git clone https://github.com/huggingface/transformers
cd transformers
pip install .
pip install -r ./examples/requirements.txt

You can find the above in the readme section of https://github.com/huggingface/transformers/tree/cbbb3c43c55d2d93a156fc80bd12f31ecbac8520/examples

@riacheruvu
Copy link
Author

@murugeshmanthiramoorthi, I agree with @sangeethabal15, I followed the same steps as well.

After installing the dependencies, the code should run without errors about transformer_base - I believe the following line in run_train.sh ensures that:

# Add parent directory to python path to access transformer_base.py export PYTHONPATH=“../../“:”${PYTHONPATH}”

@sangeethabal15
Copy link

sangeethabal15 commented Apr 24, 2020

@sshleifer @riacheruvu I keep running into an error every time I change the beam size, define min_length, skip_ngram, length_penalty during decoding time. Here is a snippet of the error

Traceback (most recent call last):
  File "finetune1.py", line 189, in <module>
    main(args)
  File "finetune1.py", line 176, in main
    outputs = model.text_predictions(inputs)
  File "finetune1.py", line 80, in text_predictions
    length_penalty=1.0,
  File "/home/sangeethabal/.local/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad
    return func(*args, **kwargs)
  File "/home/sangeethabal/.local/lib/python3.7/site-packages/transformers/modeling_utils.py", line 995, in generate
    attention_mask=attention_mask,
  File "/home/sangeethabal/.local/lib/python3.7/site-packages/transformers/modeling_utils.py", line 1338, in _generate_beam_search
    past = self._reorder_cache(past, beam_idx)
  File "/home/sangeethabal/.local/lib/python3.7/site-packages/transformers/modeling_bart.py", line 933, in _reorder_cache
    ((enc_out, enc_mask), decoder_cached_states) = past
ValueError: too many values to unpack (expected 2)

The function where I have defined all of this

def test(self, input_ids):
    generated_ids = self.model.generate(
        input_ids=input_ids,
        num_beams=6,
        max_length=60,
        min_length=4,
        no_repeat_ngram_size=3,
        length_penalty=1.0,
    )
    preds = [
        self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        for g in generated_ids
    ]
    return preds

Any idea how to go about this?

@riacheruvu
Copy link
Author

riacheruvu commented Apr 24, 2020

@sangeethabal15, I have two ideas: Try explicitly setting use_cache=True in the generate() function to see if it resolves the error. If that does not work, could you try specifying the attention_mask parameter? I'm looking at modeling_utils.py and modeling_bart.py, and I think these are the two parameters that are linked to this issue.

Edit: It also seems evaluate_cnn.py demonstrates a similar configuration for the generate() function, although the parameters are slightly different. If the two ideas above don't work, you could try using specifying those parameters to confirm it's not an issue with the values of the parameters that were chosen.

@murugeshmanthiramoorthi

Thank you so much @sangeethabal15 @riacheruvu I got it. Thanks a ton for your help.

@sangeethabal15
Copy link

@sshleifer when I use the exact same parameters as in the evaluate_cnn.py code` I still get the exact same error as below. There seems to be an issue with the values chosen for these parameters specified in evaluate_cnn.py

@riacheruvu I have tried the parameters you specified, same issue.

@sshleifer @riacheruvu I keep running into an error every time I change the beam size, define min_length, skip_ngram, length_penalty during decoding time. Here is a snippet of the error

Traceback (most recent call last):
  File "finetune1.py", line 189, in <module>
    main(args)
  File "finetune1.py", line 176, in main
    outp
@sshleifer @riacheruvu I keep running into an error every time I change the beam size, define min_length, skip_ngram, length_penalty during decoding time. Here is a snippet of the error

Traceback (most recent call last):
File "finetune1.py", line 189, in
main(args)
File "finetune1.py", line 176, in main
outputs = model.text_predictions(inputs)
File "finetune1.py", line 80, in text_predictions
length_penalty=1.0,
File "/home/sangeethabal/.local/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad
return func(*args, **kwargs)
File "/home/sangeethabal/.local/lib/python3.7/site-packages/transformers/modeling_utils.py", line 995, in generate
attention_mask=attention_mask,
File "/home/sangeethabal/.local/lib/python3.7/site-packages/transformers/modeling_utils.py", line 1338, in _generate_beam_search
past = self._reorder_cache(past, beam_idx)
File "/home/sangeethabal/.local/lib/python3.7/site-packages/transformers/modeling_bart.py", line 933, in _reorder_cache
((enc_out, enc_mask), decoder_cached_states) = past
ValueError: too many values to unpack (expected 2)


The function where I have defined all of this

def test(self, input_ids):
generated_ids = self.model.generate(
input_ids=input_ids,
num_beams=6,
max_length=60,
min_length=4,
no_repeat_ngram_size=3,
length_penalty=1.0,
)
preds = [
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for g in generated_ids
]
return preds


Any idea how to go about this?

@sshleifer
Copy link
Contributor

Try passing use_cache=True.
Note that the call here
works. Only differences appear to be attention_mask and use_cache.

@sangeethabal15
Copy link

@sshleifer use_cache by default is set to true in the modeling_utils.py. But when I specify the parameter in my function and run the code it throws the following error

Traceback (most recent call last):
File "finetune1.py", line 191, in
main(args)
File "finetune1.py", line 178, in main
outputs = model.text_predictions(inputs)
File "finetune1.py", line 82, in text_predictions
use_cache=True,
File "/home/sangeethabal/.local/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad
return func(*args, **kwargs)
TypeError: generate() got an unexpected keyword argument 'use_cache'

@sshleifer
Copy link
Contributor

sshleifer commented Apr 24, 2020

This isn't enough information for me to diagnose. My guess with the limited info I have is that you didn't run pip install -e . from transformers/.

What does pip freeze | grep transformers say?

@sangeethabal15
Copy link

@sshleifer I did run pip install -e .

Here is the output of pip freeze | grep transformers

WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.
Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.
transformers==2.8.0

@sshleifer
Copy link
Contributor

Ok, output should look like -e git+git@...
try

git pull
pip install -e .

You should probably also upgrade pip, though that shouldn't matter much.

@ArijRB
Copy link

ArijRB commented May 4, 2020

@riacheruvu hello , do you get <extra_id_0> in your generation output ?

@riacheruvu
Copy link
Author

@ArijRB, hi - I don’t remember seeing that in the output of the model.

@isabelcachola
Copy link

@ArijRB I'm also getting <extra_id_x> generations. Were you able to solve that problem? I'm using a T5 model finetuned on my own dataset.

@claudiatin
Copy link

@riacheruvu How did you load the model in the line 'model.load_from_checkpoint(checkpoints[-1])' of the following code you posted?

    tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
    ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
    inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')['input_ids']
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
    model = model.load_from_checkpoint(checkpoints[-1])
    model.eval()
    model.freeze()
    outputs = model(inputs)
    print(outputs) #Successfully prints two 3D tensors in a tuple
    #print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs]) #Results in ValueError: only one element tensors can be converted to Python scalars
    print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs[0][0]])

Is 'model' an instance of pl.LightningModule? I still have the error message that you got in the previous post:

pytorch_lightning.utilities.exceptions.MisconfigurationException: Checkpoint contains hyperparameters but LightningModule's __init__ is missing the argument 'hparams'. Are you loading the correct checkpoint?

@riacheruvu
Copy link
Author

@claudiatin, model should be defined as an instance of the Summarization trainer class. You will need to have the following code (which is already under main() in fine tune.py):

model = SummarizationTrainer(args)

I am wondering if there is an easier way to go about generating the predictions though. I’ve tried calling the Summarization trainer from another python file so I can separate my prediction and training code files, but ran into some issues, so I needed to stick with using another version of finetune.py running with a clone of the repo. If anyone finds an easier way of accomplishing this or if the HuggingFace team can build this functionality in, that would be great.

@claudiatin
Copy link

@riacheruvu Thank you so much for your answer. I did the same you did, and then I save the .bin file and config.json so I can use 'BartForConditionalGeneration.from_pretrained'. I don't know if it is the best way actually.

# model checkpoints and save the model
model = SummarizationTrainer(args)
model = model.load_from_checkpoint('bart_sum/checkpointepoch=2.ckpt')
torch.save(model.state_dict(), args.output_dir + '/pytorch_model.bin')
model.config.to_json_file(args.output_dir + '/config.json')

# load the fine-tuned model and predict
model = BartForConditionalGeneration.from_pretrained('bart_sum')
summarizer = pipeline('summarization', model=model, tokenizer=tokenizer)
summarizer(ARTICLE_TO_SUMMARIZE, max_length=80, min_length=40)

@riacheruvu
Copy link
Author

riacheruvu commented Jun 4, 2020

@claudiatin, thank you!

Edit: Please ignore my previous response to your newest reply. I just went through the code again, and I was wrong about the inputs to the from_pretrained() function. I apologize for that.

I’ll try using the code block you provided!

@riacheruvu
Copy link
Author

I tried applying the code provided for T5 (I haven't tried it with BART, but I think it'll work successfully per @claudiatin's response) - I am including the results here for documentation and if anyone knows the solution:

from transformers import T5Model, pipeline

model = T5Model.from_pretrained('tfive_sum')
summarizer = pipeline("summarization", model=model, tokenizer="t5-base", framework="tf")
summarizer(ARTICLE_TO_SUMMARIZE, min_length=5, max_length=20)

I run into the error:

AttributeError: You tried to generate sequences with a model that does not have a LM Head.Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )

I've tried importing T5WithLMHeadModel using from transformers import T5WithLMHeadModel and encounter an ImportError: cannot import name 'T5WithLMHeadModel'. I have the most up to date version of the transformers library installed, so I'm not sure if there's something wrong with my setup.

@claudiatin
Copy link

@riacheruvu, don't worry about the previous answer. For the sake of completeness 'bart_sum' is just the default name of the folder where the checkpoints are saved (the line export OUTPUT_DIR_NAME=bart_sum in the run_train.sh). The complete code in my notebook is the following:

%cd examples/summarization/bart

!bash run_train.sh  # run_train.sh script has been changed in order to use a custom dataset

%cd ../..
from lightning_base import BaseTransformer

%cd summarization/bart
from finetune import SummarizationTrainer

import torch
from argparse import Namespace
args = Namespace(adam_epsilon=1e-08, cache_dir='', config_name='', data_dir='../../../../dataset', do_predict=False, do_train=True, eval_batch_size=2, fp16=False, fp16_opt_level='O1', gradient_accumulation_steps=1, learning_rate=3e-05, max_grad_norm=1.0, max_source_length=1024, max_target_length=56, model_name_or_path='bart-large', n_gpu=1, n_tpu_cores=0, num_train_epochs=3, output_dir='bart_sum', seed=42, tokenizer_name='', train_batch_size=2, warmup_steps=0, weight_decay=0.0)

model = SummarizationTrainer(args)
model = model.load_from_checkpoint('bart_sum/checkpointepoch=2.ckpt')
torch.save(model.state_dict(), args.output_dir + '/pytorch_model.bin')
model.config.to_json_file(args.output_dir + '/config.json') # NOW in the bart_sum folder I have checkpoints, pytorch_model.bin and config.json

In another notebook

import torch
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import pipeline

tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
# load the fine-tuned model
model = BartForConditionalGeneration.from_pretrained('transformers/examples/summarization/bart/bart_sum')

The code works but the performances are not good. I think this is because of my dataset:)

@riacheruvu
Copy link
Author

Thank you, @claudiatin, and thank you for sharing your code!

@gmlander
Copy link

gmlander commented Jun 9, 2020

@claudiatin thanks for providing your code. I was able to load a finetuned version of facebook/bart-large-cnn into a pipeline using a far hackier way originally as well as your method.

Problem I'm running into which it sounds like maybe you were as well, is that the predictions from the pipeline after finetuning come out as pure gibberish, so something is being lost in translation. Example below:

'redistributionestonestoneston Hag Hag resultant resultant '
'resultantestoneston redistribution redistribution Hag Hag pressuring '
'pressuring redistribution redistribution alternate alternate alternate '
'pressuring pressuring Hag Hagestoneston Champions Champions Champions '
'redistribution redistribution sil sil sil redistribution redistributionbelt '
'redistribution redistributioniopiopiop redistribution redistribution carved '
'carved carved Hag Hag sil sil pressuring pressuring carved carved '
'compartment compartment compartment redistribution redistribution Voyager '
'Voyager Voyager redistribution redistribution pressuring pressuring '

I used the finetune.py script on the cnn tiny dataset found from the tiny version of the bash script in the examples folder. I even attempted to do this finetuning with nearly 0 (1e-10) learning rate, so that I knew I wasn't significantly changing the model. This still lead to gibberish predictions.

I tried a version where I loaded the pretrained model into the pipeline, saved it using pipeline.model.save_pretrained("path/to/dir") and in a new session, reloaded it using the second portion of the code provided by @claudiatin plus bart_loaded = pipeline(task='summarization', model=model, device = 0, tokenizer=tokenizer)

This worked correctly on predictions, however I did notice a significant change in inference time on the same article I tested (~3 seconds vs ~20 seconds). The only difference I could see vs using the config.json and pytorch_model.bin that came out of save_pretrained() vs the finetune.py checkpoint is that the save_pretrained() config.json contains the added key:value "architectures": ["BartForConditionalGeneration"]. I made this change to the config generated from my finetuned model, but it did not correct the gibberish generation problem.

@sshleifer , any ideas?

@claudiatin
Copy link

@gmlander, yes I have the same gibberish issue. It's not clear to me how to solve it. It would be nice to know that

@stale
Copy link

stale bot commented Aug 9, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Aug 9, 2020
@stale stale bot closed this as completed Aug 16, 2020
@mriganktiwari
Copy link

mriganktiwari commented Jan 31, 2021

I found a solution. The model.generate() function is necessary to extract the predictions. I defined a separate function in the SummarizationTrainer() class to use self.model.generate(), and was able to use tokenizer.decoder() on the outputs.

I was encountering issues when using self.tokenizer, so I assume using 'bart-large-cnn' tokenizer for similar custom summarization datasets is okay.

@prabalbansal, I'm not sure if the same method will apply to T5, but it could work for predicting for a single instance, per one of your questions.

My code is below:

    def text_predictions(self, input_ids):
        generated_ids = self.model.generate(
            input_ids=input_ids,
            num_beams=1,
            max_length=80,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True,
        )
        preds = [
            self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            for g in generated_ids
        ]
        return preds
...
    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
        ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
        inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')['input_ids']
        checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1])
        model.eval()
        model.freeze()
        outputs = model.text_predictions(inputs)
        print(outputs)

Thank you for the help, @sshleifer !

Hi @riacheruvu , I am facing a similar issue while tokenizing a piece of text in the QAGS repo. Line number 133 in https://github.com/W4ngatang/qags/blob/master/qg_utils.py gives me the same error which is due to tokenizer.decode() encountering a NoneType object. Would request if you can help. Please see the error log below:

Traceback (most recent call last):
    File "qg_utils.py", line 169, in <module>
        sys.exit(main(sys.argv[1:]))
    File "qg_utils.py", line 166, in main
        extract_gen_from_fseq_log(args.data_file, args.out_dir)
    File "qg_utils.py", line 142, in extract_gen_from_fseq_log
        gen = tokenizer.decode(tok_ids)
    File "/home/test/miniconda3/envs/qags/lib/python3.6/site-packages/transformers/tokenization_utils_base.py", line 3113, in decode
        *kwargs,
    File "/home/test/miniconda3/envs/qags/lib/python3.6/site-packages/transformers/tokenization_utils.py", line 753, in _decode
        sub_texts.append(self.convert_tokens_to_string(current_sub_text))
    File "/home/test/miniconda3/envs/qags/lib/python3.6/site-packages/transformers/models/gpt2/tokenization_gpt2.py", line 264, in convert_tokens_to_string
        text = "".join(tokens)
TypeError: sequence item 0: expected str instance, NoneType found

@riacheruvu
Copy link
Author

riacheruvu commented Jan 31, 2021

Hi @mriganktiwari, in my case, I needed to use model.generate() as input to tokenizer.decode() to solve this issue. I had an older version of HuggingFace at the time, so this might not be true today.

You could consider first using model.generate() with tok_ids, followed by tokenizer.decode(). I could be wrong, and I'm not sure what the input data_file consists of, but I would try this to see if it helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Discussion Discussion on a topic (keep it focused or open a new issue though) wontfix
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants