Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Runtime validation of TaskInstanceParameter() by subclass bound #268

Open
ganow opened this issue Jan 19, 2022 · 2 comments

Comments

@ganow
Copy link

ganow commented Jan 19, 2022

Hello, thank you for developing such a great tool!

Summary

I have one feature request to validate TaskInstanceParameter() at runtime by its subclass bound like the following:

class UpstreamBase(gokart.TaskOnKart): ...
class UpstreamA(UpstreamBase): ...
class UpstreamB(UpstreamBase): ...

class Example(gokart.TaskOnKart):
    upstream_task = gokart.TaskInstanceParameter(bound=UpstreamBase)
    ...
    def requires(self):
        return self.upstream_task  # guaranteed it is a subclass of UpstreamBase at runtime
    ...

Detail

More concrete motivating example is as follows. Suppose we want to perform some feature embedding pre-processing followed by actual data analysis task on top of it. Since we would like to empirically compare which embedding method is better, consider making this pre-processing task (upstream task) abstract and injecting actual choice as a TaskInstanceParameter(). The examples of illustrative pipelines are like:

  • pattern 1: data -> PCA embedding -> downstream analysis (classification, visualization, etc.)
  • pattern 2: data -> Isomap embedding -> downstream analysis (classification, visualization, etc.)

In such a situation, we may want to limit the task instance parameter to be injected as an upstream task to some tasks with a specific output format, rather than all possible tasks defined. The current default behavior of TaskInstanceParameter() can cause potential bugs because it does not raise any kind of exceptions in all subclasses of luigi.Task().

Here is an example code:

from abc import ABCMeta, abstractmethod

import gokart
import luigi

import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import Isomap

class CreateFakeData(gokart.TaskOnKart):
    def run(self):
        self.dump(np.random.normal(size=(100, 20)))

class EmbedTask(gokart.TaskOnKart, metaclass=ABCMeta):
    data_task = CreateFakeData()
    n_components = luigi.IntParameter(default=2)

    def requires(self):
        return self.data_task

    def run(self):
        data = self.load()
        embedding = self.get_embedding(data)
        assert embedding.shape == (len(data), self.n_components)
        self.dump(embedding)

    @abstractmethod
    def get_embedding(self, data): ...

class PCAEmbed(EmbedTask):
    whiten = luigi.BoolParameter(default=False)

    def get_embedding(self, data):
        pca = PCA(n_components=self.n_components, whiten=self.whiten)
        return pca.fit_transform(data)

class IsomapEmbed(EmbedTask):
    n_neighbors = luigi.IntParameter(default=5)

    def get_embedding(self, data):
        isomap = Isomap(n_components=self.n_components, n_neighbors=self.n_neighbors)
        return isomap.fit_transform(data)

class DownstreamTask(gokart.TaskOnKart):
    upstream_task = gokart.TaskInstanceParameter()  # we want to restrict by gokart.TaskInstanceParameter(bound=EmbedTask)

    def requires(self):
        return self.upstream_task

    def run(self):
        print('start downstream task')
        print(f'data shape: {self.load().shape}')
        self.dump(self.load().shape)

class SayHello(gokart.TaskOnKart):
    def run(self):
        self.dump("Hello, world!")

For the above code, when PCAEmbed or IsomapEmbed is used for the upstream task, it works fine as follows:

>>> gokart.build(DownstreamTask(upstream_task=PCAEmbed(whiten=True)))
start downstream task
data shape: (100, 2)
>>> gokart.build(DownstreamTask(upstream_task=IsomapEmbed(n_neighbors=5)))
start downstream task
data shape: (100, 2)

However, if we intentionally inject the wrong upstream task, it fails after hitting incorrect API access to input data due to duck typing, which means that if the wrong API access is only after a very long process in the downstream task (e.g., NN training), we won't notice the problem until we get an error.

>>> gokart.build(DownstreamTask(upstream_task=SayHello()))
start downstream task  # DownstreamTask().run() started without error!
ERROR: [pid 2305701] Worker Worker(...) failed    DownstreamTask(upstream_task=SayHello(251d2defb17d8f40d3dfb3128ef72945))
Traceback (most recent call last):
...
    print(f'data shape: {self.load().shape}')
AttributeError: 'str' object has no attribute 'shape'
...

It would be better if the problem could be detected at the initialization of the task.

>>> DownstreamTask(upstream_task=SayHello())
DownstreamTask(upstream_task=SayHello(251d2defb17d8f40d3dfb3128ef72945))  # we want to raise an exception here!

Implementation idea

luigi.Parameter() provides normalize(v) to normalize & validate injected parameters at runtime (spotify/luigi#1273). This method is executed when task object is instanciated. Therefore, it seems that a subclass check can be done by adding the following implementation to TaskInstanceParameter().

from typing import Optional

import luigi

class TaskInstanceParameter(luigi.Parameter):
    def __init__(self, *args, bound: Optional[type] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self._bound = [luigi.Task]
        if bound is not None:
            if isinstance(bound, type):
                self._bound.append(bound)
            else:
                raise ValueError(f'bound must be a type, not {type(bound)}')

    def normalize(self, v):
        for t in self._bound:
            if not isinstance(v, t):
                raise ValueError(f'{v} is not an instance of {t}')
        return v

    ...

I am happy to create PR for this if the proposal and implementation idea are reasonable for you.
Thank you.

@hirosassa
Copy link
Collaborator

@ganow Thank you for your feature request and very concrete implementation idea!
It is very reasonable for me.

@Hi-king @mski-iksm How do you think about this?

@mski-iksm
Copy link
Contributor

@ganow Thank you for raising this issue! I also think is's a good implementation.

Can you go ahead and make a PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants