Skip to content

Commit

Permalink
1. Add code for DeBERTaV3 pre-training; 2. Fix error in torch 1.11; 3…
Browse files Browse the repository at this point in the history
…. Add code for ONNX export
  • Loading branch information
BigBird01 committed Mar 19, 2023
1 parent 2c5b6b2 commit c794b71
Show file tree
Hide file tree
Showing 16 changed files with 853 additions and 34 deletions.
14 changes: 12 additions & 2 deletions DeBERTa/apps/models/sequence_classification.py
Expand Up @@ -48,7 +48,7 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
pooled_output = self.pooler(encoder_layers[-1])
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = 0
loss = torch.tensor(0).to(logits)
if labels is not None:
if self.num_labels ==1:
# regression task
Expand All @@ -74,7 +74,17 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
'logits' : logits,
'loss' : loss
}
return (logits,loss)

def export_onnx(self, onnx_path, input):
del input[0]['labels'] #= input[0]['labels'].unsqueeze(1)
torch.onnx.export(self, input, onnx_path, opset_version=13, do_constant_folding=False, \
input_names=['input_ids', 'type_ids', 'input_mask', 'position_ids', 'labels'], output_names=['logits', 'loss'], \
dynamic_axes={'input_ids' : {0 : 'batch_size', 1: 'sequence_length'}, \
'type_ids' : {0 : 'batch_size', 1: 'sequence_length'}, \
'input_mask' : {0 : 'batch_size', 1: 'sequence_length'}, \
'position_ids' : {0 : 'batch_size', 1: 'sequence_length'}, \
# 'labels' : {0 : 'batch_size', 1: 'sequence_length'}, \
})

def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
Expand Down
49 changes: 47 additions & 2 deletions DeBERTa/apps/run.py
Expand Up @@ -152,9 +152,21 @@ def _ignore(k):
def run_eval(args, model, device, eval_data, prefix=None, tag=None, steps=None):
# Run prediction for full data
prefix = f'{tag}_{prefix}' if tag is not None else prefix
device = torch.device('cpu') if device is None else device
if args.export_onnx_model:
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType
if args.fp16:
ort_model = os.path.join(args.output_dir, f'{prefix}_onnx_fp16.bin')
ort_model_qt = None
else:
ort_model = os.path.join(args.output_dir, f'{prefix}_onnx_fp32.bin')
ort_model_qt = os.path.join(args.output_dir, f'{prefix}_onnx_qt.bin')

eval_results=OrderedDict()
eval_metric=0
no_tqdm = (True if os.getenv('NO_TQDM', '0')!='0' else False) or args.rank>0
ort_session = None
for eval_item in eval_data:
name = eval_item.name
eval_sampler = SequentialSampler(len(eval_item.data))
Expand All @@ -167,9 +179,38 @@ def run_eval(args, model, device, eval_data, prefix=None, tag=None, steps=None):
predicts=[]
labels=[]
for batch in tqdm(AsyncDataLoader(eval_dataloader), ncols=80, desc='Evaluating: {}'.format(prefix), disable=no_tqdm):
_batch = batch.copy()
batch = batch_to(batch, device)
with torch.no_grad():
output = model(**batch)
if args.export_onnx_model:
if ort_session is None:
if args.rank < 1:
model.export_onnx(ort_model, (batch.copy(),))
if ort_model_qt is not None:
quantize_dynamic(ort_model, ort_model_qt)
ort_model = ort_model_qt
if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
torch.distributed.barrier()
sess_opt = ort.SessionOptions()
os.environ["ORT_TENSORRT_ENGINE_CACHE_ENABLE"] = "1"
os.environ["ORT_TENSORRT_FP16_ENABLE"] = "1" #TRT precision: 1: TRT FP16, 0: TRT FP32
ort_session = ort.InferenceSession(ort_model, sess_options=sess_opt, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider'])
numpy_input = {}
for k in [p.name for p in ort_session.get_inputs()]:
if isinstance(_batch[k], torch.Tensor):
numpy_input[k] = _batch[k].cpu().numpy()

warmup = ort_session.run(None, numpy_input)
#cuda_session = ort.InferenceSession(ort_model, sess_options=sess_opt, providers=['CUDAExecutionProvider'])
#warmup = cuda_session.run(None, numpy_input)
numpy_input = {}
for k in [p.name for p in ort_session.get_inputs()]:
if isinstance(_batch[k], torch.Tensor):
numpy_input[k] = _batch[k].cpu().numpy()
output = ort_session.run(None, numpy_input)
output = dict([(n.name,torch.tensor(o).to(device)) for n,o in zip(ort_session.get_outputs(), output)])
if ort_session is None:
with torch.no_grad():
output = model(**batch)
logits = output['logits'].detach()
tmp_eval_loss = output['loss'].detach()
if 'labels' in output:
Expand Down Expand Up @@ -415,6 +456,10 @@ def build_argument_parser():
type=str,
help="The loss function used to calculate adversarial loss. It can be one of symmetric-kl, kl or mse.")

parser.add_argument('--export_onnx_model',
default=False,
type=boolean_string,
help="Whether to export model to ONNX format.")

return parser

Expand Down

0 comments on commit c794b71

Please sign in to comment.