Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ONNX conversion #6

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,4 @@ dmypy.json

# Pyre type checker
.pyre/
tmp/
4 changes: 2 additions & 2 deletions DeBERTa/apps/sequence_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

loss = 0
loss = torch.tensor(0).to(logits)
if labels is not None:
if self.num_labels ==1:
# regression task
Expand All @@ -68,4 +68,4 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
label_confidence = 1
loss = -((log_softmax(logits)*labels).sum(-1)*label_confidence).mean()

return (logits,loss)
return (loss, logits)
ganik marked this conversation as resolved.
Show resolved Hide resolved
74 changes: 69 additions & 5 deletions DeBERTa/apps/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from ..utils import *
from ..utils import xtqdm as tqdm
from .task_registry import tasks
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler

from ..training import DistributedTrainer, initialize_distributed, batch_to, set_random_seed,kill_children
from ..data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader
from ..data import DistributedBatchSampler, SequentialSampler, BatchSampler, RandomSampler, AsyncDataLoader

def create_model(args, num_labels, model_class_fn):
# Prepare model
Expand Down Expand Up @@ -217,9 +218,63 @@ def run_predict(args, model, device, eval_data, prefix=None):
if predict_fn:
predict_fn(predicts, args.output_dir, name, prefix)

def deberta_model_description(args):
vocab_size = 30528
# set concrete input sizes to permit optimization
input_ids_desc = IODescription('input_ids', [args.train_batch_size, args.max_seq_length], torch.int32, num_classes=vocab_size)
type_ids_desc = IODescription('type_ids', [args.train_batch_size, args.max_seq_length], torch.int32) # num_classes=?
position_ids_desc = IODescription('position_ids', [args.train_batch_size, args.max_seq_length], torch.int32) # num_classes=?
input_mask_desc = IODescription('input_mask', [args.train_batch_size, args.max_seq_length], torch.int32) # num_classes=?
labels_desc = IODescription('labels', [args.train_batch_size, args.max_seq_length], torch.float32) # num_classes=?

loss_desc = IODescription('loss', [], torch.float32)
return ModelDescription([input_ids_desc, type_ids_desc, position_ids_desc, input_mask_desc, labels_desc], [loss_desc])

def create_ort_trainer(args, device, model):
# default initial settings: b1=0.9, b2=0.999, e=1e-6
def map_optimizer_attributes(name):
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
no_decay = False
for no_decay_key in no_decay_keys:
if no_decay_key in name:
no_decay = True
break
if no_decay:
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
else:
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}

# we request ORTTrainer to create a LambOptimizer with given optimizer_attributes.
# train_step does forward, backward, and optimize step.
model = ORTTrainer(model, None, deberta_model_description(args), "LambOptimizer",
map_optimizer_attributes,
IODescription('Learning_Rate', [1,], torch.float32),
device,
_opset_version = 10)

return model

def run_onnx_training(args, model, device, train_data, prefix=None):
# runs training in ONNX
trainer = create_ort_trainer(args, device, model)
train_sampler = RandomSampler(len(train_data))
batch_sampler = BatchSampler(train_sampler, args.train_batch_size)
batch_sampler = DistributedBatchSampler(batch_sampler, rank=args.rank, world_size=args.world_size)
train_dataloader = DataLoader(train_data, batch_sampler=batch_sampler, num_workers=args.workers, pin_memory=True)
torch.cuda.empty_cache()
for step, batch in enumerate(AsyncDataLoader(train_dataloader, 100)):
#import pdb
#pdb.set_trace()
batch = batch_to(batch, device)
with torch.no_grad():
trainer.train_step(batch['input_ids'], batch['type_ids'], batch['position_ids'], batch['input_mask'], batch['labels'])
# conversion fails now with:
# site-packages/torch/onnx/utils.py:617: UserWarning: ONNX export failed on ATen operator broadcast_tensors
Copy link
Member Author

@ganik ganik Aug 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast_tensor and mse_loss are ops that are not implemented in ONNX currently. To get unblocked need to modify functional.py as per below comment

Copy link
Member Author

@ganik ganik Aug 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mse_loss implementation in https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L2682 uses 2 ops that are not implemented: broadcast_tensors() and mse_loss(). Working around this to get unblocked, made a patch:
#expanded_input, expanded_target = torch.broadcast_tensors(input, target)
expanded_input = input + torch.zeros(target.size())
expanded_target = target + torch.zeros(input.size())
#ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
t = expanded_input - expanded_target
t = t * t
ret = torch.mean(t)

# because torch.onnx.symbolic_opset10.broadcast_tensors does not exist

def main(args):
if not args.do_train and not args.do_eval and not args.do_predict:
raise ValueError("At least one of `do_train` or `do_eval` or `do_predict` must be True.")
if not args.do_train and not args.do_eval and not args.do_predict and not args.do_onnx:
raise ValueError("At least one of `do_train` or `do_eval` or `do_predict` or `do_onnx` must be True.")
os.makedirs(args.output_dir, exist_ok=True)
task_name = args.task_name.lower()
random.seed(args.seed)
Expand All @@ -236,11 +291,11 @@ def main(args):
test_data = processor.test_data(max_seq_len=args.max_seq_length)
logger.info(" Prediction batch size = %d", args.predict_batch_size)

if args.do_train:
if args.do_train or args.do_onnx:
train_data = processor.train_data(max_seq_len=args.max_seq_length, mask_gen = None, debug=args.debug)
model_class_fn = processor.get_model_class_fn()
model = create_model(args, len(label_list), model_class_fn)
if args.do_train:
if args.do_train or args.do_onnx:
with open(os.path.join(args.output_dir, 'model_config.json'), 'w', encoding='utf-8') as fs:
fs.write(model.config.to_json_string() + '\n')
logger.info("Model config {}".format(model.config))
Expand All @@ -257,6 +312,10 @@ def main(args):
if args.do_predict:
run_predict(args, model, device, test_data, prefix=args.tag)

# trains in ONNX
if args.do_onnx:
run_onnx_training(args, model, device, train_data, prefix=args.tag)

def build_argument_parser():
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -437,6 +496,11 @@ def build_argument_parser():
default=None,
type=str,
help="The path of pre-trained RoBERTa model")

parser.add_argument("--do_onnx",
default=False,
action='store_true',
help="Whether to run training in ONNX")
return parser

if __name__ == "__main__":
Expand Down