Skip to content

[BUG] Same train time with DeepSpeed (despite increased batch size) #1825

@pminhyung

Description

@pminhyung

Describe the bug and Expected behavior
Hi! I really appreciate your awesome work! I'm fine-tuning Electra model with using huggingface without Trainer API and with using deepspeed. After I applied deepspeed, I could increase the batch size (64 -> 128, but OOM with 256) of training model so I expected train time would decrease. However, even though I applied deepspeed in my code, the train time is the same. I am not sure if deepspeed is working as my configuration and if not, I wonder how I can make it work well (faster training)

I am training the model in ColabProPlus and Environment info is as belows.
( I have been allocated with the same GPU every time)

System info (please complete the following information):

  • OS: Ubuntu 18.04.5 LTS
  • GPU count and types : 1 GPU, Tesla P100-PCIE-16GB (Colab Pro Plus)
  • Python version : 3.7.12
  • torch install path ............... ['/usr/local/lib/python3.7/dist-packages/torch']
  • torch version .................... 1.10.0+cu111
  • torch cuda version ............... 11.1
  • torch hip version ................ None
  • nvcc version ..................... 11.1
  • deepspeed install path ........... ['/usr/local/lib/python3.7/dist-packages/deepspeed']
  • deepspeed info ................... 0.6.0, unknown, unknown
  • deepspeed wheel compiled w. ...... torch 1.10, cuda 11.1, hip 0.0

To Reproduce
I run this command to use deepspeed JIT mode when I initialize runtime in colab.

!apt-get install libaio-dev
!pip install deepspeed triton==1.0.0

then, my ds_report shows as below before running train code.

ds_report output

DeepSpeed C++/CUDA extension op report

NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.

JIT compiled ops requires ninja
ninja .................. [OKAY]

op name ................ installed .. compatible

cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
sparse_attn ............ [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
async_io ............... [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]

DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.7/dist-packages/torch']
torch version .................... 1.10.0+cu111
torch cuda version ............... 11.1
torch hip version ................ None
nvcc version ..................... 11.1
deepspeed install path ........... ['/usr/local/lib/python3.7/dist-packages/deepspeed']
deepspeed info ................... 0.6.0, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.10, cuda 11.1, hip 0.0

Deepspeed config
my deepspeed configuration is like this.

"ds_config":{
        "stage2": {
            "train_micro_batch_size_per_gpu": 128,
            "fp16": {
                "enabled": true,
                "loss_scale": 0,
                "loss_scale_window": 1000,
                "initial_scale_power": 16,
                "hysteresis": 2,
                "min_loss_scale": 1
                    },

            "optimizer": {
                "type": "AdamW",
                "params": {
                    "lr": 5e-05,
                    "betas": [
                        0.9,
                        0.999
                        ],
                    "eps": 1e-8
                        }
                      },

            "scheduler": {
                "type": "WarmupLR",
                "params": {
                    "warmup_min_lr": 0,
                    "warmup_max_lr": 5e-5,
                    "warmup_num_steps": 0
                        }
                        },

            "zero_optimization": {
                "stage": 2,
                "allgather_partitions": true,
                "allgather_bucket_size": 5e8,
                "overlap_comm": true,
                "reduce_scatter": true,
                "reduce_bucket_size": 5e8,
                "contiguous_gradients": true,
                "cpu_offload": true,
                "offload_optimizer": {
                    "device": "cpu",
                    "pin_memory": true,
                    "fast_init": true
                                }
                          },

            "gradient_accumulation_steps": 1,
            "gradient_clipping": 1.0
        }
}

Train code
My train code is almost same with code of the link in comment and added just few lines to use deepspeed as below.

# https://github.com/monologg/KoELECTRA/blob/master/finetune/run_seq_cls.py

...
import deepspeed
...

def train():
    ...
    ...

    model, optimizer, _, _ = deepspeed.initialize(model=model,
                              config_params=args['ds_config']['stage2'],
                              model_parameters=model.parameters())
    ...
    ...
    for epoch in mb:
        epoch_iterator = progress_bar(train_dataloader, parent=mb)
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            
            outputs = model(**inputs)
            loss = outputs.loss
            model.backward(loss)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    len(train_dataloader) <= args.gradient_accumulation_steps
                    and (step + 1) == len(train_dataloader)
            ):
                #torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                #optimizer.step()
                model.step()
                ...

def main():
    ....
    deepspeed.init_distributed("nccl")
    ....

Output Logs
This is my output log when I run the training.

[2022-03-11 02:05:14,971] [INFO] [logging.py:69:log_dist] [Rank 0] DeepSpeed info: version=0.6.0, git-hash=unknown, git-branch=unknown
03/11/2022 02:05:14 - INFO - torch.distributed.distributed_c10d -   Added key: store_based_barrier_key:2 to store for rank: 0
03/11/2022 02:05:14 - INFO - torch.distributed.distributed_c10d -   Rank 0: Completed store-based barrier for key:store_based_barrier_key:2 with 1 nodes.
[2022-03-11 02:05:15,039] [INFO] [engine.py:278:__init__] DeepSpeed Flops Profiler Enabled: False
Using /root/.cache/torch_extensions/py37_cu111 as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/py37_cu111/cpu_adam...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py37_cu111/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] /usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/usr/local/lib/python3.7/dist-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.7/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.7m -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_60,code=compute_60 -gencode=arch=compute_60,code=sm_60 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_60,code=compute_60 -c /usr/local/lib/python3.7/dist-packages/deepspeed/ops/csrc/common/custom_cuda_kernel.cu -o custom_cuda_kernel.cuda.o 
[2/3] c++ -MMD -MF cpu_adam.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/usr/local/lib/python3.7/dist-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.7/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.7m -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -O3 -std=c++14 -g -Wno-reorder -L/usr/local/cuda/lib64 -lcudart -lcublas -g -march=native -fopenmp -D__AVX256__ -c /usr/local/lib/python3.7/dist-packages/deepspeed/ops/csrc/adam/cpu_adam.cpp -o cpu_adam.o 
[3/3] c++ cpu_adam.o custom_cuda_kernel.cuda.o -shared -lcurand -L/usr/local/lib/python3.7/dist-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda_cu -ltorch_cuda_cpp -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart -o cpu_adam.so
Loading extension module cpu_adam...
Time to load cpu_adam op: 26.065569162368774 seconds
Adam Optimizer #0 is created with AVX2 arithmetic capability.
Config: alpha=0.000050, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1
[2022-03-11 02:05:42,721] [INFO] [engine.py:1066:_configure_optimizer] Using DeepSpeed Optimizer param name adamw as basic optimizer
[2022-03-11 02:05:42,730] [INFO] [engine.py:1073:_configure_optimizer] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
[2022-03-11 02:05:42,731] [INFO] [utils.py:49:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2022-03-11 02:05:42,731] [INFO] [logging.py:69:log_dist] [Rank 0] Creating fp16 ZeRO stage 2 optimizer
[2022-03-11 02:05:42,731] [INFO] [stage_1_and_2.py:125:__init__] Reduce bucket size 500000000.0
[2022-03-11 02:05:42,731] [INFO] [stage_1_and_2.py:126:__init__] Allgather bucket size 500000000.0
[2022-03-11 02:05:42,731] [INFO] [stage_1_and_2.py:127:__init__] CPU Offload: True
[2022-03-11 02:05:42,731] [INFO] [stage_1_and_2.py:128:__init__] Round robin gradient partitioning: False
Using /root/.cache/torch_extensions/py37_cu111 as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/py37_cu111/utils...
Emitting ninja build file /root/.cache/torch_extensions/py37_cu111/utils/build.ninja...
Building extension module utils...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/2] c++ -MMD -MF flatten_unflatten.o.d -DTORCH_EXTENSION_NAME=utils -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /usr/local/lib/python3.7/dist-packages/torch/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.7/dist-packages/torch/include/THC -isystem /usr/include/python3.7m -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -c /usr/local/lib/python3.7/dist-packages/deepspeed/ops/csrc/utils/flatten_unflatten.cpp -o flatten_unflatten.o 
[2/2] c++ flatten_unflatten.o -shared -L/usr/local/lib/python3.7/dist-packages/torch/lib -lc10 -ltorch_cpu -ltorch -ltorch_python -o utils.so
Loading extension module utils...
Time to load utils op: 12.981101274490356 seconds
Rank: 0 partition count [1] and sizes[(113066686, False)] 
[2022-03-11 02:05:56,300] [INFO] [utils.py:824:see_memory_usage] Before initializing optimizer states
[2022-03-11 02:05:56,301] [INFO] [utils.py:829:see_memory_usage] MA 0.26 GB         Max_MA 0.47 GB         CA 0.74 GB         Max_CA 1 GB 
[2022-03-11 02:05:56,302] [INFO] [utils.py:834:see_memory_usage] CPU Virtual Memory:  used = 7.39 GB, percent = 14.5%
[2022-03-11 02:05:56,823] [INFO] [utils.py:824:see_memory_usage] After initializing optimizer states
[2022-03-11 02:05:56,824] [INFO] [utils.py:829:see_memory_usage] MA 0.26 GB         Max_MA 0.26 GB         CA 0.74 GB         Max_CA 1 GB 
[2022-03-11 02:05:56,824] [INFO] [utils.py:834:see_memory_usage] CPU Virtual Memory:  used = 8.66 GB, percent = 17.0%
[2022-03-11 02:05:56,825] [INFO] [stage_1_and_2.py:497:__init__] optimizer state initialized
[2022-03-11 02:05:56,963] [INFO] [utils.py:824:see_memory_usage] After initializing ZeRO optimizer
[2022-03-11 02:05:56,963] [INFO] [utils.py:829:see_memory_usage] MA 0.26 GB         Max_MA 0.26 GB         CA 0.74 GB         Max_CA 1 GB 
[2022-03-11 02:05:56,964] [INFO] [utils.py:834:see_memory_usage] CPU Virtual Memory:  used = 8.66 GB, percent = 17.0%
[2022-03-11 02:05:56,964] [INFO] [logging.py:69:log_dist] [Rank 0] DeepSpeed Final Optimizer = adamw
[2022-03-11 02:05:56,964] [INFO] [engine.py:777:_configure_lr_scheduler] DeepSpeed using configured LR scheduler = WarmupLR
[2022-03-11 02:05:56,964] [INFO] [logging.py:69:log_dist] [Rank 0] DeepSpeed LR Scheduler = <deepspeed.runtime.lr_schedules.WarmupLR object at 0x7f4e91ebe7d0>
[2022-03-11 02:05:56,964] [INFO] [logging.py:69:log_dist] [Rank 0] step=0, skipped=0, lr=[5e-05], mom=[[0.9, 0.999]]
[2022-03-11 02:05:56,965] [INFO] [config.py:1058:print] DeepSpeedEngine configuration:
[2022-03-11 02:05:56,965] [INFO] [config.py:1062:print]   activation_checkpointing_config  {
    "partition_activations": false, 
    "contiguous_memory_optimization": false, 
    "cpu_checkpointing": false, 
    "number_checkpoints": null, 
    "synchronize_checkpoint_boundary": false, 
    "profile": false
}
[2022-03-11 02:05:56,965] [INFO] [config.py:1062:print]   aio_config ................... {'block_size': 1048576, 'queue_depth': 8, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}
[2022-03-11 02:05:56,965] [INFO] [config.py:1062:print]   amp_enabled .................. False
[2022-03-11 02:05:56,965] [INFO] [config.py:1062:print]   amp_params ................... False
[2022-03-11 02:05:56,965] [INFO] [config.py:1062:print]   autotuning_config ............ {
    "enabled": false, 
    "start_step": null, 
    "end_step": null, 
    "metric_path": null, 
    "arg_mappings": null, 
    "metric": "throughput", 
    "model_info": null, 
    "results_dir": null, 
    "exps_dir": null, 
    "overwrite": true, 
    "fast": true, 
    "start_profile_step": 3, 
    "end_profile_step": 5, 
    "tuner_type": "gridsearch", 
    "tuner_early_stopping": 5, 
    "tuner_num_trials": 50, 
    "model_info_path": null, 
    "mp_size": 1, 
    "max_train_batch_size": null, 
    "min_train_batch_size": 1, 
    "max_train_micro_batch_size_per_gpu": 1.024000e+03, 
    "min_train_micro_batch_size_per_gpu": 1, 
    "num_tuning_micro_batch_sizes": 3
}
[2022-03-11 02:05:56,965] [INFO] [config.py:1062:print]   bfloat16_enabled ............. False
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   checkpoint_tag_validation_enabled  True
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   checkpoint_tag_validation_fail  False
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   communication_data_type ...... None
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   curriculum_enabled ........... False
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   curriculum_params ............ False
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   dataloader_drop_last ......... False
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   disable_allgather ............ False
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   dump_state ................... False
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   dynamic_loss_scale_args ...... {'init_scale': 65536, 'scale_window': 1000, 'delayed_shift': 2, 'min_scale': 1}
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   eigenvalue_enabled ........... False
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   eigenvalue_gas_boundary_resolution  1
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   eigenvalue_layer_name ........ bert.encoder.layer
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   eigenvalue_layer_num ......... 0
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   eigenvalue_max_iter .......... 100
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   eigenvalue_stability ......... 1e-06
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   eigenvalue_tol ............... 0.01
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   eigenvalue_verbose ........... False
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   elasticity_enabled ........... False
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   flops_profiler_config ........ {
    "enabled": false, 
    "profile_step": 1, 
    "module_depth": -1, 
    "top_modules": 1, 
    "detailed": true, 
    "output_file": null
}
[2022-03-11 02:05:56,966] [INFO] [config.py:1062:print]   fp16_enabled ................. True
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   fp16_master_weights_and_gradients  False
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   fp16_mixed_quantize .......... False
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   global_rank .................. 0
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   gradient_accumulation_steps .. 1
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   gradient_clipping ............ 1.0
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   gradient_predivide_factor .... 1.0
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   initial_dynamic_scale ........ 65536
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   loss_scale ................... 0
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   memory_breakdown ............. False
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   optimizer_legacy_fusion ...... False
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   optimizer_name ............... adamw
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   optimizer_params ............. {'lr': 5e-05, 'betas': [0.9, 0.999], 'eps': 1e-08}
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0}
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   pld_enabled .................. False
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   pld_params ................... False
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   prescale_gradients ........... False
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   quantize_change_rate ......... 0.001
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   quantize_groups .............. 1
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   quantize_offset .............. 1000
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   quantize_period .............. 1000
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   quantize_rounding ............ 0
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   quantize_start_bits .......... 16
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   quantize_target_bits ......... 8
[2022-03-11 02:05:56,967] [INFO] [config.py:1062:print]   quantize_training_enabled .... False
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   quantize_type ................ 0
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   quantize_verbose ............. False
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   scheduler_name ............... WarmupLR
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   scheduler_params ............. {'warmup_min_lr': 0, 'warmup_max_lr': 5e-05, 'warmup_num_steps': 0}
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   sparse_attention ............. None
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   sparse_gradients_enabled ..... False
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   steps_per_print .............. 10
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   tensorboard_enabled .......... False
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   tensorboard_job_name ......... DeepSpeedJobName
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   tensorboard_output_path ...... 
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   train_batch_size ............. 128
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   train_micro_batch_size_per_gpu  128
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   use_quantizer_kernel ......... False
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   wall_clock_breakdown ......... False
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   world_size ................... 1
[2022-03-11 02:05:56,968] [INFO] [config.py:1062:print]   zero_allow_untested_optimizer  False
[2022-03-11 02:05:56,969] [INFO] [config.py:1062:print]   zero_config .................. {
    "stage": 2, 
    "contiguous_gradients": true, 
    "reduce_scatter": true, 
    "reduce_bucket_size": 5.000000e+08, 
    "allgather_partitions": true, 
    "allgather_bucket_size": 5.000000e+08, 
    "overlap_comm": true, 
    "load_from_fp32_weights": true, 
    "elastic_checkpoint": false, 
    "offload_param": null, 
    "offload_optimizer": {
        "device": null, 
        "nvme_path": null, 
        "buffer_count": 4, 
        "pin_memory": false, 
        "pipeline_read": false, 
        "pipeline_write": false, 
        "fast_init": false
    }, 
    "sub_group_size": 1.000000e+09, 
    "prefetch_bucket_size": 5.000000e+07, 
    "param_persistence_threshold": 1.000000e+05, 
    "max_live_parameters": 1.000000e+09, 
    "max_reuse_distance": 1.000000e+09, 
    "gather_16bit_weights_on_model_save": false, 
    "ignore_unused_parameters": true, 
    "round_robin_gradients": false, 
    "legacy_stage1": false
}
[2022-03-11 02:05:56,969] [INFO] [config.py:1062:print]   zero_enabled ................. True
[2022-03-11 02:05:56,969] [INFO] [config.py:1062:print]   zero_optimization_stage ...... 2
[2022-03-11 02:05:56,969] [INFO] [config.py:1070:print]   json = {
    "train_micro_batch_size_per_gpu": 128, 
    "fp16": {
        "enabled": true, 
        "loss_scale": 0, 
        "loss_scale_window": 1000, 
        "initial_scale_power": 16, 
        "hysteresis": 2, 
        "min_loss_scale": 1
    }, 
    "optimizer": {
        "type": "AdamW", 
        "params": {
            "lr": 5e-05, 
            "betas": [0.9, 0.999], 
            "eps": 1e-08
        }
    }, 
    "scheduler": {
        "type": "WarmupLR", 
        "params": {
            "warmup_min_lr": 0, 
            "warmup_max_lr": 5e-05, 
            "warmup_num_steps": 0
        }
    }, 
    "zero_optimization": {
        "stage": 2, 
        "allgather_partitions": true, 
        "allgather_bucket_size": 5.000000e+08, 
        "overlap_comm": true, 
        "reduce_scatter": true, 
        "reduce_bucket_size": 5.000000e+08, 
        "contiguous_gradients": true, 
        "cpu_offload": true, 
        "offload_optimizer": {
            "device": "cpu", 
            "pin_memory": true, 
            "fast_init": true
        }
    }, 
    "gradient_accumulation_steps": 1, 
    "gradient_clipping": 1.0
}
Using /root/.cache/torch_extensions/py37_cu111 as PyTorch extensions root...
No modifications detected for re-loaded extension module utils, skipping build step...
Loading extension module utils...
Time to load utils op: 0.00044083595275878906 seconds
03/11/2022 02:05:56 - INFO - __main__ -   ***** Running training *****
03/11/2022 02:05:56 - INFO - __main__ -     Num examples = 234528
03/11/2022 02:05:56 - INFO - __main__ -     Num Epochs = 100
03/11/2022 02:05:56 - INFO - __main__ -     Total train batch size = 128
03/11/2022 02:05:56 - INFO - __main__ -     Gradient Accumulation steps = 1
03/11/2022 02:05:56 - INFO - __main__ -     Total optimization steps = 146600
03/11/2022 02:05:56 - INFO - __main__ -     Logging steps = 3000
03/11/2022 02:05:56 - INFO - __main__ -     Save steps = 3000
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
[2022-03-11 02:05:59,192] [INFO] [stage_1_and_2.py:1656:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, reducing to 65536
/usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py:134: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
<IPython.core.display.HTML object>
[2022-03-11 02:06:00,592] [INFO] [stage_1_and_2.py:1656:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, reducing to 32768.0
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
[2022-03-11 02:06:13,536] [INFO] [logging.py:69:log_dist] [Rank 0] step=10, skipped=2, lr=[5e-05], mom=[[0.9, 0.999]]
[2022-03-11 02:06:13,537] [INFO] [timer.py:189:stop] 0/10, SamplesPerSec=79.2758999705874, MemAllocated=0.26GB, MaxMemAllocated=9.02GB
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
[2022-03-11 02:06:28,687] [INFO] [logging.py:69:log_dist] [Rank 0] step=20, skipped=2, lr=[5e-05], mom=[[0.9, 0.999]]
[2022-03-11 02:06:28,688] [INFO] [timer.py:189:stop] 0/20, SamplesPerSec=82.1916862587935, MemAllocated=0.26GB, MaxMemAllocated=9.02GB
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
[2022-03-11 02:06:43,833] [INFO] [logging.py:69:log_dist] [Rank 0] step=30, skipped=2, lr=[5e-05], mom=[[0.9, 0.999]]
[2022-03-11 02:06:43,834] [INFO] [timer.py:189:stop] 0/30, SamplesPerSec=83.07239837441124, MemAllocated=0.26GB, MaxMemAllocated=9.02GB
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
[2022-03-11 02:06:58,985] [INFO] [logging.py:69:log_dist] [Rank 0] step=40, skipped=2, lr=[5e-05], mom=[[0.9, 0.999]]
[2022-03-11 02:06:58,985] [INFO] [timer.py:189:stop] 0/40, SamplesPerSec=83.48609168106144, MemAllocated=0.26GB, MaxMemAllocated=9.02GB
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
[2022-03-11 02:07:14,130] [INFO] [logging.py:69:log_dist] [Rank 0] step=50, skipped=2, lr=[5e-05], mom=[[0.9, 0.999]]
[2022-03-11 02:07:14,130] [INFO] [timer.py:189:stop] 0/50, SamplesPerSec=83.73838805358588, MemAllocated=0.26GB, MaxMemAllocated=9.02GB
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
[2022-03-11 02:07:29,274] [INFO] [logging.py:69:log_dist] [Rank 0] step=60, skipped=2, lr=[5e-05], mom=[[0.9, 0.999]]
[2022-03-11 02:07:29,274] [INFO] [timer.py:189:stop] 0/60, SamplesPerSec=83.9044415557675, MemAllocated=0.26GB, MaxMemAllocated=9.02GB
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
[2022-03-11 02:07:44,417] [INFO] [logging.py:69:log_dist] [Rank 0] step=70, skipped=2, lr=[5e-05], mom=[[0.9, 0.999]]
[2022-03-11 02:07:44,418] [INFO] [timer.py:189:stop] 0/70, SamplesPerSec=84.02348484813439, MemAllocated=0.26GB, MaxMemAllocated=9.02GB

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions