Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] [pipeline parallel] t5 - experiment #9765

Closed
wants to merge 14 commits into from

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Jan 24, 2021

This PR is not ready for reviews.

I'm putting it up primarily for those who want an early preview of a possible Pipeline solution. @PeterAJansen, you wanted to see if you could get it working with 4x 40GB rig and t5-11b. Please give it a try.


Intention

We want to replace the naive model parallel (MP) implementation with a more efficient pipeline parallel (PP) implementation, which takes advantage of all participating gpus, and not not having one gpu run and the rest idling which is the case with the naive MP.

To give you a visual from the GPipe paper,

mp-pp

You will find a new argument chunks, which is how many pipeline stages you want to add, in the 2nd diagram of the image oabove you can see that chunks=4.

So with chunks=1 you get the naive mp, but it'd be even slower than the naive MP because of the RPC overhead.

Overview

Porting t5 to Pipeline Parallelism proved to be a study in hacking, due to the very restrictive original pipeline interface which only allows tensors or tuples of tensors as input/output arguments in forward, and in transformers we have a ton of very complex variables to pass to forward and return from it.

We are trying to change the Pipeline design to be much more user-friendly: pytorch/pytorch#50693

This implementation tries to take advantage of 2 natural stacks, so I implemented it as 2 pipes:

T5ForConditionalGeneration->
   T5Stack(encoder)->Pipe(Sequential([T5StackPipeSegment * 6])
   T5Stack(decoder)->Pipe(Sequential([T5StackPipeSegment * 6])

6 for t5-small.

Please don't even bother looking at the code, it is one big hack which took many hours to come up with to make the pipeline work, so clearly it is not something very portable or readable.

Setup

important: you need pytorch-nightly to be able to use this.

pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu110/torch_nightly.html -U

Just create another conda env not to mess up with your normal env, but pt-nightly is a solid peace of software, I use it all the time. here is a quick copy-n-paste of what you will need - just edit the location of the transformers checkout dir.

conda create -y -n py38-pt18 python=3.8
conda activate py38-pt18
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu110/torch_nightly.html -U

git clone https://github.com/huggingface/transformers
cd transformers
gh pr checkout 9765 # or use whatever other method to checkout this PR
pip install -e .[dev]
pip install -r examples/_tests_requirements.txt

Down the road I will look at using also fairscale/deepspeed but for now pytorch is just more accessible and hopefully will be more flexible soon.

Deployment: script

You can deploy PP directly via your own trainer/script, e.g. this is what I have been using while developing it:

from transformers import T5Tokenizer, T5ForConditionalGeneration
import transformers.models.t5.modeling_t5
import transformers.utils.logging

transformers.models.t5.modeling_t5.logger.setLevel(transformers.logging.INFO)

mname = "t5-large"
tokenizer = T5Tokenizer.from_pretrained(mname)
model = T5ForConditionalGeneration.from_pretrained(mname, return_dict=True)
model.to("cuda:0")
model.pipeline_enable(chunks=2, device_map=None)

texts = ["This is good", "This is bad", "This is really bad", "This is fantastic",]
texts = ["translate English to French: "+x for x in texts]
batch = tokenizer.prepare_seq2seq_batch(texts, return_tensors="pt")
batch.to("cuda:0")
outputs = model.generate(**batch)
for x in outputs:
    decoded = tokenizer.decode(x, skip_special_tokens=True)
    print(decoded)

model.pipeline_finalize()

Deployment: HF Trainer

But you can also use HF trainer. I tweaked the trainer to activate PP with:

--pipeline "chunks=4"

This will let the program do the partitioning for you. But you can control the partitioning manually by passing:

--pipeline "chunks=4 device_map=0:0-3,1:3-12"

Here we basically pass the equivalent of a dict {0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9, 10, 11]} which btw, you can pass in your script as:

device_map = {0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9, 10, 11]}
model.pipeline_enable(chunks=30, device_map=device_map)

The syntax is what you'd pass to range, so `device_map=0:0-3,1:3-12" is the same as:

device_map = {0: list(range(0, 3), 1: list(range(3, 12)}

the keys are the gpu ids.

The number of layers is at the moment just the depth of the encoder stack, so 12 for t5-base, 6 for t5-small, etc.

Later we should have a different way as well, where we define the desired balance, rather than the specific layers.

Since each t5 model has a different number of blocks, the easiest way is to first run without the device map and then check the logger output which will show you which device map it's using. Then I recommend to re-balance it so that gpu0 has less layers than the remaining gpus.

Benchmarks

example for 2x 24GB cards

export BS=160 MODEL=t5-base; rm -r output_dir; PYTHONPATH=src USE_TF=0 examples/seq2seq/run_seq2seq.py \
--model_name_or_path $MODEL --output_dir output_dir --adam_eps 1e-06 --do_eval --do_train --evaluation_strategy=steps \
--label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 \
--max_target_length 128 --val_max_target_length 128  --num_train_epochs 1 --overwrite_output_dir \
--per_device_eval_batch_size $BS --per_device_train_batch_size $BS  --eval_steps 25000  --sortish_sampler  \
--warmup_steps 50 \
--task translation_en_to_ro --dataset_name wmt16 --dataset_config ro-en --source_prefix "translate English to Romanian: " \
--max_train_samples 10 --max_val_samples 10   \
--pipeline "chunks=4 device_map=0:0-3,1:3-12" --dataloader_num_workers 4

Performance-wise:

  • prediction speed is terrible - just as bad as the naive MP we have in t5 and others
  • training/eval w/o prediction is slightly slower 20-40% than the baseline with just one process - this is primarily due to data copying and the current quite inefficient implementation due to the Pipeline api restrictions.
  • the key is to find the value for chunks so that there is enough in the pipe so that the gpus don't idle, but not too big as performance goes down. But I wasn't able to overcome 50%/gpu utilization, so it's not much more different from the naive implementation - don't know yet why - probably data copying takes most of the overhead.
  • I think on 4 gpus it'd be good to try an experiment and put the encoder stack on gpu 0+1 and decoder on gpu 2+3, instead of copying data between 4 devices as it's happening now - this will require a more complex device map, that I designed for the Bart MP, which has separate encoder and decoder sub-maps. But then it'd affect the pipeline as half the gpus will surely idle while encoder is running - so not great either. We will have to experiment with real data once I have access to a rig with 4 gpus and see. That's why I don't think this is urgent to work on. But such change would be easy to do. We will have to do it anyway for other models whose stacks aren't necessarily symmetrical.

Here are some stats on 2x 24GB Titan RTX:

Baseline: (1gpu)

export BS=64 MODEL=t5-base; rm -r output_dir;  CUDA_VISIBLE_DEVICES=0  PYTHONPATH=src USE_TF=0  \
examples/seq2seq/run_seq2seq.py --model_name_or_path $MODEL --output_dir output_dir --adam_eps 1e-06  --do_eval \
--do_train --evaluation_strategy=steps  --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 \
 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir \
--per_device_eval_batch_size  $BS --per_device_train_batch_size $BS --eval_steps 25000  --sortish_sampler   \
--val_max_target_length 128 --warmup_steps 50 --max_train_samples 1000 --max_val_samples 1000 \
--task translation_en_to_ro --dataset_name wmt16 --dataset_config ro-en --source_prefix "translate English to Romanian: "

  train_runtime              = 6.9149
  eval_loss                 =  3.5492
  eval_runtime              =  3.2802

XXX: need to re-test with rebased code-base

Now with pipeline:

  • can run much higher batch-size
  • note, that I'm using a user-provided device map that has more layers on gpu 1, since gpu 0 needs much more RAM
# device_map=0:0-3,1:3-12 - so splitting 1:4
# {0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9, 10, 11]}
export BS=160 MODEL=t5-base;  rm -r output_dir;  CUDA_VISIBLE_DEVICES=0,1  PYTHONPATH=src USE_TF=0  \
examples/seq2seq/run_seq2seq.py --model_name_or_path $MODEL --output_dir output_dir --adam_eps 1e-06  --do_eval \
--do_train --evaluation_strategy=steps  --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 \
 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir \
--per_device_eval_batch_size  $BS --per_device_train_batch_size $BS --eval_steps 25000  --sortish_sampler   \
--val_max_target_length 128 --warmup_steps 50 --max_train_samples 1000 --max_val_samples 1000 \
--task translation_en_to_ro --dataset_name wmt16 --dataset_config ro-en --source_prefix "translate English to Romanian: " \
--pipeline "chunks=4 device_map=0:0-3,1:3-12" --dataloader_num_workers 4

XXX: need to re-test with rebased code-base

Future

I'm also trying to instrument this feature with reporting that will help users to finetune chunks/device_map
This is the model.pipeline_finalize() call. Things I'm thinking that would be useful:

  • gpu utilization stats (average/peak) - probably need to fire off a thread that samples pynvml gpu utilization, then calculates average + peak
  • peak memory usage per device report that I added seems to be too low - I think it has to do with pipeline threads - need to sort it out

Any other ideas/requests/needs?

@PeterAJansen, please let me know if you managed to run this on your 4x gpu setup.

Next, I think I'm going to scratch the current implementation and try a new one afresh.

Also this PR should be good enough to try to figure out how to use with DeepSpeed, once I get access to 4 gpus (need at least 4 gpus to do 2D parallelism).

I did warn you not to look at the code.

I also removed big chunks of MP code for now as it was getting in the way with the noise, will restore it when I sorted this all out.

@PeterAJansen
Copy link

Thanks @stas00 , I am getting what looks like a torch error when I run this (I'm not sure if the "Failed to look up the IP address for the hostname" error is related -- I'm not able to find much on this except for an issue from a few days ago that mentions this: pytorch/pytorch#50700 ):

git clone https://www.github.com/huggingface/transformers.git
cd transformers/
gh pr checkout 9765

conda create -y -n py38-pt18 python=3.8
conda activate py38-pt18
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu110/torch_nightly.html -U
pip install -e .[dev]
pip install -r examples/_tests_requirements.txt

cd examples/seq2seq/
ln -s ~/github/transformers/examples/seq2seq/wmt_en_ro wmt_en_ro

export BS=160 MODEL=t5-base; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  ./finetune_trainer.py --model_name_or_path $MODEL --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS  --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 50 --n_train 1000 --n_val 1000 --pipeline "chunks=4 device_map=0:0-3,1:3-12" --dataloader_num_workers 4 

Output:

export BS=160 MODEL=t5-base; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  ./finetune_trainer.py --model_name_or_path $MODEL --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS  --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 50 --n_train 1000 --n_val 1000 --pipeline "chunks=4 device_map=0:0-3,1:3-12" --dataloader_num_workers 4 

01/25/2021 11:39:51 - WARNING - __main__ -   Process rank: -1, device: cuda:0, n_gpu: 4, distributed training: False, 16-bits training: False
01/25/2021 11:39:51 - INFO - __main__ -   Training/evaluation parameters Seq2SeqTrainingArguments(output_dir='output_dir', overwrite_output_dir=True, do_train=True, do_eval=True, do_predict=False, evaluation_strategy=<EvaluationStrategy.STEPS: 'steps'>, prediction_loss_only=False, per_device_train_batch_size=160, per_device_eval_batch_size=160, per_gpu_train_batch_size=None, per_gpu_eval_batch_size=None, gradient_accumulation_steps=1, eval_accumulation_steps=None, learning_rate=3e-05, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-06, max_grad_norm=1.0, num_train_epochs=1.0, max_steps=-1, lr_scheduler_type=<SchedulerType.LINEAR: 'linear'>, warmup_steps=50, logging_dir='runs/Jan25_11-39-51_seahorse', logging_first_step=True, logging_steps=1000, save_steps=500, save_total_limit=None, no_cuda=False, seed=42, fp16=False, fp16_opt_level='O1', fp16_backend='auto', local_rank=-1, tpu_num_cores=None, tpu_metrics_debug=False, debug=False, pipeline='chunks=4 device_map=0:0-3,1:3-12', dataloader_drop_last=False, eval_steps=25000, dataloader_num_workers=4, past_index=-1, run_name='output_dir', disable_tqdm=False, remove_unused_columns=True, label_names=None, load_best_model_at_end=False, metric_for_best_model=None, greater_is_better=None, ignore_data_skip=False, sharded_ddp=False, deepspeed=None, label_smoothing_factor=0.1, adafactor=False, group_by_length=False, report_to=['tensorboard'], sortish_sampler=True, predict_with_generate=False)
[INFO|configuration_utils.py:445] 2021-01-25 11:39:51,546 >> loading configuration file https://huggingface.co/t5-base/resolve/main/config.json from cache at /home/pajansen/.cache/huggingface/transformers/91e9fe874e06c44883b535d6c950b8b89d6eaa3298d8e7fb3b2c78039e9f8b7b.66b9637a52aa11e9285cdd6e668cc0df14b3bcf0b6674cf3ba5353c542649637
[INFO|configuration_utils.py:481] 2021-01-25 11:39:51,547 >> Model config T5Config {
  "architectures": [
    "T5WithLMHeadModel"
  ],
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summarize: "
    },
    "translation_en_to_de": {
      "early_stopping": true,
      "max_length": 300,
      "num_beams": 4,
      "prefix": "translate English to German: "
    },
    "translation_en_to_fr": {
      "early_stopping": true,
      "max_length": 300,
      "num_beams": 4,
      "prefix": "translate English to French: "
    },
    "translation_en_to_ro": {
      "early_stopping": true,
      "max_length": 300,
      "num_beams": 4,
      "prefix": "translate English to Romanian: "
    }
  },
  "transformers_version": "4.3.0.dev0",
  "use_cache": true,
  "vocab_size": 32128
}

[INFO|configuration_utils.py:445] 2021-01-25 11:39:51,740 >> loading configuration file https://huggingface.co/t5-base/resolve/main/config.json from cache at /home/pajansen/.cache/huggingface/transformers/91e9fe874e06c44883b535d6c950b8b89d6eaa3298d8e7fb3b2c78039e9f8b7b.66b9637a52aa11e9285cdd6e668cc0df14b3bcf0b6674cf3ba5353c542649637
[INFO|configuration_utils.py:481] 2021-01-25 11:39:51,741 >> Model config T5Config {
  "architectures": [
    "T5WithLMHeadModel"
  ],
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summarize: "
    },
    "translation_en_to_de": {
      "early_stopping": true,
      "max_length": 300,
      "num_beams": 4,
      "prefix": "translate English to German: "
    },
    "translation_en_to_fr": {
      "early_stopping": true,
      "max_length": 300,
      "num_beams": 4,
      "prefix": "translate English to French: "
    },
    "translation_en_to_ro": {
      "early_stopping": true,
      "max_length": 300,
      "num_beams": 4,
      "prefix": "translate English to Romanian: "
    }
  },
  "transformers_version": "4.3.0.dev0",
  "use_cache": true,
  "vocab_size": 32128
}

[INFO|tokenization_utils_base.py:1766] 2021-01-25 11:39:52,522 >> loading file https://huggingface.co/t5-base/resolve/main/spiece.model from cache at /home/pajansen/.cache/huggingface/transformers/684a47ca6257e4ca71f0037771464c5b323e945fbc58697d2fad8a7dd1a2f8ba.3b69006860e7b5d0a63ffdddc01ddcd6b7c318a6f4fd793596552c741734c62d
[INFO|tokenization_utils_base.py:1766] 2021-01-25 11:39:52,523 >> loading file https://huggingface.co/t5-base/resolve/main/tokenizer.json from cache at /home/pajansen/.cache/huggingface/transformers/90de37880b5ff5ac7ab70ff0bd369f207e9b74133fa153c163d14c5bb0116207.8627f1bd5d270a9fd2e5a51c8bec3223896587cc3cfe13edeabb0992ab43c529
[INFO|modeling_utils.py:1027] 2021-01-25 11:39:52,809 >> loading weights file https://huggingface.co/t5-base/resolve/main/pytorch_model.bin from cache at /home/pajansen/.cache/huggingface/transformers/ab4e948915b067f5cb6e5105f6f85044fd717b133f43240db67899a8fc7b29a2.26934c75adf19ceac3c268b721ba353356b7609c45f5627550326f275a2163b4
[INFO|modeling_utils.py:1143] 2021-01-25 11:39:58,232 >> All model checkpoint weights were used when initializing T5ForConditionalGeneration.

[INFO|modeling_utils.py:1151] 2021-01-25 11:39:58,233 >> All the weights of T5ForConditionalGeneration were initialized from the model checkpoint at t5-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use T5ForConditionalGeneration for predictions without further training.
01/25/2021 11:39:58 - INFO - utils -   setting model.config to task specific params for translation_en_to_ro:
 {'early_stopping': True, 'max_length': 300, 'num_beams': 4, 'prefix': 'translate English to Romanian: '}
01/25/2021 11:39:58 - INFO - utils -   note: command line args may override some of these
[INFO|modeling_t5.py:1536] 2021-01-25 11:39:58,479 >> enabling pipeline with chunks=4
[INFO|modeling_t5.py:1545] 2021-01-25 11:39:58,479 >> using user-provided device_map
[INFO|modeling_t5.py:1563] 2021-01-25 11:39:58,479 >> using pipeline partitioning: {0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9, 10, 11]}
[W ProcessGroupGloo.cpp:532] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
[W tensorpipe_agent.cpp:63] Failed to look up the IP address for the hostname (EADDRNOTAVAIL: address not available), defaulting to 127.0.0.1
01/25/2021 11:40:01 - INFO - __main__ -   *** Train ***
[INFO|trainer.py:807] 2021-01-25 11:40:01,659 >> ***** Running training *****
[INFO|trainer.py:808] 2021-01-25 11:40:01,659 >>   Num examples = 1000
[INFO|trainer.py:809] 2021-01-25 11:40:01,659 >>   Num Epochs = 1
[INFO|trainer.py:810] 2021-01-25 11:40:01,659 >>   Instantaneous batch size per device = 160
[INFO|trainer.py:811] 2021-01-25 11:40:01,659 >>   Total train batch size (w. parallel, distributed & accumulation) = 160
[INFO|trainer.py:812] 2021-01-25 11:40:01,659 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:813] 2021-01-25 11:40:01,659 >>   Total optimization steps = 7
2021-01-25 11:40:01.766436: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
  0%|                                                                                                                                                                                                               | 0/7 [00:00<?, ?it/s]Traceback (most recent call last):
  File "./finetune_trainer.py", line 373, in <module>
    main()
  File "./finetune_trainer.py", line 303, in main
    train_result = trainer.train(
  File "/home/pajansen/stass-test1/transformers/src/transformers/trainer.py", line 904, in train
    tr_loss += self.training_step(model, inputs)
  File "/home/pajansen/stass-test1/transformers/src/transformers/trainer.py", line 1271, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/pajansen/stass-test1/transformers/src/transformers/trainer.py", line 1301, in compute_loss
    outputs = model(**inputs)
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/pajansen/stass-test1/transformers/src/transformers/models/t5/modeling_t5.py", line 1704, in forward
    encoder_outputs = self.encoder(
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/pajansen/stass-test1/transformers/src/transformers/models/t5/modeling_t5.py", line 1088, in forward
    outputs = block_pipe(inputs)
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/distributed/pipeline/sync/pipe.py", line 362, in forward
    self.pipeline.run(batches)
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/distributed/pipeline/sync/pipeline.py", line 117, in run
    self.compute(batches, schedule, skip_trackers)
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/distributed/pipeline/sync/pipeline.py", line 257, in compute
    raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/distributed/pipeline/sync/worker.py", line 79, in worker
    batch = task.compute()
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/distributed/pipeline/sync/worker.py", line 60, in compute
    return self._compute()
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/distributed/pipeline/sync/pipeline.py", line 222, in compute
    return batch.call(partition)
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/distributed/pipeline/sync/microbatch.py", line 70, in call
    return Batch(function(self.value))
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/pajansen/stass-test1/transformers/src/transformers/models/t5/modeling_t5.py", line 838, in forward
    layer_outputs = self.layer_module(hidden_states,
  File "/home/pajansen/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'head_mask'

@stas00
Copy link
Contributor Author

stas00 commented Jan 25, 2021

It looks more like a warning as it recovers with a fallback, make sure you have:

$ cat /etc/hosts
127.0.0.1       localhost

It looks like I forgot to commit the last change. My apologies. Could you please update and try again?

@PeterAJansen
Copy link

Thanks -- appears to be working -- on t5-3B it spreads it evenly across the 4 A100s (13.0-13.3GB each with a batch size of 1). For t5-11B there's an out of memory error -- I suppose (naively) if 11b is ~3.7x larger than 3B then it would require ~49gb per card without some form of offloading?

@stas00
Copy link
Contributor Author

stas00 commented Jan 25, 2021

Thank you for confirming that you were able to use it with t5-3b on your 4 gpus.

Were you able to get a decent gpu utilization across the board? Or were they all under 25%?


Please make sure you read my notes on the balancing in OP and experiment with the device map so that all gpus get a balanced GPU memory usage. gpu0 is already busy with many things, so I'd try a spread of 2/4/4/4 parts or perhaps 1/3/3/3 in your definition of:

--pipeline "chunks=4 device_map=0:0-3,1:3-12"

in this example we have 1/3 parts balance between gpu 0 and 1. i.e. 3 times more layers for gpu 1.

Of course, it needs to be adjusted to 4 gpus and I don't remember how many encoder blocks t5-11b has, but as I mentioned if you look at the logs you will find a ready map there, just re-adjust it to balance things better. Please let me know if I communicated clearly what I'm trying to say - we want all 4 gpus to have about the same memory usage - then we maximize the chance to fit t5-11b on those 4 gpus.


Next we need to try to bolt DeepSpeed on it. So we will try to use 2 gpus for pipeline and 2 gpus for ZeRO-DP and perhaps some ZeRO-Offload too. I should get access to 4 gpus soon and I will start working on figuring that step out. I will post back once I have something practical to share.

@PeterAJansen
Copy link

Thanks -- the autobalancing (just "chunks=4") actually seemed to give nearly entirely even results on -3B (the ~13.0-3GB each), so I tried that with 11B instead of manually supplying the device map (since it seemed a bit uneven when I tested on -base) -- but I'll tinker on 11B and report back.

@stas00
Copy link
Contributor Author

stas00 commented Jan 25, 2021

the autobalancing

FYI, currently the automatic device map just tries to split n_layers/n_gpus per gpu, and not taking into an account gpu0's extra load. Once everything else is working we will come up with much better heuristics based on actual gpu capacity and each layer's real memory demands.

@PeterAJansen
Copy link

What's interesting is that I'm not generally observing GPU0 to have a higher load. Here's an example with unifiedqa-t5-3b (essentially just a further pre-trained t5-3b, not relevant here), chunks=4, autobalancing (with a different visualization tool). They all tend to show about the same RAM usage over time. The graph also shows the utilization (also generally under 30% most of the time):

image

BTW -- I tinkered with different manual device_map settings for t5-11b, but it always quickly gave out of memory errors.

@stas00
Copy link
Contributor Author

stas00 commented Jan 26, 2021

Oh, what tool is that? I want it too!

It looks like different GPUs behave differently, it will take some experimentation to make sense of it all.

But clearly you're also not seeing much benefit from the pipeline over the native MP. Same as I. Either my workaround to make it work slow everything down or there is another problem elsewhere. As I mentioned I'd like to redesign my implementation in hope
to reduce the unnecessary logic and data-copying.

BTW -- I tinkered with different manual device_map settings for t5-11b, but it always quickly gave out of memory errors.

Thank you for the experimentation. I'm still waiting to get access to a 4-gpu setup and when it happens will immediately start experimenting with bolting DeepSpeed on it and then will get back to you.

@PeterAJansen
Copy link

Thanks -- this handy cool visualization tool is nvtop -- I just found it to plot the relative changes rather than stare at nvidia-smi and hope to keep it all in my brain. It's available with apt ( sudo apt-get install nvtop ).

Happy to offer my rig for some testing if you need a 4 GPU setup sooner. :)

@stas00
Copy link
Contributor Author

stas00 commented Jan 26, 2021

Oh, yes, I had it and forgot about its usefulness. Thank you!

I typically use

alias wn='watch -n 1 nvidia-smi'

but this is a way better.

Happy to offer my rig for some testing if you need a 4 GPU setup sooner. :)

If don't find access by tomorrow I will gladly accept your generous offer, @PeterAJansen. Thank you!

@stas00
Copy link
Contributor Author

stas00 commented Jan 26, 2021

hmm, how do you get a split screen per card in nvtop? for some reason my version reports both cards as one card. I don't see any command line options to configure that.

@PeterAJansen
Copy link

hmmm, it actually worked out-of-the-box for me (but looks very different depending on the dimensions of the terminal). Does it show only 1 GPU (with memory for both?), or two separate GPUs?

@stas00
Copy link
Contributor Author

stas00 commented Jan 26, 2021

It reports 2 gpus but shows the report only for gpu 0. could be a bug. I just saw that for you it showed all 4 gpus.
snapshot_7

@PeterAJansen
Copy link

What happens if you make the window really tall/wide? It changes the display for me if I resize the terminal -- if I make it really tiny, it looks something like yours:

image

@stas00
Copy link
Contributor Author

stas00 commented Jan 26, 2021

Sorry, I forgot to mentioned I tried this already to no avail. I give it a huge console.

I even tried various terminals - same.

I think it may have to do with my 2nd card being rtx-3090 - and it doesn't work with cuda < 11.1 - most likely nvtop was built against cuda-10, so while it replicates the nvidia-smi stats, it can't access nvml for that card and thus doesn't show the graph.

Yup, installed nvtop on a machine with 2 normal gpus and it shows them both in the same-size terminal. So it just can't handle rtx-30* unless it's rebuilt from source against cuda-11.1+

But even then when it works it gives no way to separate the 2 gpu other than colors and 4 lines often around the same magnitude for different things are impossible to make sense of. This is an odd design.

@PeterAJansen
Copy link

PeterAJansen commented Jan 26, 2021

:-/ That's unfortunate (though I suppose the cost of using bleeding-edge hardware). The A100s are supported with CUDA 11.0, so they must just squeak in on the current version available with apt.

(And, the usability is a little unusual, but with ASCII graphics there are strong limits... :) )

@stas00
Copy link
Contributor Author

stas00 commented Jan 26, 2021

pytorch w/ cuda-11.2 nightly should be available any day now. cuda-11.2 has been out for a month now.

(And, the usability is a little unusual, but with ASCII graphics there are strong limits... :) )

This is a good point. But at least one could use a double line or asterisks or something to differentiate 4 different things. Perhaps some people can track 4 similar colors and remember which is which. Not me. I guess the source code is there, if I really need to I could probably hack it to do be more user-friendly.

@stas00
Copy link
Contributor Author

stas00 commented Jan 31, 2021

Update: this overload of the term MP to mean totally different things is a big problem.

I was sure I could easily combine non-DeepSpeed pipeline with Deepspeed after reading
https://www.deepspeed.ai/features/#support-for-custom-model-parallelism
Except, I have just now realized that it's not PP but the super-confusing-mean-different-things-in-different-contexts abbreviation MP, which in this particular context means horizontal MP and not vertical MP/PP. And there are no instructions on how to integrate non-DeepSpeed PP. So I have been trying to fix the wrong thing. microsoft/DeepSpeed#710

So this particular branch takes us nowhere closer to integration of PP with DeepSpeed.

Back to the drawing board.

Comment on lines +1117 to +1148
# rewrite the model after pre-trained weights were loaded
layers = [
T5StackPipeSegment(
idx,
n_layers,
layer_module,
self.is_decoder,
head_mask[idx],
encoder_head_mask[idx],
output_hidden_states,
use_cache,
output_attentions,
all_hidden_states_add,
present_key_value_states_add,
all_attentions_add,
all_cross_attentions_add,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]

# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights),
# (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)

if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)

# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
for idx, layer_module in enumerate(self.block)
]
# block_sequential = nn.Sequential(*layers)

# for now don't enable the pipe
if self.pipeline_is_enabled:

# print("using partitioning: ", dict(zip(devices, layer_splits)))
for device_id, layer_partition in self.device_map.items():
for layer_id in layer_partition:
# print(f"{layer_id} => {device_id}")
layers[layer_id].to(device_id)

block_sequential = nn.Sequential(*layers)
block_pipe = Pipe(block_sequential, chunks=self.pipeline_chunks, checkpoint="never")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we are creating the entire pipeline model each time in the forward pass. Isn't this going to be pretty expensive? Why don't we just create the Pipe only during init and just use it in the forward pass?

Copy link
Contributor Author

@stas00 stas00 Feb 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because if you look a bit higher where I create T5StackPipeSegment to populate layers I don't have all those arguments available at init time - i.e. I only get those during forward. But I can pass them to forward of course. I was just trying to come up with ways to overcome the Tensor/tuple(Tensor) restriction and at the time this seems like a good workaround. If I don't do it, the passing of arguments to forward becomes even more complicated.

But as you started looking into in order to identify why there is no speed up over naive MP, it could be the cause - I have no idea how long does it take to build these 2 pipes - but surely it's expensive doing it on every model run.

My first experiment was just to make the pipe work at all costs, so clearly this is not a way forward.

Copy link
Contributor Author

@stas00 stas00 Feb 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And btw, continuing our discussion from the pytorch issue, we also can't easily change the shape of the model, because we need to adhere to the pre-trained model layout to be able to load pre-trained models. I suppose it could be re-mapped after the normal dist state is loaded. So that's another complication I was trying to overcome here - injecting proxy layers that had no weights of their own.

@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Apr 14, 2021
@huggingface huggingface deleted a comment from github-actions bot Apr 14, 2021
@stas00
Copy link
Contributor Author

stas00 commented Jun 4, 2021

too long. closing.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Jul 22, 2021

We will test this branch soon.

@stas00
Copy link
Contributor Author

stas00 commented Jul 22, 2021

There are probably some things that can be salvaged from this PR, but the main utility of it is to see the difficulties I run into. And of course, this is not a good solution not only because the code is absolutely nuts, but because it's very inefficient.

As I mentioned in the other thread, pytorch now has a better API, so some of the encoding/decoding of non-tensor inputs/outputs I did won't be needed anymore as it now supports non-tensor inputs/output.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Pipeline Parallel WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants