-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate text-generation pipeline from inference.py to TSModelForCausalLM #300
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this integration @jiqing-feng
optimum/intel/generation/modeling.py
Outdated
@@ -81,9 +81,34 @@ def load_model(file_name: Union[str, Path]): | |||
torch.jit.freeze(model.eval()) | |||
return model | |||
|
|||
@staticmethod | |||
def jit_trace(model: PreTrainedModel, task: str, config: PretrainedConfig, use_cache: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would make sense to create a separate function trace
(working for all architectures not only causal LM) and to use it in TSModelForCausalLM._from_transformers
instead of having jit_trace
and export_model
methods.
Hi @echarlaix , I have created a separate function trace in a new file, could you please help to review it? Thanks! BTW, the failed check seems not related to my changes. |
optimum/intel/generation/tracing.py
Outdated
from optimum.exporters import TasksManager | ||
|
||
|
||
def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we keep prepare_jit_inputs
and jit_trace
in modeling.py
?
optimum/intel/ipex/inference.py
Outdated
if self._model.task == "text-generation": | ||
jit_model = jit_trace( | ||
model=model, | ||
task=self._model.task, | ||
use_cache=self._original.config.use_cache, | ||
) | ||
model = TSModelForCausalLM( | ||
model=jit_model, | ||
config=self._original.config, | ||
use_cache=self._original.config.use_cache, | ||
) | ||
else: | ||
jit_inputs = [] | ||
dummy_input = self._model.tokenizer("") | ||
for key in dummy_input: | ||
jit_inputs.append( | ||
torch.ones((1, len(dummy_input[key])), dtype=torch.long) | ||
) | ||
model = torch.jit.trace(model, jit_inputs, strict=False) | ||
model = torch.jit.freeze(model) | ||
model(*jit_inputs) | ||
model(*jit_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not :
if self._model.task == "text-generation": | |
jit_model = jit_trace( | |
model=model, | |
task=self._model.task, | |
use_cache=self._original.config.use_cache, | |
) | |
model = TSModelForCausalLM( | |
model=jit_model, | |
config=self._original.config, | |
use_cache=self._original.config.use_cache, | |
) | |
else: | |
jit_inputs = [] | |
dummy_input = self._model.tokenizer("") | |
for key in dummy_input: | |
jit_inputs.append( | |
torch.ones((1, len(dummy_input[key])), dtype=torch.long) | |
) | |
model = torch.jit.trace(model, jit_inputs, strict=False) | |
model = torch.jit.freeze(model) | |
model(*jit_inputs) | |
model(*jit_inputs) | |
model = jit_trace( | |
model=model, | |
task=self._model.task, | |
use_cache=self._original.config.use_cache, | |
) | |
if self._model.task == "text-generation": | |
model = TSModelForCausalLM( | |
model=model, | |
config=self._original.config, | |
use_cache=self._original.config.use_cache, | |
) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for integrating TSModelForCausalLM
to inference_mode
@jiqing-feng !!
Hi @echarlaix , thanks for your advice. I have updated the code. Could you please review it? Thanks! |
optimum/intel/ipex/inference.py
Outdated
model = torch.jit.freeze(model) | ||
model(*jit_inputs) | ||
model(*jit_inputs) | ||
jit_model = jit_trace( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be
jit_model = jit_trace( | |
model = jit_trace( | |
optimum/intel/ipex/inference.py
Outdated
if self._model.task == "text-generation": | ||
self._model.model = _ModelGenerationWrapper(model, self._original) | ||
if self._model.task == "text-generation" and self._jit: | ||
self._model.model = model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be for all cases :
self._model.model = _ModelFallbackWrapper(model, self._original)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @echarlaix . If I use _ModelFallbackWrapper, when I execute model.generate
, it will go to the func __getattr__(self, item)
, and will return getattr(self._default, item)
so I actually execute self._default.generate
. I didn't use _ModelFallbackWrapper because I cannot use the generation of my optimized model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not subclass from _ModelFallbackWrapper
to enable it then ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echarlaix , I think we should enhance TSModelForCausalLM for all text-generation tasks, then for text-generation task, we don't need to use _ModelFallbackWrapper, is it OK?
optimum/intel/ipex/inference.py
Outdated
@@ -188,16 +101,23 @@ def __enter__(self): | |||
with torch.cpu.amp.autocast(enabled=(self._dtype == torch.bfloat16)), torch.no_grad(): | |||
if self._model.tokenizer is not None and self._jit: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be :
if self._model.tokenizer is not None and self._jit: | |
if self._jit: | |
Hi, @echarlaix . I subclass a new class from BTW, could we merge it if there are no big issues? I have the following PR for enabling text2text-generation and other generation tasks. Some issues like variable naming could be fixed in the next PR. |
Hi @echarlaix , I have updated the test cases of 310. Could you help me have a look at the failed check? Thanks! |
feb5ac7
to
2395ea2
Compare
@mfuntowicz please help review, This PR utilizes the TSModelForCausalLM and fix the jit issue in llama text generation |
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Hi @echarlaix . I have integrated my changes from inference.py to TSModelForCausalLM, refer to: #227.
I also use the function
export_model
to enable ipex model in TSModelForCausalLM. Would like your opinion. Thanks!