-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sklearn autologging: Fix behavior when a child and its parent are both patched #3582
Conversation
Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
This reverts commit 5a84bba. Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
This reverts commit b4fcfa5. Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
prev_patch = getattr(patch.destination, _ACTIVE_PATCH, None) | ||
if not hasattr(patch.destination, original_name) or ( | ||
prev_patch | ||
and prev_patch.destination != patch.destination | ||
and issubclass(patch.destination, prev_patch.destination) | ||
): | ||
setattr(patch.destination, original_name, target) | ||
|
||
setattr(patch.destination, patch.name, patch.obj) | ||
setattr(patch.destination, _ACTIVE_PATCH, patch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@harupy These are the critical portions of the change to address problem (1) from #3574 (comment).
We now retain a reference to the previous patch that was applied to a class or one of its super classes, and, if the patch was applied to a super class, we overwrite it with the patch that's being applied to this class directly. See https://github.com/christophercrouzet/gorilla/blob/0045a7f4b5c46bda208dbce9e628f97bc9e551e0/gorilla.py#L295-L300 for the original implementation.
should_start_run = mlflow.active_run() is None | ||
if should_start_run: | ||
try_mlflow_log(mlflow.start_run) | ||
|
||
_log_pretraining_metadata(self, *args, **kwargs) | ||
|
||
original_fit = gorilla.get_original_attribute(self, func_name) | ||
original_fit = gorilla.get_original_attribute(clazz, func_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This (together line 914) is the critical change to address problem (2) from #3574 (comment).
Assume that the old line was being invoked from a patched fit
call on a class called Parent
and that self
refers to an instance of class Child
which inherits from Parent
. Calling get_original_attribute()
on self
would yield the Child.fit()
function, rather than Parent.fit()
, leading to recursive behavior.
Now, instead of fetching the attribute from self
, we fetch it from clazz
(which would refer to Parent
in the above example).
Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
Signed-off-by: Corey Zumar <corey.zumar@databricks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍
What changes are proposed in this pull request?
Fixes #3574 by addressing the following two issues:
CountVectorizer
) forcesget_original_attribute
of its children to refer to the parent class attribute, even if we patch the child later on. I.e.This ordering is okay
but this ordering is not
This appears to be caused by the following line that prevents the original attribute from being overwritten once it's set: https://github.com/christophercrouzet/gorilla/blob/0c895b34311b6546bde35854e24862c001279add/gorilla.py#L280. As a result, setting the attribute on the parent, which sets the original attribute on the child if the child does not already define it, means that the child is always stuck with the parent's original attribute. This seems to be fixed in the
master
branch of gorilla but has not been released. Either we can pull in the copy frommaster
or pull in a copy from 0.3.0 that removes this problematichasattr
check.fit_mlflow
andfit_predict
, we performget_original_attribute(self, func_name)
. In the case where we're in a super class function (e.g.,CountVectorizer.fit_transform()
),self
still refers to the subclass (e.g.,TfidfVectorizer
). As a result, callingget_original_attribute
in the patched super class method fetches invokes the patched subclass method, which creates the infinite recursion problem observed in the issue description. I've filed Sklearn autologging: Fix behavior when a child and its parent are both patched #3582 to address this problem by fetching the original function attribute by reference to a class rather than by reference toself
. This PR still needs docs & tests, and it needs to incorporate a fix for (1) before it's mergeable.How is this patch tested?
Unit tests
Release Notes
Fixes an infinite recursion bug in scikit-learn autologging caused by invoking a
fit()
method on a model class that includes a call tosuper.fit()
as part of its implementation (see #3574).Is this a user-facing change?
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/projects
: MLproject format, project running backendsarea/scoring
: Local serving, model deployment tools, spark UDFsarea/server-infra
: MLflow server, JavaScript dev serverarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, JavaScript, plottingarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes