Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor transformed estimators #51

Closed
aazuspan opened this issue Sep 18, 2023 · 6 comments · Fixed by #52
Closed

Refactor transformed estimators #51

aazuspan opened this issue Sep 18, 2023 · 6 comments · Fixed by #52
Assignees
Labels
estimator Related to one or more estimators refactor Code cleanup without changing functionality

Comments

@aazuspan
Copy link
Contributor

aazuspan commented Sep 18, 2023

@grovduck I have a proposal to run by you, inspired by your refactor of the inheritance in #50. This design builds off of the changes there, so that would be merged before tackling this. I thought about proposing this as part of that PR, but didn't want to derail things with even more refactoring.

I suggest we turn TransformedKNeighborsRegressor into an abstract class by inheriting from ABC, then add an abstract method _get_transform that all subclasses would be required to implement, e.g.:

class EuclideanKNNRegressor(TransformedKNeighborsRegressor):
    def _get_transform(self) -> TransformerMixin:
        return StandardScalerWithDOF(ddof=1)

Next, we would get rid of the fit methods on those estimators and move that functionality into TransformedKNeighborsRegressor.fit, using a _set_fit_transform method to handle the instantiation and fitting of the transformer.

class TransformedKNeighborsRegressor(RawKNNRegressor, ABC):
    ...

    @abstractmethod
    def _get_transform(self) -> TransformerMixin:
        """Return the transformer to use for fitting. Must be implemented by subclasses."""
        ...

    def _set_fit_transform(self, X, y) -> TransformerMixin:
        self.transform_ = self._get_transform().fit(X, y)

    def fit(self, X, y):
        """Fit using transformed feature data."""
        self._validate_data(X, y, force_all_finite=True, multi_output=True)        
        self._set_dataframe_index_in(X)
        self._set_fit_transform(X, y)

        X_transformed = self.transform_.transform(X)
        return super().fit(X_transformed, y)

To accommodate for fitting transformers with separate y data in GNN and MSN, we could add a kNN-independent YFitMixin that overrides the fit method to accept the additional argument, store it as an attribute, and fit the transformer with it using an overriden _set_fit_transform:

class YFitMixin:
    """Mixin for transformed estimators that can fit the transformer with a separate y."""
    def _set_fit_transform(self, X, y):
        y_fit = self.y_fit_ if self.y_fit_ is not None else y
        self.transform_ = self._get_transform().fit(X, y_fit)

    def fit(self, X, y, y_fit=None):
        self.y_fit_ = y_fit
        return super().fit(X, y)

Overall, this should reduce some code duplication in the fit methods, prevent instantiation of TransformedKNeighborsRegressor without having to rely on making it private, and add a runtime check to ensure that all transformed estimators define a transformer function. The main downside is the need to store a reference to y_fit_. There may be another way to handle the YFitMixin, but that was the best solution I could come up with after trying a few different strategies.

Curious to hear your thoughts on this design, and if you foresee any limitations.

@aazuspan aazuspan added refactor Code cleanup without changing functionality estimator Related to one or more estimators labels Sep 18, 2023
@aazuspan aazuspan added this to the Core Estimators milestone Sep 18, 2023
@grovduck
Copy link
Member

@aazuspan, I like it! I especially like the prevention of instantiation by code rather than by convention and, as you say, it gets rid of some duplicated code as well. A couple of quick questions/comments just to make sure I get it:

  • GNN and MSN would be, for example, GNNRegressor(YFitMixin, TransformedKNeighborsRegressor), correct? Just wanted to verify that YFitMixin.fit gets called before TransformedKNeighborsRegressor.fit ...
  • I don't know if this was intentional, but _set_fit_transform could be confused with the more generic fit_transform method defined in each transformer, which fits/transforms in a single step. But I think the intention of _set_fit_transform is really just to set the transform? I actually wonder if it's even necessary to have that method - perhaps this instead? Is there an advantage to this that I'm overlooking?
class TransformedKNeighborsRegressor(RawKNNRegressor, ABC):
    ...

    @abstractmethod
    def _get_transform(self) -> TransformerMixin:
        """Return the transformer to use for fitting. Must be implemented by subclasses."""
        ...

    def fit(self, X, y):
        """Fit using transformed feature data."""
        self._validate_data(X, y, force_all_finite=True, multi_output=True)        
        self._set_dataframe_index_in(X)
        self.transform_ = self._get_transform().fit(X, y)

        X_transformed = self.transform_.transform(X)
        return super().fit(X_transformed, y)

@aazuspan
Copy link
Contributor Author

GNN and MSN would be, for example, GNNRegressor(YFitMixin, TransformedKNeighborsRegressor), correct?

Bingo!

I don't know if this was intentional, but _set_fit_transform could be confused with the more generic fit_transform method defined in each transformer, which fits/transforms in a single step. But I think the intention of _set_fit_transform is really just to set the transform?

Yes, great point! That was unintentional - we should definitely choose a better name. Is _set_fitted_transform still too similar? Or _set_fitted_transformer (probably combined with _get_transformer)?

I actually wonder if it's even necessary to have that method - perhaps this instead? Is there an advantage to this that I'm overlooking?

Separating out the transformer fitting is a workaround to an inheritance problem that arises when YFitMixin.fit calls super().fit, which calls TransformedKNeighborsRegressor.fit and refits the transformer with the wrong y data. I also played around with 1) adding a reset=True argument to TransformedKNeighborsRegressor.fit to avoid resetting the transformer if called intentionally, and 2) having TransformedKNeighborsRegressor.fit accept **kwargs and fit the transformer with y_fit if present. Those generally worked, but I wanted to keep as much of the y_fit logic out of TransformedKNeighborsRegressor as possible, and having a separate fitting method allowed that.

It's 100% possible there's a better workaround that I didn't think of though!

@grovduck
Copy link
Member

Yes, great point! That was unintentional - we should definitely choose a better name. Is _set_fitted_transform still too similar? Or _set_fitted_transformer (probably combined with _get_transformer)?

I guess I have a slight preference for the latter, but either is fine by me.

Separating out the transformer fitting is a workaround to an inheritance problem that arises when YFitMixin.fit calls super().fit, which calls TransformedKNeighborsRegressor.fit and refits the transformer with the wrong y data. I also played around with 1) adding a reset=True argument to TransformedKNeighborsRegressor.fit to avoid resetting the transformer if called intentionally, and 2) having TransformedKNeighborsRegressor.fit accept **kwargs and fit the transformer with y_fit if present. Those generally worked, but I wanted to keep as much of the y_fit logic out of TransformedKNeighborsRegressor as possible, and having a separate fitting method allowed that.

Oof, let's avoid that complexity! Having a separate _set_fitted_transform or _set_fitted_transformer method sounds like a preferable solution.

@aazuspan
Copy link
Contributor Author

I guess I have a slight preference for the latter, but either is fine by me.

Agreed! Should we stick with transform_ for the attribute, or do you think we should go with transformer_?

Oof, let's avoid that complexity!

Sounds like a plan! I already worked most of this up to make sure it was possible, so I'll make a PR shortly.

@grovduck
Copy link
Member

Should we stick with transform_ for the attribute, or do you think we should go with transformer_?

Ooo, maybe transformer_ is better to pair with the class names (CCATransformer, MahalanobisTransformer, etc.). Maybe we've gone back and forth on this one 😉

@aazuspan aazuspan self-assigned this Sep 18, 2023
@aazuspan aazuspan linked a pull request Sep 18, 2023 that will close this issue
aazuspan added a commit that referenced this issue Sep 19, 2023
This resolves #51 by:

* Refactoring the TransformedKNeighborsRegressor into an abstract class
* Moving transformer instantiation out of the estimator fit methods and 
into an abstract _get_transformer method
* Moving transformer fitting out of the estimator fit methods and into the 
TransformedKNeighborsRegressor._set_fitted_transformer method to reduce duplication
* Creating a YFitMixin to handle transformer fitting for GNN and MSN

This also:

* Renames the transform_ attribute to transformer_ for consistency with the new 
methods
* Adds a _validate_data check with force_all_finite=True when fitting all transformed 
estimators. This was needed by MSN, but also fixed an xfailing estimator check for GNN, 
which allowed us to drop that from the tags.
@aazuspan
Copy link
Contributor Author

Resolved by #52

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
estimator Related to one or more estimators refactor Code cleanup without changing functionality
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants