In [1]:
import json
import kfp
from dataclasses import dataclass
from kfp import components
from networkx import DiGraph
from jinja2 import Environment

In [2]:
with open('recommender-response.json') as f:
    recommender_response = json.load(f)

In [3]:
# Temporary fix for the recommender response
for node in recommender_response["nodes"]:
    node["component"]["url"] = f"_components/{node['component']['name']}.yaml"

In [4]:
class Step:
    def __init__(self, component: components.YamlComponent):
        self.component = component

    @property
    def name(self) -> str:
        return self.component.name

    @property
    def description(self) -> str:
        return self.component.description

    @property
    def inputs(self) -> list[str]:
        return list(self.component.component_spec.inputs.keys())
    
    @property
    def outputs(self) -> list[str]:
        return list(self.component.component_spec.outputs.keys())
    
    @property
    def produced_artifacts(self) -> dict[str, kfp.dsl.structures.OutputSpec]:
        return {f"{artifact}": self.component.component_spec.outputs[artifact] for artifact in self.component.component_spec.outputs
                if self.component.component_spec.outputs[artifact].type.startswith("system.")}
    
    @property
    def pipeline_parameters(self) -> dict[str, kfp.dsl.structures.InputSpec]:
        return {f"{self.name}-{param}": self.component.component_spec.inputs[param] for param in self.component.component_spec.inputs
                if not self.component.component_spec.inputs[param].type.startswith("system.")}

In [5]:
@dataclass
class ArtifactConnection:
    def __init__(self, artifact_name: str, artifact_key: str, producer_step: str):
        self.artifact_name = artifact_name
        self.artifact_key = artifact_key
        self.producer_step = producer_step

In [6]:
class Pipeline(DiGraph):
    def __init__(self, component_definitions: list[dict], component_connections: list[dict]):
        super().__init__()
        self.component_definitions = component_definitions
        self.component_connections = component_connections

        self._generate_steps()
        self._connect_steps()
        self._store_connections()

    def _generate_steps(self):
        self.steps = {}
        for comp_definition in self.component_definitions:
            component_name = comp_definition["component"]["name"]
            component_uri = comp_definition["component"]["url"]
            step = Step(components.load_component_from_file(component_uri))
            self.steps[component_name] = step
            self.add_node(step)

    def _connect_steps(self):
        for comp_connection in self.component_connections:
            source = self.steps[comp_connection["source"]]
            target = self.steps[comp_connection["target"]]
            self.add_edge(source, target)

    def _store_connections(self):
        self.connection_data = {}
        for comp_connection in self.component_connections:
            step_name = comp_connection["target"]
            artifact_name = comp_connection["input"]
            artifact_key = comp_connection["output"]
            producer_step = comp_connection["source"]
            connection = ArtifactConnection(artifact_name, artifact_key, producer_step)
            self.connection_data[step_name] = connection

In [7]:
pipe = Pipeline(recommender_response["nodes"], recommender_response["edges"])

In [8]:
with open("pipeline.jinja", "r") as f:
    template_str = f.read()

env = Environment(trim_blocks=True, lstrip_blocks=True)
    
# Create the template from our template string
template = env.from_string(template_str)

In [9]:
# Render the template with components
rendered_yaml = template.render(pipe=pipe)

# Format the rendered YAML
with open("my_pipe.yaml", 'w') as f:
    f.write(rendered_yaml)