Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

add arrange_caffe2_model_inputs in BaseModel #1292

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions pytext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,7 @@ def _predict(workspace_id, predict_net, model, tensorizers, input):
name: tensorizer.prepare_input(input)
for name, tensorizer in tensorizers.items()
}
model_inputs = model.arrange_model_inputs(tensor_dict)
flat_model_inputs = []
for model_input in model_inputs:
if isinstance(model_input, tuple):
flat_model_inputs.extend(model_input)
else:
flat_model_inputs.append(model_input)
model_inputs = flat_model_inputs
model_inputs = model.arrange_caffe2_model_inputs(tensor_dict)
model_input_names = model.get_export_input_names(tensorizers)
vocab_to_export = model.vocab_to_export(tensorizers)
for blob_name, model_input in zip(model_input_names, model_inputs):
Expand Down
14 changes: 14 additions & 0 deletions pytext/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,20 @@ def arrange_model_context(self, tensor_dict):
def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None):
pass

def arrange_caffe2_model_inputs(self, tensor_dict):
"""
Generate inputs for exported caffe2 model, default behavior is flatten the
input tuples
"""
model_inputs = self.arrange_model_inputs(tensor_dict)
flat_model_inputs = []
for model_input in model_inputs:
if isinstance(model_input, tuple):
flat_model_inputs.extend(model_input)
else:
flat_model_inputs.append(model_input)
return flat_model_inputs


class Model(BaseModel):
"""
Expand Down