In [66]:
import uuid
import random
from functools import partial
import logging

from sklearn.model_selection import KFold
import torch
import torch.utils.data
import networkx as nx

In [67]:
config = {
    "LoadTrainingData": {
        "type": "load_training_data",
        "properties": {
            "callable_arguments": {
                "n_samples": 100
            }
        },
        "output": ["train_X", "train_y"]
    },
    "LoadTestingData": {
        "type": "load_testing_data",
        "properties": {
            "callable_arguments": {
                "n_samples": 100
            }
        },
        "output": ["test_X"]
    },
    "CreateTrainingDataset": {
        "type": "MyDataset",
        "properties": {
            "partial_initialization": True
        },
        "output": ["TrainingDataset"]
    },
    "CreateValidationDataset": {
        "type": "MyDataset",
        "properties": {
            "partial_initialization": True
        },
        "output": ["ValidationDataset"]
    },
    "CreateTestingDataset": {
        "type": "MyDataset",
        "properties": {
            "partial_initialization": True
        },
        "output": ["TestingDataset"]
    },
    "CreateDataBunch": {
        "type": "create_data_bunch",
        "properties": {
            "partial_callable": True,
            "callable_arguments": {
                "train_X": {"Ref": {"LoadTrainingData": "train_X"}},
                "train_y": {"Ref": {"LoadTrainingData": "train_y"}},
                "test_X": {"Ref": {"LoadTestingData": "test_X"}},
                "train_ds": {"Ref": {"CreateTrainingDataset": "TrainingDataset"}},
                "val_ds": {"Ref": {"CreateValidationDataset": "ValidationDataset"}},
                "test_ds": {"Ref": {"CreateTestingDataset": "TestingDataset"}},
                "train_bs": 32,
                "val_bs": 64,
                "test_bs": 64
            }
        },
        "output": ["DataBunch"],
    },
    "CreateLearner": {
        "type": "create_learner",
        "properties": {
            "callable_arguments": {
                "data_splitter_iterable": {"Ref": {"CreateDataSplitter": "DataSplitter"}},
                "data_bunch_creator": {"Ref": {"CreateDataBunch": "DataBunch"}}
            }
        }   
    },
    "CreateDataSplitter": {
        "type": "KFold",
        "properties": {
            "initialization_arguments": {
                "n_splits": 5
            },
            "callable_arguments": {
                "X": {"Ref": {"LoadTrainingData": "train_X"}},
                "y": {"Ref": {"LoadTrainingData": "train_y"}}
            }
        },
        "output": ["DataSplitter"],
    },
}

In [68]:
def load_training_data(n_samples):
    X = [uuid.uuid4() for _ in range(n_samples)]
    y = [int(random.random()*100) for _ in range(n_samples)]
    return X, y 

def load_testing_data(n_samples):
    X = [uuid.uuid4() for _ in range(n_samples)]
    return X


In [69]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, i):
        return self.X[i], self.y[i]

In [70]:
def create_data_bunch(train_idx, test_idx, train_ds, val_ds, test_ds):
    return (train_idx, test_idx, train_ds, val_ds, test_ds)

In [71]:
def create_learner(data_splitter_iterable, data_bunch_creator):
    print("Create Learner")
    print(data_splitter_iterable)
    print(data_bunch_creator)
    return

In [72]:
lookup ={
    "load_training_data": load_training_data,
    "load_testing_data": load_testing_data,
    "KFold": KFold,
    "create_data_bunch": create_data_bunch,
    "create_learner": create_learner,
    "MyDataset": MyDataset
}

In [36]:
def find_node_dependencies(node_config):
    dependencies = []
    node_properties = node_config.get("properties")
    if node_properties is None:
        pass
    else:
        initialization_dependencies = []
        initialization_arguments = node_properties.get("initialization_arguments", dict())
        assert isinstance(initialization_arguments, dict), f"Please make sure the initialization arguments is of type dict for node {node_config}"
        for arg_name, arg_value in initialization_arguments.items():
            if isinstance(arg_value, dict) and arg_value.get("Ref") is not None:
                initialization_dependencies.append(arg_value.get("Ref"))
        dependencies.extend(initialization_dependencies)
                
        callable_dependencies =[]
        callable_arguments = node_properties.get("callable_arguments", dict())        
        assert isinstance(callable_arguments, dict), f"Please make sure the callable arguments is of type dict for node {node_config}"
        for arg_name, arg_value in callable_arguments.items():
            if isinstance(arg_value, dict) and arg_value.get("Ref") is not None:
                callable_dependencies.append(arg_value.get("Ref"))
        dependencies.extend(callable_dependencies)
        
    return dependencies            

In [37]:
def create_pipeline_graph(config):
    pipeline_graph = nx.DiGraph()
    
    for node_name, node_config in config.items():
        node_dependencies = find_node_dependencies(node_config)
        pipeline_graph.add_node(node_name, config=node_config, dependencies=node_dependencies)
        
    for node, attributes in pipeline_graph.nodes(data=True):
        dependent_on_nodes = list(set(list(dep.keys())[0] for dep in attributes['dependencies']))
        for dependent_on_node in dependent_on_nodes:
            pipeline_graph.add_edge(dependent_on_node, node)
    return pipeline_graph

In [73]:
def listify(x):
    if not isinstance(x, collections.Iterable):
        return [x]
    else:
        return x
    
import types
# Reference from: https://stackoverflow.com/questions/38541015/how-to-monkey-patch-a-call-method
def patch_call(instance, func):
    class _(type(instance)):
        def __call__(self, *arg, **kwarg):
            return func(*arg, **kwarg)
    instance.__class__ = _

def replace_references(graph, arguments):
    for arg_name, arg_value in arguments.items():
        if isinstance(arg_value, dict) and arg_value.get("Ref") is not None:
            reference_node_name = list(arg_value.get("Ref").keys())[0]
            reference_node_output_name = list(arg_value.get("Ref").values())[0]

            try:
                arguments[arg_name] = graph[reference_node_name]['output_lookup'][reference_node_output_name]
            except KeyError as e:
                print(f"KeyError: {e}")
                print(f"Reference node name: {reference_node_name}")
                print(f"Reference node output lookup: {graph[reference_node_name]['output_lookup']}")
    return arguments

In [74]:
def run_graph(graph, reference_lookup, force_run=False):
    sorted_node_names = list(nx.algorithms.dag.topological_sort(graph))

    for name in sorted_node_names:
        node = graph.nodes(data=True)[name]
        node_output_lookup = node.get('output_lookup')
        if not force_run and node_output_lookup is not None:
            continue
        node_config = node['config']
        node_dependencies = node.get('dependencies')
        node_output = node_config.get('output', [])

        node_properties = node_config.get("properties")
        node_callable = reference_lookup[node_config['type']]
        #         print(name)
        #         print(node)
        if node_properties is not None:

            # 1. Replace the properties that reference values from another node with the actual values
            initialization_arguments = node_properties.get("initialization_arguments", dict())
            callable_arguments = node_properties.get("callable_arguments", dict())

            has_dependencies = len(node_dependencies) > 0
            if has_dependencies:
                # replace the reference values in the arguments
                initialization_arguments = replace_references(graph.nodes(data=True), initialization_arguments)
                callable_arguments = replace_references(graph.nodes(data=True), callable_arguments)
            else:
                pass

            # 2. Check if the callable of the node is a function or a class
            is_function = isinstance(node_callable, types.FunctionType)
            partial_initialization = node_properties.get("partial_initialization", False)
            partial_callable = node_properties.get("partial_callable", False)

            # 3. Initialize the callables accordingly
            node['callable'] = node_callable
            node['output'] = node_output
            if is_function:
                # Make sure that if the callable is a function then there are no initialization arguments
                assert len(
                    initialization_arguments) == 0, f"Function: {node_callable.__name__} cannot have initialization arguments: {initialization_arguments}, only callable arguments"

                if partial_callable:
                    assert len(
                        node_output) <= 1, 'If this is a partial callable, then there should be one or less output for this step'
                    if len(node_output) == 0:
                        pass
                    output_name = node_output[0]
                    node['output_lookup'] = {
                        output_name: partial(node_callable, **callable_arguments)
                    }
                else:
                    callable_output = listify(node_callable(**callable_arguments))
                    node['output_lookup'] = {
                        output_name: callable_output for output_name, callable_output in
                        zip(node['output'], callable_output)
                    }
            else:
                assert (
                               partial_initialization and partial_callable) is False, "Can't make both the initialization of a class and it's __call__ method both partial"
                assert len(
                    node_output) == 1, 'If this is a step to initialize an object, then there should only be one output, the object itself'
                output_name = node_output[0]
                if partial_initialization:
                    node['output_lookup'] = {
                        output_name: partial(node_callable, **initialization_arguments)
                    }
                else:
                    initialized_node_object = node_callable(**initialization_arguments)
                    if partial_callable:
                        patch_call(initialized_node_object,
                                   partial(initialized_node_object.__call__, **callable_arguments))
                    node['output_lookup'] = {
                        output_name: initialized_node_object
                    }
        else:
            node['callable'] = node_callable
            callable_output = listify(node_callable())
            node['output_lookup'] = {
                output_name: callable_output for output_name, callable_output in zip(node['output'], callable_output)
            }

In [75]:
pipeline_graph = create_pipeline_graph(config)
sorted_node_names = list(nx.algorithms.dag.topological_sort(pipeline_graph))
print(sorted_node_names)

['CreateTestingDataset', 'CreateValidationDataset', 'CreateTrainingDataset', 'LoadTestingData', 'LoadTrainingData', 'CreateDataSplitter', 'CreateDataBunch', 'CreateLearner']


In [78]:
run_graph(pipeline_graph, reference_lookup=lookup)

Create Learner
KFold(n_splits=5, random_state=None, shuffle=False)
functools.partial(<function create_data_bunch at 0x7fe7d4bd4158>, train_X=[UUID('c985cce3-6499-4a7e-ba71-45b453b4f52e'), UUID('de40ad0a-808d-422a-bdd2-7f3d2c590b7f'), UUID('181f39fe-739e-4e54-ae94-cc93e1e97df5'), UUID('d117c366-70a0-4545-af1d-e3a4d7689850'), UUID('fc96dd71-ffd8-4d9b-b329-9e3b9d563cbe'), UUID('3c8f02ae-a86c-4637-b739-13006af2d551'), UUID('39c81509-6ce7-495d-8a4e-19372b5f6ef1'), UUID('ee4b9a01-9894-420e-bf7a-0e70c8b1e538'), UUID('6edf09d9-9207-4d90-a993-061c0afc549d'), UUID('8721d7f4-a1cc-49df-88d6-34c2c2d1e44b'), UUID('8a75bd43-dd83-4b24-a364-113c17394093'), UUID('cb9eb422-107d-430a-9612-cb1b3472d827'), UUID('3e51fe38-7e97-4ff2-b6a5-e37c52880664'), UUID('893a0889-542d-4224-b27b-931698a60799'), UUID('2c7ce7af-8bee-4247-80f8-53d82b5e7a5f'), UUID('530a97ea-5d56-4834-959d-58f707774bfa'), UUID('b9ad6634-0b00-4b3e-bcae-99d4aaa65a71'), UUID('df54040a-8d11-48c9-b0cb-ddccaf9b4f1d'), UUID('b9ad7093-8274-4d8d-8602-

In [79]:
pipeline_graph.nodes(data=True)['LoadTrainingData']['output_lookup']

{'train_X': [UUID('c985cce3-6499-4a7e-ba71-45b453b4f52e'),
  UUID('de40ad0a-808d-422a-bdd2-7f3d2c590b7f'),
  UUID('181f39fe-739e-4e54-ae94-cc93e1e97df5'),
  UUID('d117c366-70a0-4545-af1d-e3a4d7689850'),
  UUID('fc96dd71-ffd8-4d9b-b329-9e3b9d563cbe'),
  UUID('3c8f02ae-a86c-4637-b739-13006af2d551'),
  UUID('39c81509-6ce7-495d-8a4e-19372b5f6ef1'),
  UUID('ee4b9a01-9894-420e-bf7a-0e70c8b1e538'),
  UUID('6edf09d9-9207-4d90-a993-061c0afc549d'),
  UUID('8721d7f4-a1cc-49df-88d6-34c2c2d1e44b'),
  UUID('8a75bd43-dd83-4b24-a364-113c17394093'),
  UUID('cb9eb422-107d-430a-9612-cb1b3472d827'),
  UUID('3e51fe38-7e97-4ff2-b6a5-e37c52880664'),
  UUID('893a0889-542d-4224-b27b-931698a60799'),
  UUID('2c7ce7af-8bee-4247-80f8-53d82b5e7a5f'),
  UUID('530a97ea-5d56-4834-959d-58f707774bfa'),
  UUID('b9ad6634-0b00-4b3e-bcae-99d4aaa65a71'),
  UUID('df54040a-8d11-48c9-b0cb-ddccaf9b4f1d'),
  UUID('b9ad7093-8274-4d8d-8602-cd3e78f09cb7'),
  UUID('29f36294-2175-450d-9abd-09763a52c62d'),
  UUID('1be2d3dc-590c-45b0-b3