Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLNET SQuAD2.0 Fine-Tuning - What May Have Changed? #2651

Closed
ahotrod opened this issue Jan 26, 2020 · 29 comments
Closed

XLNET SQuAD2.0 Fine-Tuning - What May Have Changed? #2651

ahotrod opened this issue Jan 26, 2020 · 29 comments
Labels

Comments

@ahotrod
Copy link
Contributor

ahotrod commented Jan 26, 2020

❓ Questions & Help

I fine-tuned XLNet_large_cased on SQuAD 2.0 last November 2019 with Transformers V2.1.1 yielding satisfactory results:

xlnet_large_squad2_512_bs48
{
  "exact": 82.07698138633876,
  "f1": 85.898874470488,
  "total": 11873,
  "HasAns_exact": 79.60526315789474,
  "HasAns_f1": 87.26000954590184,
  "HasAns_total": 5928,
  "NoAns_exact": 84.54163162321278,
  "NoAns_f1": 84.54163162321278,
  "NoAns_total": 5945,
  "best_exact": 83.22243746315169,
  "best_exact_thresh": -11.112004280090332,
  "best_f1": 86.88541353813282,
  "best_f1_thresh": -11.112004280090332
}

loss_graph

with script:

#!/bin/bash

export OMP_NUM_THREADS=6
RUN_SQUAD_DIR=/media/dn/dssd/nlp/transformers/examples
SQUAD_DIR=${RUN_SQUAD_DIR}/scripts/squad2.0
MODEL_PATH=${RUN_SQUAD_DIR}/runs/xlnet_large_squad2_512_bs48

python -m torch.distributed.launch --nproc_per_node=2 ${RUN_SQUAD_DIR}/run_squad.py \
  --model_type xlnet \
  --model_name_or_path xlnet-large-cased \
  --do_train \
  --train_file ${SQUAD_DIR}/train-v2.0.json \
  --predict_file ${SQUAD_DIR}/dev-v2.0.json \
  --version_2_with_negative \
  --num_train_epochs 3 \
  --learning_rate 3e-5 \
  --adam_epsilon 1e-6 \
  --max_seq_length 512 \
  --doc_stride 128 \
  --save_steps 2000 \
  --per_gpu_train_batch_size 1 \
  --gradient_accumulation_steps 24 \
  --output_dir ${MODEL_PATH}

CUDA_VISIBLE_DEVICES=0 python ${RUN_SQUAD_DIR}/run_squad.py \
  --model_type xlnet \
  --model_name_or_path ${MODEL_PATH} \
  --do_eval \
  --train_file ${SQUAD_DIR}/train-v2.0.json \
  --predict_file ${SQUAD_DIR}/dev-v2.0.json \
  --version_2_with_negative \
  --max_seq_length 512 \
  --per_gpu_eval_batch_size 48 \
  --output_dir ${MODEL_PATH}
$@

After upgrading Transformers to Version 2.3.0 I decided to see if there would be any improvements in the fine-tuning results using the same script above. I got the following results:

xlnet_large_squad2_512_bs48
Results: {
'exact': 45.32131727448834,
'f1': 45.52929325627209,
'total': 11873,
'HasAns_exact': 0.0,
'HasAns_f1': 0.4165483859174251,
'HasAns_total': 5928,
'NoAns_exact': 90.51303616484441,
'NoAns_f1': 90.51303616484441,
'NoAns_total': 5945,
'best_exact': 50.07159100480081,
'best_exact_thresh': 0.0,
'best_f1': 50.07229287739689,
'best_f1_thresh': 0.0}

No learning takes place:
loss

Looking for potential explanation(s)/source(s) for the loss of performance. I have searched Transformer releases and issues for anything pertaining to XLNet with no clues. Are there new fine-tuning hyperparameters I've missed that now need to be assigned, or maybe didn't exist in earlier Transformer versions? Any PyTorch/Tensorflow later version issues? I may have to recreate the Nov 2019 environment for a re-run to verify the earlier results, and then incrementally update Transformers, PyTorch, Tensorflow, etc.?

Current system configuration:
OS: Linux Mint 19.3 based on Ubuntu 18.04. 3 LTS and Linux Kernel 5.0
GPU/CPU: 2 x NVIDIA 1080Ti / Intel i7-8700
Seasonic 1300W Prime Gold Power Supply
CyberPower 1500VA/1000W battery backup

Transformers: 2.3.0
PyTorch: 1.3.0
TensorFlow: 2.0.0
Python: 3.7.5

@adymaharana
Copy link

I have been facing the same problem with RoBERTa finetuning for multiple choice QA datasets. I have even tried going back to the older version of transformers (version 2.1.0 from Oct 2019) and re-running my experiments but I am not able to replicate results from before anymore. The loss just varies within a range of +/- 0.1.

@cronoik
Copy link
Contributor

cronoik commented Jan 27, 2020

Are you using one of the recent versions of run_squad.py? It was quite heavily refactored in december. Maybe there is a mistake now. Can you try it with the run_squad.py of the 2.1.1 release again?

@LysandreJik
Copy link
Member

LysandreJik commented Jan 27, 2020

Could it be related to 96e8350? Before november 29 there was a mistake where the script would only evaluate on 1/N_GPU of the entire evaluation set.

@ahotrod
Copy link
Contributor Author

ahotrod commented Jan 28, 2020

@cronoik good suggestion

Are you using one of the recent versions of run_squad.py? It was quite heavily refactored in december. Maybe there is a mistake now. Can you try it with the run_squad.py of the 2.1.1 release again?

I'm attempting to recreate the environment that existed for the successful fine-turning above that was dated 26Nov2019. I have the .yml file for that environment but after re-creating & re-running the script I get errors of missing "Albert files" and others. Not making much sense since this is using XLNET. I'm keeping after it.

@LysandreJik helpful information

Could it be related to 96e8350? Before november 29 there was a mistake where the script would only evaluate on 1/N_GPU of the entire evaluation set.

Perhaps, but given that the successful run was before 29Nov2019, plus my eval script uses single GPU ( CUDA_VISIBLE_DEVICES=0 ), could [96e8350] be a culprit?

How best to debug my latest, up-to-date environment?
Transformers: 2.3.0
PyTorch: 1.4.0
TensorFlow: 2.1.0
Python: 3.7.6

@ahotrod
Copy link
Contributor Author

ahotrod commented Jan 30, 2020

How about the cached files at .cache/torch/transformers?
I have over 6GB of models cached dating back to November 2019.
Any chance the wrong config.json, spiece.model, model.bin, etc. are getting loaded from the cache which don't match-up with new Transformer code/libraries?
I think it's time to clear out the cache.

Ran single GPU0 on script above with gradient accumulation set to 48, everything else the same. Results and loss were the same. Apparently it is not a distributed processing issue.

Update 30Jan20: Cleared the caches, ran the distributed processing script in the first post above adding --overwrite_cache, same results and losses.

@WilliamNurmi
Copy link

WilliamNurmi commented Feb 3, 2020

Hi guys! I just run into the same issue. I fine-tuned XLNet on the squad 2 trainingset over the weekend, exactly as instructed on the examples page, and got the same inferior results:

python examples/run_squad.py --model_type xlnet --model_name_or_path xlnet-large-cased --do_train --do_eval --version_2_with_negative --train_file ./squad/train-v2.0.json --predict_file ./squad/dev-v2.0.json --learning_rate 3e-5 --num_train_epochs 4 --max_seq_length 384 --doc_stride 128 --output_dir ./xlnet_large_squad2_out/ --per_gpu_eval_batch_size=2 --per_gpu_train_batch_size=2 --save_steps 50000

02/01/2020 00:50:47 - WARNING - __main__ - Process rank: -1, device: cuda, n_gpu: 1, distributed training: False, 16-bits training: False
...
02/03/2020 01:50:51 - INFO - __main__ - Results: {'exact': 45.35500715910048, 'f1': 45.42776379790963, 'total': 11873, 'HasAns_exact': 0.08434547908232119, 'HasAns_f1': 0.23006740428154376, 'HasAns_total': 5928, 'NoAns_exact': 90.49621530698066, 'NoAns_f1': 90.49621530698066, 'NoAns_total': 5945, 'best_exact': 50.07159100480081, 'best_exact_thresh': 0.0, 'best_f1': 50.07159100480081, 'best_f1_thresh': 0.0}

My versions:
transformers: 0aa40e9 (same as v2.4.0)
python 3.6.8
pytorch 1.2.0+cu92

I will proceed to run it again on transformers v2.1.1 and report back whether the old code still works for XLNet.

@LysandreJik
Copy link
Member

Hi @WilliamNurmi, thank you for taking the time to do this. Do you mind making sure that you're using SequentialSampler in your evaluation, even when running against transformers v2.1.1? This affects the evaluation, which should be the same as the one you did in v2.4.0.

This should only affect setups with more than 1 gpu and this does not seem to be your case, but if it is, it would be great to update the sampler.

@WilliamNurmi
Copy link

Hi @LysandreJik, I'm indeed using only 1 gpu, so we should be good there!

@WilliamNurmi
Copy link

No dice with XLNet on v2.1.1. I used the same parameters as @ahotrod except for slight changes for gradient_accumulation_steps (not used), max_seq_length (368) and per_gpu_train_batch_size (1).

python examples/run_squad.py --model_type xlnet --model_name_or_path xlnet-large-cased --do_train --do_eval --version_2_with_negative --train_file ./squad/train-v2.0.json --predict_file ./squad/dev-v2.0.json --learning_rate 3e-5 --num_train_epochs 3 --max_seq_length 368 --doc_stride 128 --output_dir ./xlnet_cased_finetuned_squad/ --per_gpu_eval_batch_size=2 --per_gpu_train_batch_size=2 --save_steps 63333 --logging_steps 63333 --evaluate_during_training --adam_epsilon 1e-6

Inferior results:

{ "exact": 37.45472921755243, "f1": 41.95943914787417, "total": 11873, "HasAns_exact": 70.05735492577598, "HasAns_f1": 79.07969315160429, "HasAns_total": 5928, "NoAns_exact": 4.945332211942809, "NoAns_f1": 4.945332211942809, "NoAns_total": 5945, "best_exact": 50.07159100480081, "best_exact_thresh": 0.0, "best_f1": 50.07159100480081, "best_f1_thresh": 0.0 }

I tried to mimic the setup at the time with the following versions:
Transformers v2.1.1
Python 3.6.9 Pytorch 1.3.1`

Interestingly the first run with v2.4.0 gave an answer to only 5% of the test questions, while this v2.1.1 version dared to an answer 90% of the questions.

Does anyone have any idea what could have changed since last November that completely broke the SQuAD2 training? Could it be the files (pretrained network, tokenization, hyperparameters etc) that transformers lib is downloading at the beginning of the training ?

@cronoik
Copy link
Contributor

cronoik commented Feb 5, 2020

Is the run_squad.py the 2.1.1 version?

@WilliamNurmi
Copy link

WilliamNurmi commented Feb 5, 2020

@cronoik, yeah. I'm installing from source and I re-cloned the whole repo.

I didn't realize to clean ~/.cache/torch/transformers/ though, but @ahotrod seems to have tried that with no luck.

EDIT: and looking at the cache file timestamps, it seems it has downloaded new files anyways.

@WilliamNurmi
Copy link

WilliamNurmi commented Feb 5, 2020

As noted on other issues, plain old Bert is working better, so the issue seems to be specific to XLNet, RoBERTa and ALBERT(?).

On transformers 2.4.0
python examples/run_squad.py --model_type=bert --model_name_or_path=bert-base-uncased --do_train --do_eval --do_lower_case --version_2_with_negative --train_file=./squad/train-v2.0.json --predict_file=./squad/dev-v2.0.json --per_gpu_train_batch_size=12 --learning_rate=3e-5 --num_train_epochs=2.0 --max_seq_length=384 --doc_stride=128 --save_steps=20000 --output_dir=bert_out --overwrite_output_dir

Results: {'exact': 73.04809231028383, 'f1': 76.29336127902307, 'total': 11873, 'HasAns_exact': 71.99730094466936, 'HasAns_f1': 78.49714549018896, 'HasAns_total': 5928, 'NoAns_exact': 74.09587888982338, 'NoAns_f1': 74.09587888982338, 'NoAns_total': 5945, 'best_exact': 73.04809231028383, 'best_exact_thresh': 0.0, 'best_f1': 76.29336127902297, 'best_f1_thresh': 0.0}

@ahotrod
Copy link
Contributor Author

ahotrod commented Feb 5, 2020

After nearly two weeks of unsuccessful varied XLNet fine-tunes, I gave-up and switched to fine-tuning ALBERT for an alternative model:

albert_xxlargev1_sqd2_512_bs48 results:
{'exact': 85.65653162637918,
 'f1': 89.260458954177,
 'total': 11873,
 'HasAns_exact': 82.6417004048583,
 'HasAns_f1': 89.85989020967376,
 'HasAns_total': 5928,
 'NoAns_exact': 88.66274179983179,
 'NoAns_f1': 88.66274179983179,
 'NoAns_total': 5945,
 'best_exact': 85.65653162637918,
 'best_exact_thresh': 0.0,
 'best_f1': 89.2604589541768,
 'best_f1_thresh': 0.0}

Ahhh, the beauty and flexibility of Transformers, out with one model and in with another.
My QA app is performing well with ALBERT.

Current system configuration:
OS: Linux Mint 19.3 based on Ubuntu 18.04. 3 LTS and Linux Kernel 5.0
GPU/CPU: 2 x NVIDIA 1080Ti / Intel i7-8700
Transformers: 2.3.0
PyTorch: 1.4.0
TensorFlow: 2.1.0
Python: 3.7.6

@WilliamNurmi
Copy link

WilliamNurmi commented Feb 6, 2020

I was originally going for ALBERT, but tried XLNet instead because many people seemed to be reporting that ALBERT doesn't work (#202, #2609). But looking into it more, it looks like it is only the v2 model that doesn't work!

@WilliamNurmi
Copy link

WilliamNurmi commented Feb 6, 2020

After nearly two weeks of unsuccessful varied XLNet fine-tunes, I gave-up and switched to fine-tuning ALBERT for an alternative model:

albert_xxlargev1_sqd2_512_bs48 results:
{'exact': 85.65653162637918,
 'f1': 89.260458954177,

Nice results @ahotrod! Better than what you got in Dec:
albert_xxlargev1_squad2_512_bs48: "exact": 83.65198349195654, "f1": 87.4736247587816,

Could you share the hyper-parameters you used?
And ellaborate a bit whether you train it with run_squad.py or some custom code? run_squad.py doesn't seem allow us to apply 0.1 dropout for the classification layer as suggested in the paper.

@ahotrod
Copy link
Contributor Author

ahotrod commented Feb 6, 2020

@WilliamNurmi thanks for your feedback

When Google Research released their v2 of ALBERT LMs they stated that xxlarge-v1 outperforms xxlarge-v2 and have a discussion as to why: https://github.com/google-research/ALBERT. So I've stuck with v1 for that reason plus the "teething" issues that have been associated with v2 LMs.

Yes, seems there have been transfomers revisions positively impacting ALBERT SQuAD 2.0 fine-tuning since my results Dec19 as you noted. I think including --max_steps 8144 & --warmup_steps 814 in my script produced the improvement listed above.

Additional ALBERT & transformers refinements, hopefully significant, are in transformers v2.4.1: classifier dropout and gelu_new, thanks to @peteriz & @LysandreJik #2679. I am 18 hours in to a 67 hour fine-tune & eval of albert_xxlargev1_sqd2_512_bs48 with script below using transformers v2.4.1. I will post results when processing is complete.

BTW the heat produced from my hardware-challenged computer, hotrod, is a welcome tuning by-product for my winter office, summer not so much. Hoping for a NVIDIA Ampere upgrade before this summer's heat.

My fine-tuning has been with transformer's run_squad.py not custom code. Here's my latest script:

albert_xxlargev1_sqd2_512_bs48.sh:

#!/bin/bash

export OMP_NUM_THREADS=8
RUN_SQUAD_DIR=/media/dn/dssd/nlp/transformers/examples
SQUAD_DIR=${RUN_SQUAD_DIR}/scripts/squad2.0
MODEL_PATH=${RUN_SQUAD_DIR}/runs/albert_xxlargev1_squad2_512_bs48

python -m torch.distributed.launch --nproc_per_node=2 ${RUN_SQUAD_DIR}/run_squad.py \
  --model_type albert \
  --model_name_or_path albert-xxlarge-v1 \
  --do_train \
  --train_file ${SQUAD_DIR}/train-v2.0.json \
  --predict_file ${SQUAD_DIR}/dev-v2.0.json \
  --version_2_with_negative \
  --num_train_epochs 3 \
  --max_steps 8144 \
  --warmup_steps 814 \
  --do_lower_case \
  --learning_rate 3e-5 \
  --max_seq_length 512 \
  --doc_stride 128 \
  --save_steps 1000 \
  --per_gpu_train_batch_size 1 \
  --gradient_accumulation_steps 24 \
  --overwrite_cache \
  --logging_steps 100 \
  --threads 8 \
  --output_dir ${MODEL_PATH}

CUDA_VISIBLE_DEVICES=0 python ${RUN_SQUAD_DIR}/run_squad.py \
  --model_type albert \
  --model_name_or_path ${MODEL_PATH} \
  --do_eval \
  --train_file ${SQUAD_DIR}/train-v2.0.json \
  --predict_file ${SQUAD_DIR}/dev-v2.0.json \
  --version_2_with_negative \
  --do_lower_case \
  --max_seq_length 512 \
  --per_gpu_eval_batch_size 24 \
  --eval_all_checkpoints \
  --overwrite_output_dir \
  --output_dir ${MODEL_PATH}
$@

@WilliamNurmi
Copy link

Thanks for all the details @ahotrod, I had missed the fact that classifier dropout had just been added! I restarted my run with v2.4.1. Loss seems to be going down nicely, so far so good.

It's gonna be 6 days for me since I'm on a single Ti 1080. I'm gonna have to look for some new hardware / instances soon as well. Any bigger model or sequence length and I couldn't fit a single batch on this GPU anymore :D

Looking forward to the sneak peak of the results when your run finishes!

@knuser
Copy link

knuser commented Feb 7, 2020

@ahotrod could you consider sharing trained ALBERT SQUAD trained model on https://huggingface.co/models?

@ahotrod
Copy link
Contributor Author

ahotrod commented Feb 7, 2020

@ahotrod could you consider sharing trained ALBERT SQUAD trained model on https://huggingface.co/models?

@knuser Absolutely, I signed-up some time ago with that intent but have yet to contribute.
I'm 26 hours from this v2.4.1 albert_xxlargev1_sqd2_512_bs48 run completion and afterwards will share the best run to date.

FYI, 11 question inferencing/prediction with this 512 max_seq_length xxlarge ALBERT model takes 37 seconds CPU and 5 secs single GPU w/large batches on my computer, hotrod, described above.

BTW, sharing can definitely save some energy & lower the carbon footprint. As an example my office electric bill doubled last month from just under $100 to over $200 with nearly constant hotrod fine-tuning. Perhaps the gas heater didn't need to fire-up as often though. ;-]

@ahotrod
Copy link
Contributor Author

ahotrod commented Feb 8, 2020

@WilliamNurmi @knuser :

Fine-tuning the albert_xxlargev1_sqd2_512_bs48 script with Transformers 2.4.1 yielded the following results:

{'exact': 85.47123726101238,
 'f1': 89.0856118938743,
 'total': 11873,
 'HasAns_exact': 82.11875843454791,
 'HasAns_f1': 89.35787280971171,
 'HasAns_total': 5928,
 'NoAns_exact': 88.81412952060555,
 'NoAns_f1': 88.81412952060555,
 'NoAns_total': 5945,
 'best_exact': 85.46281478985935,
 'best_exact_thresh': 0.0,
 'best_f1': 89.07718942272103,
 'best_f1_thresh': 0.0}

which is no improvement over fine-tuning the same script with Transformers 2.3.0

My best model to date is now posted at: https://huggingface.co/ahotrod/albert_xxlargev1_squad2_512
You can access this albert_xxlargev1_sqd2_512 fine-tuned model with:

config_class, model_class, tokenizer_class = \
            AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer

model_name_or_path = "ahotrod/albert_xxlargev1_squad2_512"
config = config_class.from_pretrained(model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True)
model = model_class.from_pretrained(model_name_or_path, config=config)

The AutoModels: (AutoConfig, AutoTokenizer & AutoModel) should also work, however I
have yet to use them.

Hope this furthers your efforts!

@LysandreJik
Copy link
Member

Hi guys, thanks for the great discussion. I've been trying to reproduce the XLNet fine-tuning myself, but have failed to do so so far. I stumbled upon a few issues along the way, mostly related to the padding side.

There was an issue that I fixed this morning related to the tokens that were used for evaluation, which were not correctly computed. I updated that in 125a75a, however it does not improve the accuracy.

I'm still actively working on it and will let you know as I progress (it is quite a lengthy process as a finetuning requires a full-day of computing on my machine).

@WilliamNurmi
Copy link

Hi @LysandreJik, thanks for hunting the bugs! It's going to be a great help for many people.

I don't know the details of the remaining bugs, but at least the bugs I encountered were so bad that I think you should see whether or not it works very quickly after starting fine-tuning by checking if the loss is decreasing on tensorboard.

@yilisg
Copy link

yilisg commented Feb 28, 2020

I can also confirm the issue after fine-tuning xlnet-large-cased on Squad 2.0 for 1 epoch. The F1 score is 46.53 although the NoAns_F1 was 89.05, probably because the model is predicting so many blanks (most with "start_log_prob": -1000000.0, "end_log_prob": -1000000.0) while HasAns_exact is close to 0.

Not sure if it is related to the CLS token position mentioned in #947 and #1088. But it might be specific to the unanswerable questions in Squad 2.0. Hopefully the bug will be found and fixed soon :-)

Transformers: 2.5.1
PyTorch: 1.4.0
Python: 3.8.1

@elgeish
Copy link
Contributor

elgeish commented Mar 2, 2020

@ahotrod I saw you're using a different eval script (run_squad_II.py) for your model at https://huggingface.co/ahotrod/xlnet_large_squad2_512 — have you figured out what was wrong with run_squad.py? Thanks!

@ahotrod
Copy link
Contributor Author

ahotrod commented Mar 2, 2020

@elgeish - good eye on my eval script using run_squad_II.py, as posted in my model card. Unfortunately I have not figured-out what is wrong with training using the latest run_squad.py versions as outlined in this issue.

My https://huggingface.co/ahotrod/xlnet_large_squad2_512 model is from Nov 2019, same as the successful fine-tuned model described in my first post above. run_squad_II.py contained experimental code I was working on at the time trying to overcome the multi-GPU distributed processing eval limitation. Fortunately, when run_squad_II.py evals were run single GPU (CUDA_VISIBLE_DEVICES=0), evals were the same as the Transformers v2.1.1 original run_squad.py, as I did not modify that portion of the code. I failed to change that eval script back to run_squad.py, but again since the run_squad_II.py eval in that script was run single GPU, it performed the same eval as the original. Sorry for the confusion.

@elgeish
Copy link
Contributor

elgeish commented Mar 2, 2020

@ahotrod thanks for the explanation!

@stale
Copy link

stale bot commented May 1, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label May 1, 2020
@stale stale bot closed this as completed May 8, 2020
@brgsk
Copy link

brgsk commented Oct 11, 2021

Any update on this issue? I am facing same issue when fine tuning custom RoBERTa.
Cheers

@brgsk
Copy link

brgsk commented Oct 11, 2021

I'm on 4.4.0dev

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

9 participants