Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART on CNN/DM : how to train on small GPU ? #1413

Closed
astariul opened this issue Nov 22, 2019 · 21 comments
Closed

BART on CNN/DM : how to train on small GPU ? #1413

astariul opened this issue Nov 22, 2019 · 21 comments

Comments

@astariul
Copy link
Contributor

astariul commented Nov 22, 2019

I'm trying to reproduce the CNN/DM results of BART.
Unfortunately, I don't have access to good GPU. I only have access to 2 GPU with 8GB of memory.


I updated the finetuning cmd accordingly (changing UPDATE_FREQ) for the number of GPU.

But I have issue for the memory of GPU : I tried reducing MAX_TOKENS to 512 in order to make the data fit in my 8GB, but I receive following error :

AssertionError: sentence at index 227550 of size 728 exceeds max_tokens limit of 512!

If I set MAX_TOKENS to 1024, I have a CUDA out of memory error (expected).


What modification do I need to do to be able to finetune the model on small GPU (8GB) ?

@ngoyal2707 @yinhanliu

@yinhanliu
Copy link

Set your update-freq into 32 in your case and try max_tokens 800 (32GPU*2048/800/2GPU=32ish).

If this still doesn't work, then you to modify the code by change --max-target-positions 512 --max-source-positions 512 (this will filter out samples that longer than 512)

also you can train a smaller batch (less update freq but with a longer training)

@astariul
Copy link
Contributor Author

Thanks for the fast answer !

With MAX_TOKENS set to 800 I still have similar issue :

AssertionError: sentence at index 228353 of size 903 exceeds max_tokens limit of 800!

If I understood the paper, it's normal because BART takes 1024 tokens maximum, not 512 like BERT. And in CNN/DM there is a lot of sample with more than 800 tokens..


If I try using --max-target-positions 512 + --max-source-positions 512, I have following error :

RuntimeError: Error(s) in loading state_dict for BARTModel:
size mismatch for encoder.embed_positions.weight: copying a param with shape torch.Size([1026, 1024]) from checkpoint, the shape in current model is torch.Size([514, 1024]).
size mismatch for decoder.embed_positions.weight: copying a param with shape torch.Size([1026, 1024]) from checkpoint, the shape in current model is torch.Size([514, 1024]).


Also I didn't understand what you mean by :

also you can train a smaller batch (less update freq but with a longer training)

Do you mean reducing UPDATE_FREQ and increase TOTAL_NUM_UPDATES ?
Like from UPDATE_FREQ = 64 and TOTAL_NUM_UPDATES = 20000
to UPDATE_FREQ = 32 and TOTAL_NUM_UPDATES = 40000 ?

As far as I understand, it will not change the real batch size, it will just change the accumulated batch size. But for my memory problem, it's the real batch size that matter.

Did I misunderstood something ?

Thanks again for your help !

@yinhanliu
Copy link

Actually, Bart took 512 during pretrain. However, we initialized the model with 1024 positional embedding -- the 512-1024 position embedding doesn't have update during pretrain.

During fine-tune, we use 1024 position embedding -- the 512-1024 start to get update in this phase.

Looks like in your case, 8GB gpu won't even save one single instance.

You have to cut the pretrain model 's position layers from 1024 to 512 (rewrite the pretrain model's state).

then use --max-target-positions 512. this will for sure hurt the performance on cnn dm dataaset --- tons of instances longer than 512.

I did a briefly tuning on cnn/dm. Probably training with a smaller batch size but longer (more than 30000 steps) won't hurt the performance. You can try.

@astariul
Copy link
Contributor Author

astariul commented Nov 22, 2019

Thanks for the kind explanation 👍

I can make the training run by specifying --memory-efficient-fp16 instead of --fp16 (and keeping MAX_TOKENS=1024).

However my GPU does not support FP16...
Aside from potentially slowing down the training process, do you know if using memory efficient FP16 will also reduce the performance on such GPU ?


Just curious, did you try to train a BART using memory efficient FP16 ?
How does the results compare to other version of BART ?

@yinhanliu
Copy link

We did try to use memory efficient FP16 on Roberta. With this setting on Roberta base on bookwiki data ppl is 4.00(fp-memory-efficient) VS 3.90 (fp16). so we only used fp16 on Bart.

@astariul
Copy link
Contributor Author

astariul commented Nov 22, 2019

Thanks for sharing your knowledge. It is helpful !

I'm going to try this path (--memory-efficient-fp16).

I believe results are going to be higher by doing this, than truncating article to 512 tokens, because as you mentioned, a lot of article are longer than 512...

What's your opinion about this ?

@yinhanliu
Copy link

sure. --memory-efficient-fp16 sounds better.

@myleott
Copy link
Contributor

myleott commented Nov 25, 2019

Note that --memory-efficient-fp16 can produce worse results, especially with small batch sizes. You're probably better off either decreasing the batch size and/or training in FP32, since FP16 can actually use more memory since it needs to maintain both an FP32 and FP16 copy of the model.

@wonjininfo
Copy link

Hi @colanim ,
I am wondering if you could share your training results on small GPU.

I've also tried to train the model on multiple 24GB GPUs (the number of machines varying from 2~8). Since my GPUs do not support FP16, I trained the model without --memory-efficient-fp16 or --fp16 flag. To compensate this (regarding memory issue), I set MAX_TOKENS to 1792 and adjusted UPDATE_FREQ according to the number of the machines. It worked well with almost identical performance.

Back to the point, have you tried training the model with MAX_TOKENS set as 512? Would you please share the performance if you do not mind?

ps) Merry Christmas!

@astariul
Copy link
Contributor Author

astariul commented Dec 24, 2019

@wonjininfo

On my side, I trained BART on 4 x 11GB GPU.
As mentioned earlier, 11GB is not enough to fit 1 sample (1024 tokens). So I used --memory-efficient-fp16. Even though my GPU does not support FP16 training, this reduced the required memory by almost half.

But still, it was not enough, so I reduced the MAX_TOKENS from 1024 to 928. With these parameters, I could fit 1 sample in my GPU.

With MAX_TOKENS = 928 and --memory-efficient-fp16, I got following results :

R1 = 43.61
R2 = 20.90
RL = 40.41

It's a bit lower than normal BART, but it was expected due to my parameters.


I didn't try training the model with lower number of MAX_TOKENS, as I could fit 1 sample already with 928.

Merry christmas :)

@zide05
Copy link

zide05 commented Feb 12, 2020

hi @colanim ,i am wondering how many times it takes for you to finetune BART on CNNDM using 4 GPU?

@astariul
Copy link
Contributor Author

@zide05 It took quite long, I don't remember exactly but something like 24 hours

@zide05
Copy link

zide05 commented Feb 12, 2020

@colanim I got this, thank you for your quick reply!

@DevHyung
Copy link

@colanim
Can you show the entire script of parameters used in train?

@astariul
Copy link
Contributor Author

My configuration :

restore-file = ./bart.large/model_928.pt
max-tokens = 928
task = translation
source-lang = source
target-lang = target
layernorm-embedding = True
share-all-embeddings = True
share-decoder-input-output-embed = True
reset-optimizer = True
reset-dataloader = True
reset-meters = True
required-batch-size-multiple = 1
arch = bart_large
criterion = label_smoothed_cross_entropy
label-smoothing = 0.1
dropout = 0.1
attention-dropout = 0.1
weight-decay = 0.01
optimizer = adam
adam-betas = (0.9, 0.999)
adam-eps = 1e-08
clip-norm = 0.1
lr-scheduler = polynomial_decay
lr = 3e-05
total-num-update = 20000
warmup-updates = 500
memory-efficient-fp16 = True
update-freq = 16
skip-invalid-size-inputs-valid-test = True
find-unused-parameters = True
truncate-source = True
max-source-positions = 928
max-target-positions = 928
tensorboard-logdir = ./logs

I had to create a new model where I kept only the first 928 position tokens. I did it with :

import torch
model = torch.load("bart.large/model.pt")

print(model["model"]["encoder.embed_positions.weight"].size())
print(model["model"]["decoder.embed_positions.weight"].size())

model["model"]["encoder.embed_positions.weight"] = model["model"]["encoder.embed_positions.weight"][:930]
model["model"]["decoder.embed_positions.weight"] = model["model"]["decoder.embed_positions.weight"][:930]
torch.save(model, "bart.large/model_928.pt")

@DevHyung
Copy link

@colanim
thank you for your reply :D

May I ask one more question?
Why reduced max_token to 928?
and,
The average length of the dataset is 600 now, but can I reduce it more ?

@astariul
Copy link
Contributor Author

astariul commented Mar 31, 2020

@DevHyung

My GPU have a small memory, so I couldn't even fit batch size of 1 in memory if the sample size is 1024.

By reducing the length to 928, it takes less space in memory and I can fit batch size of 1.


You can reduce it more, but you should expect a score decrease.

@monologue1107
Copy link

Hi, I would like to ask some questions about fine-tuning on CNNDM.
Since my GPU doesn't fit fp16 mode, I want to try in fp32. Is it right I will train in fp32 mode when I remove "--fp16" or "--memory-efficient-fp16" in command and do nothing else?
I have faced situations that I can train in --memory-efficient-fp16 mode but I can't when I delete this command with OOM.

@astariul
Copy link
Contributor Author

astariul commented Nov 6, 2020

Is it right I will train in fp32 mode when I remove "--fp16" or "--memory-efficient-fp16" in command and do nothing else?

Right


I can train in --memory-efficient-fp16 mode but I can't when I delete this command with OOM.

This is expected, training in FP16 mode requires less memory than FP32 mode.
If you're having OOM and still wants to keep FP32 mode, you need to reduce the batch size.

@monologue1107
Copy link

Is it right I will train in fp32 mode when I remove "--fp16" or "--memory-efficient-fp16" in command and do nothing else?

Right

I can train in --memory-efficient-fp16 mode but I can't when I delete this command with OOM.

This is expected, training in FP16 mode requires less memory than FP32 mode.
If you're having OOM and still wants to keep FP32 mode, you need to reduce the batch size.

Thanks for your quick answer! It helps a lot.

@monologue1107
Copy link

monologue1107 commented Nov 7, 2020

hi, now I have faced another problem during training.
I can fine-tune the model at first, even it can train entirely in epoch 1. However, it will become OOM in epoch 2 around 4517/21194. I tried to change scripts like total_num_updates or update_freq several times, but it did't help.
Do you have some idea the OOM problem occurred in the middle part of training and give me some tips? Looking forward for your kindly help.
The log shows like below:

2020-11-06 22:55:35 | WARNING | fairseq.trainer | OOM: Ran out of memory with exception: CUDA out of memory. Tried to allocate 28.00 MiB (GPU 1; 10.92 GiB total capacity; 10.13 GiB already allocated; 13.38 MiB free; 10.33 GiB reserved in total by PyTorch)
...
2020-11-06 22:55:35 | WARNING | fairseq.trainer | attempting to recover from OOM in forward/backward pass
Traceback (most recent call last):
  File "/data/rwd/anaconda3/envs/fairseq/bin/fairseq-train", line 33, in <module>
    sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')())
  File "/data/rwd/fairseq/fairseq_cli/train.py", line 352, in cli_main
    distributed_utils.call_main(args, main)
  File "/data/rwd/fairseq/fairseq/distributed_utils.py", line 254, in call_main
    nprocs=args.distributed_num_procs,
  File "/data/rwd/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
    while not spawn_context.join():
  File "/data/rwd/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 118, in join
    raise Exception(msg)
Exception: 
-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/data/rwd/fairseq/fairseq/distributed_utils.py", line 339, in all_gather_list
    result.append(pickle.loads(bytes(out_buffer[header_size:header_size + enc_size].tolist())))
_pickle.UnpicklingError: unpickling stack underflow
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/data/rwd/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/data/rwd/fairseq/fairseq/distributed_utils.py", line 238, in distributed_main
    main(args, **kwargs)
  File "/data/rwd/fairseq/fairseq_cli/train.py", line 125, in main
    valid_losses, should_stop = train(args, trainer, task, epoch_itr)
  File "/data/rwd/anaconda3/envs/fairseq/lib/python3.6/contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "/data/rwd/fairseq/fairseq_cli/train.py", line 208, in train
    log_output = trainer.train_step(samples)
  File "/data/rwd/anaconda3/envs/fairseq/lib/python3.6/contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "/data/rwd/fairseq/fairseq/trainer.py", line 531, in train_step
    logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch,
  File "/data/rwd/fairseq/fairseq/trainer.py", line 885, in _aggregate_logging_outputs
    logging_outputs, *extra_stats_to_sum, ignore=ignore
  File "/data/rwd/fairseq/fairseq/trainer.py", line 906, in _all_gather_list_sync
    group=self.data_parallel_process_group,
  File "/data/rwd/fairseq/fairseq/distributed_utils.py", line 343, in all_gather_list
    'Unable to unpickle data from other workers. all_gather_list requires all '
Exception: Unable to unpickle data from other workers. all_gather_list requires all workers to enter the function together, so this error usually indicates that the workers have fallen out of sync somehow. Workers can fall out of sync if one of them runs out of memory, or if there are other conditions in your training script that can cause one worker to finish an epoch while other workers are still iterating over their portions of the data. Try rerunning with --ddp-backend=no_c10d and see if that helps.

facebook-github-bot pushed a commit that referenced this issue Nov 9, 2020
Summary: Pull Request resolved: fairinternal/fairseq-py#1413

Test Plan: Imported from OSS

Reviewed By: ngoyal2707

Differential Revision: D24833476

Pulled By: myleott

fbshipit-source-id: 380ea7e05c7b188086b2b10c15120ea6636e0a3e
sshleifer pushed a commit that referenced this issue Apr 7, 2021
Summary: Pull Request resolved: fairinternal/fairseq-py#1413

Test Plan: Imported from OSS

Reviewed By: ngoyal2707

Differential Revision: D24833476

Pulled By: myleott

fbshipit-source-id: 380ea7e05c7b188086b2b10c15120ea6636e0a3e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants