Skip to content

Commit

Permalink
Add zero test (#5476)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashbhandare committed Oct 22, 2020
1 parent 6d35be2 commit 0a9b83a
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 2 deletions.
2 changes: 1 addition & 1 deletion orttraining/orttraining/python/training/checkpoint.py
Expand Up @@ -25,7 +25,7 @@ def experimental_state_dict(ort_trainer, include_optimizer_state=True):

# extract untrained weights and buffer
for n in ort_trainer._onnx_model.graph.initializer:
if n.name not in torch_state:
if n.name not in torch_state and n.name in ort_trainer.options.utils.frozen_weights:
torch_state[n.name] = torch.from_numpy(np.array(onnx.numpy_helper.to_array(n)))

# Need to remove redundant (optimizer) initializers to map back to original torch state names
Expand Down
106 changes: 105 additions & 1 deletion orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py
Expand Up @@ -4,6 +4,7 @@

# ==================
import os
import shutil
import logging
import random
import h5py
Expand All @@ -29,6 +30,7 @@
import onnxruntime as ort
from onnxruntime.training import amp, optim, orttrainer
from onnxruntime.training.optim import PolyWarmupLRScheduler, LinearWarmupLRScheduler
from onnxruntime.training.checkpoint import experimental_save_checkpoint, _list_checkpoint_files, _CombineZeroCheckpoint

# need to override torch.onnx.symbolic_opset12.nll_loss to handle ignore_index == -100 cases.
# the fix for ignore_index == -100 cases is already in pytorch master.
Expand Down Expand Up @@ -210,6 +212,11 @@ class PretrainArguments:
metadata={"help": "Loss scaling, positive power of 2 values can improve fp16 convergence."}
)

deepspeed_zero_stage: Optional[int] = field(
default=0,
metadata={"help": "Deepspeed Zero Stage. 0 => disabled"}
)

log_freq: Optional[float] = field(
default=1.0,
metadata={"help": "frequency of logging loss."}
Expand All @@ -235,6 +242,16 @@ class PretrainArguments:
metadata={"help": "Number of update steps until a model checkpoint is saved to disk."}
)

save_checkpoint: Optional[bool] = field(
default=False,
metadata={"help": "Enable for saving a model checkpoint to disk."}
)

init_state_dict: Optional[dict] = field(
default=None,
metadata={"help": "State to load before training."}
)

phase2: bool = field(
default=False,
metadata={"help": "Whether to train with seq len 512."}
Expand Down Expand Up @@ -296,6 +313,7 @@ def setup_training(args):

if args.local_rank == -1:
args.local_rank = 0
args.world_rank = 0

print("args.local_rank: ", args.local_rank)
torch.cuda.set_device(args.local_rank)
Expand All @@ -317,6 +335,14 @@ def setup_training(args):
logger.info("setup_training: args.train_batch_size = %d", args.train_batch_size)
return device, args

def setup_torch_distributed(world_rank, world_size):
os.environ['RANK'] = str(world_rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str('12345')
torch.distributed.init_process_group(backend='nccl', world_size=world_size,
rank=world_rank)
return

def prepare_model(args, device):
config = BertConfig.from_pretrained('bert-base-uncased', cache_dir=args.cache_dir)
Expand All @@ -327,6 +353,8 @@ def prepare_model(args, device):
config.num_hidden_layers = args.force_num_hidden_layers

model = BertForPreTraining(config)
if args.init_state_dict is not None:
model.load_state_dict(args.init_state_dict, strict=False)
model_desc = bert_model_description(config)

lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion)
Expand All @@ -346,7 +374,8 @@ def prepare_model(args, device):
'world_rank': max(0, args.local_rank),
'world_size': args.world_size,
'local_rank': max(0, args.local_rank),
'allreduce_post_accumulation': args.allreduce_post_accumulation},
'allreduce_post_accumulation': args.allreduce_post_accumulation,
'deepspeed_zero_optimization': {'stage': args.deepspeed_zero_stage}},
'lr_scheduler': lr_scheduler
})

Expand Down Expand Up @@ -455,6 +484,9 @@ def do_pretrain(args):
if tb_writer:
tb_writer.close()

if global_step >= args.max_steps:
if args.save_checkpoint:
experimental_save_checkpoint(model, args.output_dir)
final_loss = average_loss / (args.log_freq * args.gradient_accumulation_steps)
return final_loss

Expand Down Expand Up @@ -549,6 +581,69 @@ def test_pretrain_convergence(self):
tensorboard_dir=generate_tensorboard_logdir('/bert_data/hf_data/test_out/'))
final_loss = do_pretrain(args)
return final_loss

def test_pretrain_zero(self):
assert self.world_size >0, "ZeRO test requires a distributed run."
setup_torch_distributed(self.world_rank, self.world_size)
per_gpu_batch_size = 32
optimization_batch_size = per_gpu_batch_size*self.world_size # set to disable grad accumulation

self.train_batch_size = optimization_batch_size
self.gradient_accumulation_steps = 1
self.deepspeed_zero_stage = 1
self.force_num_hidden_layers = 2
self.max_seq_length = 32
self.output_dir = './bert_pretrain_ckpt'
if self.world_rank == 0:
if os.path.isdir(self.output_dir):
shutil.rmtree(self.output_dir)
os.makedirs(self.output_dir, exist_ok = True)

torch.distributed.barrier()

assert os.path.exists(self.output_dir)

# run a few optimization steps
self.max_steps = 200
args = PretrainArguments(
output_dir=self.output_dir,
bert_model=self.bert_model,
local_rank=self.local_rank,
world_rank=self.world_rank,
world_size=self.world_size,
max_steps=self.max_steps,
learning_rate=self.learning_rate,
max_seq_length=self.max_seq_length,
max_predictions_per_seq=self.max_predictions_per_seq,
train_batch_size=self.train_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
input_dir=self.input_dir,
fp16=self.fp16,
allreduce_post_accumulation=self.allreduce_post_accumulation,
force_num_hidden_layers=self.force_num_hidden_layers,
deepspeed_zero_stage = self.deepspeed_zero_stage,
save_checkpoint = True)
train_loss = do_pretrain(args)

# ensure all workers reach this point before loading the checkpointed state
torch.distributed.barrier()

# on rank 0, load the trained state
if args.world_rank == 0:
checkpoint_files = _list_checkpoint_files(self.output_dir, "ORT_checkpoint")
ckpt_agg = _CombineZeroCheckpoint(checkpoint_files)
final_state_dict = ckpt_agg.aggregate_checkpoints()

args.init_state_dict = final_state_dict

torch.distributed.barrier()

# run a single step to get the loss, on rank 0 should be lesser than starting loss
args.save_checkpoint = False
args.max_steps = 1
args.deepspeed_zero_stage = 0
final_loss = do_pretrain(args)
return final_loss


# to do parallel training:
Expand Down Expand Up @@ -590,6 +685,15 @@ def test_pretrain_convergence(self):
logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss)
test.assertLess(final_loss, 8.5)
logger.info("ORTBertPretrainTest.test_pretrain_convergence() passed")
elif len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_zero':
logger.info("running ORTBertPretrainTest.test_pretrain_zero()...")
final_loss = test.test_pretrain_zero()
logger.info("ORTBertPretrainTest.test_pretrain_zero() rank = %i final loss = %f", local_rank, final_loss)
if local_rank == 0:
test.assertLess(final_loss, 10.2)
else:
test.assertGreater(final_loss, 11.0)
logger.info("ORTBertPretrainTest.test_pretrain_zero() passed")
else:
# https://microsoft.sharepoint.com/teams/ONNX2/_layouts/15/Doc.aspx?sourcedoc={170774be-e1c6-4f8b-a3ae-984f211fe410}&action=edit&wd=target%28ONNX%20Training.one%7C8176133b-c7cb-4ef2-aa9d-3fdad5344c40%2FGitHub%20Master%20Merge%20Schedule%7Cb67f0db1-e3a0-4add-80a6-621d67fd8107%2F%29
# to make equivalent args for cpp convergence test
Expand Down

0 comments on commit 0a9b83a

Please sign in to comment.