Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve task api code quality #2376

Merged
merged 3 commits into from May 25, 2021
Merged

Conversation

mariosasko
Copy link
Collaborator

@mariosasko mariosasko commented May 18, 2021

Improves the code quality of the TaskTemplate dataclasses.

Changes:

  • replaces return NotImplemented with raise NotImplementedError
  • replaces sorted with len in the uniqueness check
  • defines label2id and id2label in the TextClassification template as properties
  • replaces the object.__setattr__(self, attr, value) syntax with (IMO nicer) self.__dict__[attr] = value

object.__setattr__(self, "labels", tuple(sorted(self.labels)))
object.__setattr__(self, "text_column", self.text_column)
object.__setattr__(self, "label_column", self.label_column)
self.__dict__["labels"] = tuple(sorted(self.labels))
self.label_schema["labels"] = ClassLabel(names=self.labels)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is an issue. Modifying a class attribute based on the value of an instance attribute doesn't make sense if multiple instances of the same class are allowed (if that's the case, the class attribute will have a valid value only for the instance that was initialized last). One way to fix this is with the help of descriptors:

class FeaturesWithLazyClassLabel:
    def __init__(self, features, label_column="labels"):
        assert label_column in features, f"Key '{label_column}' missing in features {features}"
        self._features = features
        self._label_column = label_column

    def __get__(self, obj, objtype=None):
        if obj is None:
            return self._features
     
        assert hasattr(obj, self._label_column), f"Object has no attribute '{self._label_column}'"
        # note: this part is not cached, but we can add this easily
        features = self._features.copy()
        features["labels"] = ClassLabel(names=getattr(obj, self._label_column))
        return features

and then in TextClassification:

label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel}))

Copy link
Member

@lhoestq lhoestq May 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this !

The fix using FeaturesWithLazyClassLabel is fine IMO (though it would be nice if we find a simpler way to fix this)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch @mariosasko! i'm also happy with the FeatureWithLazyClassLabel fix for now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a temporary solution. I agree we should try to find a simpler way to fix this. The current API design does not seem to fit the task very well (having label_schema defined as a class attribute but depends on the value of instances?), so we should rethink that part IMO.

@lhoestq
Copy link
Member

lhoestq commented May 19, 2021

Looks good thanks, what do you think @lewtun ?

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, the code is not even one day old and already getting improved 😍 - thanks a lot @mariosasko. once the lazy features class is included, it LGTM!


from ..features import ClassLabel, Features, Value
from .base import TaskTemplate


@dataclass(frozen=True)
class TextClassification(TaskTemplate):
task = "text-classification"
input_schema = Features({"text": Value("string")})
task: ClassVar[str] = "text-classification"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for spotting this!

@abc.abstractmethod
def from_dict(cls, template_dict: dict) -> "TaskTemplate":
return NotImplemented
def from_dict(cls: Type[T], template_dict: dict) -> T:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very elegant approach - thanks!

@lewtun
Copy link
Member

lewtun commented May 20, 2021

thanks for including the lazy ClassLabel class @mariosasko ! from my side this LGTM!

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks :) merging

@lhoestq lhoestq merged commit 633ddcd into huggingface:master May 25, 2021
@mariosasko mariosasko deleted the task-code-quality branch June 2, 2021 20:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants