Skip to content

Commit

Permalink
Merge pull request #708 from gabe-l-hart/MultitaskTrainModules-707
Browse files Browse the repository at this point in the history
Multitask train modules 707
  • Loading branch information
gabe-l-hart committed May 3, 2024
2 parents bccda9a + 4fd2d86 commit 51593ad
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 9 deletions.
16 changes: 12 additions & 4 deletions caikit/core/modules/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,14 @@ def decorator(cls_):
cls_.MODULE_CLASS = classname
cls_.PRODUCER_ID = ProducerId(cls_.MODULE_NAME, cls_.MODULE_VERSION)

cls_._TASK_CLASSES = tasks

# Parse the `train` and `run` signatures
cls_.RUN_SIGNATURE = CaikitMethodSignature(cls_, "run")
cls_.TRAIN_SIGNATURE = CaikitMethodSignature(cls_, "train")
cls_._TASK_INFERENCE_SIGNATURES = {}

# If the module has tasks, validate them:
for t in cls_._TASK_CLASSES:
task_classes = tasks
for t in task_classes:
if not t.has_inference_method_decorators(module_class=cls_):
# Hackity hack hack - make sure at least one flavor is supported
validated = False
Expand Down Expand Up @@ -231,7 +230,16 @@ def decorator(cls_):
tasks_in_hierarchy.extend(class_._TASK_CLASSES)

if tasks_in_hierarchy:
cls_._TASK_CLASSES += tasks_in_hierarchy
task_classes += tasks_in_hierarchy

# Make sure the tasks are unique. Note that the order here is important
# so that iterating the list of tasks is deterministic, unique, and the
# tasks given in the class' module list are shown before tasks inherited
# from parent classes.
cls_._TASK_CLASSES = []
for task in task_classes:
if task not in cls_._TASK_CLASSES:
cls_._TASK_CLASSES.append(task)

# If no backend support described in the class, add current backend
# as the only backend that can load models trained by this module
Expand Down
6 changes: 3 additions & 3 deletions caikit/core/modules/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def injected_load(*args):
"""

# Standard
from typing import TYPE_CHECKING, Set
from typing import TYPE_CHECKING, List
import abc
import functools

Expand Down Expand Up @@ -158,8 +158,8 @@ def metadata_injecting_load(clz, *args, **kwargs):
return super().__new__(mcs, name, bases, attrs)

@property
def tasks(cls) -> Set["TaskBase"]:
return set(cls._TASK_CLASSES)
def tasks(cls) -> List["TaskBase"]:
return [task for task in cls._TASK_CLASSES]

def __setattr__(cls, name, val):
"""Overwrite __setattr__ to warn on any dynamic updates to the load function.
Expand Down
36 changes: 34 additions & 2 deletions tests/core/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@

# Local
from caikit.core import TaskBase, task
from caikit.interfaces.common.data_model import File
from sample_lib import SampleModule
from sample_lib.data_model.sample import SampleInputType, SampleOutputType, SampleTask
from sample_lib.data_model.sample import (
OtherOutputType,
SampleInputType,
SampleOutputType,
SampleTask,
)
from sample_lib.modules.multi_task import FirstTask, MultiTaskModule, SecondTask
import caikit.core

Expand Down Expand Up @@ -171,7 +177,7 @@ def test_task_is_not_required_for_modules():
class Stuff(caikit.core.ModuleBase):
pass

assert Stuff.tasks == set()
assert Stuff.tasks == []


def test_raises_if_tasks_not_list():
Expand Down Expand Up @@ -611,6 +617,32 @@ def run(self, sample_input: Union[str, int]) -> SampleOutputType:
pass


def test_tasks_property_order():
"""Ensure that the tasks returned by .tasks have a deterministic order that
respects the order given in the module decorator
"""
assert MultiTaskModule.tasks == [FirstTask, SecondTask]


def test_tasks_property_unique():
"""Ensure that entries in the tasks list is unique even when inherited from
modules with the same tasks
"""

@caikit.core.module(
id=str(uuid.uuid4()),
name="DerivedMultitaskModule",
version="0.0.1",
task=SecondTask,
)
class DerivedMultitaskModule(MultiTaskModule):
@SecondTask.taskmethod()
def run_second_task(self, file_input: File) -> OtherOutputType:
return OtherOutputType("I'm a derivative!")

assert DerivedMultitaskModule.tasks == [SecondTask, FirstTask]


# ----------- BACKWARDS COMPATIBILITY ------------------------------------------- ##


Expand Down

0 comments on commit 51593ad

Please sign in to comment.