Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize parameters in Model.predict() #96

Closed
arxyzan opened this issue Oct 1, 2023 · 1 comment
Closed

Sanitize parameters in Model.predict() #96

arxyzan opened this issue Oct 1, 2023 · 1 comment
Assignees

Comments

@arxyzan
Copy link
Member

arxyzan commented Oct 1, 2023

Models' predict methods should give the user the ability to set different parameters as kwargs for preprocess, forward and post_process.

@arxyzan
Copy link
Member Author

arxyzan commented Oct 4, 2023

Expectedly, I had to do a lot of changes to the library to make this happen.

First step: Accepting parameters explicitly

Formerly, all three methods received only two parameters: inputs (A dict of all required inputs) and **kwargs (Just to prevent invalid argument error when giving kwargs in predict). This was safe but really vague for the users since they had to look into the code for each step to see which parameters are required in the inputs.
Having this task done, now all three methods in all models accept their required parameters explicitly. (I changed all models codes to make this happen 😭)
Right now the issue is kinda fixed, but there is still a silent issue; If a user writes their own model with preprocess(), forward() and post_process() defined and forgets to include **kwargs in function parameters, kwargs from the predict() method will be passed to all three functions which will cause invalid argument error. So we need to pass the right parameters in kwargs to the right methods.

Second step: Unpacking kwargs in predict()

Defined a function in Model called _unpack_prediction_kwargs as below:

def _unpack_prediction_kwargs(self, **kwargs):
    # Whether to use forward or generate based on model type (Model or GenerativeModel)
    inference_fn = type(self).generate if hasattr(self, "generate") else type(self).forward
    
    # Get keyword arguments from the child class (ignore self, first arg and **kwargs)
    preprocess_kwargs_keys = list(dict(inspect.signature(type(self).preprocess).parameters).keys())[2:-1]
    post_process_kwargs_keys = list(dict(inspect.signature(type(self).post_process).parameters).keys())[2:-1]
    forward_kwargs_keys = list(dict(inspect.signature(inference_fn).parameters).keys())[2:-1]
    
    preprocess_kwargs = {k: kwargs.get(k) for k in preprocess_kwargs_keys if k in kwargs}
    forward_kwargs = {k: kwargs.get(k) for k in forward_kwargs_keys if k in kwargs}
    post_process_kwargs = {k: kwargs.get(k) for k in post_process_kwargs_keys if k in kwargs}
    
    return preprocess_kwargs, forward_kwargs, post_process_kwargs

This method will inspect the needed parameters for each step and extracts all keyword arguments as a tuple of

(preprocess_kwargs, forward_kwargs, post_process_kwargs)

So now we can make sure that the right kwargs are passed to each of the functions! 🎉🎉🎉

@arxyzan arxyzan closed this as completed Oct 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant