In [1]:
import sys
import uuid
import random
from functools import partial
import logging

from sklearn.model_selection import KFold
import networkx as nx

#### Set the PYTHONPATH to include directory pytorch_toolbox resides in 

In [2]:
sys.path.append("../../..")

In [3]:
from pytorch_toolbox import pipeline_parser

In [4]:
config = {
    "LoadTrainingData": {
        "type": "load_training_data",
        "properties": {
            "callable_arguments": {
                "n_samples": 100
            }
        },
        "output": ["train_X", "train_y"]
    },
    "LoadTestingData": {
        "type": "load_testing_data",
        "properties": {
            "callable_arguments": {
                "n_samples": 100
            }
        },
        "output": ["test_X"]
    },
    "CreateTrainingDataset": {
        "type": "MyDataset",
        "properties": {
            "partial_initialization": True
        },
        "output": ["TrainingDataset"]
    },
    "CreateValidationDataset": {
        "type": "MyDataset",
        "properties": {
            "partial_initialization": True
        },
        "output": ["ValidationDataset"]
    },
    "CreateTestingDataset": {
        "type": "MyDataset",
        "properties": {
            "partial_initialization": True
        },
        "output": ["TestingDataset"]
    },
    "CreateDataBunch": {
        "type": "create_data_bunch",
        "properties": {
            "partial_callable": True,
            "callable_arguments": {
                "train_X": {"Ref": {"LoadTrainingData": "train_X"}},
                "train_y": {"Ref": {"LoadTrainingData": "train_y"}},
                "test_X": {"Ref": {"LoadTestingData": "test_X"}},
                "train_ds": {"Ref": {"CreateTrainingDataset": "TrainingDataset"}},
                "val_ds": {"Ref": {"CreateValidationDataset": "ValidationDataset"}},
                "test_ds": {"Ref": {"CreateTestingDataset": "TestingDataset"}},
                "train_bs": 32,
                "val_bs": 64,
                "test_bs": 64
            }
        },
        "output": ["DataBunch"],
    },
    "CreateLearner": {
        "type": "create_learner",
        "properties": {
            "callable_arguments": {
                "data_splitter_iterable": {"Ref": {"CreateDataSplitter": "DataSplitter"}},
                "data_bunch_creator": {"Ref": {"CreateDataBunch": "DataBunch"}}
            }
        }   
    },
    "CreateDataSplitter": {
        "type": "KFold",
        "properties": {
            "initialization_arguments": {
                "n_splits": 5
            },
            "callable_arguments": {
                "X": {"Ref": {"LoadTrainingData": "train_X"}},
                "y": {"Ref": {"LoadTrainingData": "train_y"}}
            }
        },
        "output": ["DataSplitter"],
    },
}

In [5]:
def load_training_data(n_samples):
    X = [uuid.uuid4() for _ in range(n_samples)]
    y = [int(random.random()*100) for _ in range(n_samples)]
    return X, y 

def load_testing_data(n_samples):
    X = [uuid.uuid4() for _ in range(n_samples)]
    return X

class MyDataset:
    def __init__(self, X, y):
        self.X = X
        self.y = y
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, i):
        return self.X[i], self.y[i]

def create_data_bunch(train_idx, test_idx, train_ds, val_ds, test_ds, train_bs, val_bs, test_bs):
    print("Create data bunch")
    print("train_idx")
    print(train_idx)
    print("test_idx")
    print(test_idx)
    return (train_idx, test_idx, train_ds, val_ds, test_ds)

def create_learner(data_splitter_iterable, data_bunch_creator):
    print("Create Learner")
    print("data_splitter_iterable")
    print(data_splitter_iterable)
    print("data_bunch_creator")
    print(data_bunch_creator)
    return

In [6]:
lookup ={
    "load_training_data": load_training_data,
    "load_testing_data": load_testing_data,
    "KFold": KFold,
    "create_data_bunch": create_data_bunch,
    "create_learner": create_learner,
    "MyDataset": MyDataset
}

In [7]:
pipeline_graph = pipeline_parser.PipelineGraph.create_pipeline_graph_from_config(config)

In [8]:
pipeline_graph.run_graph(lookup)

Create Learner
data_splitter_iterable
KFold(n_splits=5, random_state=None, shuffle=False)
data_bunch_creator
functools.partial(<function create_data_bunch at 0x7ff66ceeed08>, train_X=[UUID('e3a37441-3505-435b-8ee5-1c8ab8b936b4'), UUID('7a836ffc-3ad9-4639-a5b7-13bd3aa05d75'), UUID('c534738a-6e67-40a7-a1d2-7cdfa575fc06'), UUID('06417c89-853b-4d19-a01e-45bb182ea8c4'), UUID('bd73bd7b-3c8c-4cf2-b1be-428c7de5965b'), UUID('1d39e617-f314-49cf-928d-f0fd206bcef0'), UUID('41cf9909-add1-4523-b00c-c63f6e70cb39'), UUID('dec2265f-3bc2-4bb8-a4ad-e5b311b24c87'), UUID('d72b9797-0bd4-4ce4-92b4-2af9abe6ea93'), UUID('6fec703a-d764-4ac8-b6b8-ea7422ba53c4'), UUID('1dab85b0-2d56-482c-9790-53b7ca9bb25b'), UUID('7d8142bb-6183-4f3c-92fd-1211aaf92374'), UUID('d3ffacfe-1b45-4b57-b07d-a8413448b74a'), UUID('793252ba-b7a8-474d-aec6-424e91a36ffb'), UUID('62bb285d-b8d5-4ff0-bbff-bc605d7ecd45'), UUID('6cff911c-1859-433f-b9ad-817df33869cb'), UUID('bdcfc0d3-34e1-4b99-88b2-6137210236a3'), UUID('6ddc22ac-a3e9-4f9f-96d7-dcba

In [9]:
pipeline_graph.graph.nodes(data=True)['CreateDataBunch']

{'config': {'type': 'create_data_bunch',
  'properties': {'partial_callable': True,
   'callable_arguments': {'train_X': [UUID('e3a37441-3505-435b-8ee5-1c8ab8b936b4'),
     UUID('7a836ffc-3ad9-4639-a5b7-13bd3aa05d75'),
     UUID('c534738a-6e67-40a7-a1d2-7cdfa575fc06'),
     UUID('06417c89-853b-4d19-a01e-45bb182ea8c4'),
     UUID('bd73bd7b-3c8c-4cf2-b1be-428c7de5965b'),
     UUID('1d39e617-f314-49cf-928d-f0fd206bcef0'),
     UUID('41cf9909-add1-4523-b00c-c63f6e70cb39'),
     UUID('dec2265f-3bc2-4bb8-a4ad-e5b311b24c87'),
     UUID('d72b9797-0bd4-4ce4-92b4-2af9abe6ea93'),
     UUID('6fec703a-d764-4ac8-b6b8-ea7422ba53c4'),
     UUID('1dab85b0-2d56-482c-9790-53b7ca9bb25b'),
     UUID('7d8142bb-6183-4f3c-92fd-1211aaf92374'),
     UUID('d3ffacfe-1b45-4b57-b07d-a8413448b74a'),
     UUID('793252ba-b7a8-474d-aec6-424e91a36ffb'),
     UUID('62bb285d-b8d5-4ff0-bbff-bc605d7ecd45'),
     UUID('6cff911c-1859-433f-b9ad-817df33869cb'),
     UUID('bdcfc0d3-34e1-4b99-88b2-6137210236a3'),
     UUID('6ddc2