In [19]:
import sys
sys.path.append('..')
from core import computation_graph
from core import graph_merge
from visualize_graph import network_graph
from image_utilities import load_images

import run_image_experiment

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
@computation_graph.optex_process('return')
def transform(x):
    return x

@computation_graph.optex_process('return')
def combine(x, y, z):
    return x + y + z

@computation_graph.optex_composition('return')
def pipeline1(x):
    x.name = "input"
    t1 = transform(x)
    t1.name = "transform_1"
    t2 = transform(x)
    t2.name = "transform_2"
    t3 = transform(x)
    t3.name = "transform_3"
    c = combine(t1, t2, t3)
    c.name = 'output'
    return c

### Let's try a complex graph

In [34]:
@computation_graph.optex_process(('scale_2_out', 'scale_4_out'))
def scale(arg1):
    return arg1*2, arg1*4

@computation_graph.optex_process('add_three_out')
def add_three(arg):
    return arg + 3

@computation_graph.optex_process('sum_output')
def sum_inputs(arg1, arg2):
    return arg1 + arg2


@computation_graph.optex_composition(('output_1', 'output_2'))
def dual_return(arg1, arg2):
    arg1.name = 'arg1'
    arg2.name = 'arg2'
    out_1, out_2 = scale(arg1)
    out_1.name = 'scale_out_1'
    out_2.name = 'scale_out_2'
    out = add_three(arg2)
    out.name = 'add_three_out'
    final = sum_inputs(out_2, out)
    final.name = 'sum_out'
    return out_1, final


graph1 = graph_merge.make_expanded_graph_copy(computation_graph.generate_static_graph(dual_return, "Graph 1"))
graph2 = graph_merge.make_expanded_graph_copy(computation_graph.generate_static_graph(dual_return, "Graph 2"))

unmerged_graph = computation_graph.EdgeGraph.from_output_artifacts(graph1.outputs + graph2.outputs, name='unmerged_graph')
network_graph.Network_Graph(unmerged_graph, notebook=True).pyvis_graph.show('test.html')

### Let's try merging when all arguments match

In [37]:
graph1 = graph_merge.make_expanded_graph_copy(computation_graph.generate_static_graph(dual_return, "Graph 1"))
graph2 = graph_merge.make_expanded_graph_copy(computation_graph.generate_static_graph(dual_return, "Graph 2"))

arg1 = computation_graph.Artifact(10)
arg1.name = 'arg1'
arg2 = computation_graph.Artifact(20)
arg2.name = 'arg2'
inputs = graph_merge.get_inputs([
    (graph1, scale, 'arg1', arg1),
    (graph1, add_three, 'arg', arg2),
    (graph2, scale, 'arg1', arg1),
    (graph2, add_three, 'arg', arg2)])

merged_graph, merged_inputs, merged_outputs = graph_merge.merge_graphs([graph1, graph2], inputs, "merged")
network_graph.Network_Graph(merged_graph, notebook=True).pyvis_graph.show('test.html')

### If only some of the arguments match, only part of the graph is merged.

In [36]:
graph1 = graph_merge.make_expanded_graph_copy(computation_graph.generate_static_graph(dual_return, "Graph 1"))
graph2 = graph_merge.make_expanded_graph_copy(computation_graph.generate_static_graph(dual_return, "Graph 2"))


arg1_merged = computation_graph.Artifact(10)
arg1_merged.name = 'arg1_merged'
arg2_graph1 = computation_graph.Artifact(20)
arg2_graph1.name = 'arg2_graph1'
arg2_graph2 = computation_graph.Artifact(30)
arg2_graph2.name = 'arg2_graph2'

inputs = graph_merge.get_inputs([
    (graph1, scale, 'arg1', arg1_merged),
    (graph1, add_three, 'arg', arg2_graph1),
    (graph2, scale, 'arg1', arg1_merged),
    (graph2, add_three, 'arg', arg2_graph2)])

merged_graph, merged_inputs, merged_outputs = graph_merge.merge_graphs([graph1, graph2], inputs, "merged")
network_graph.Network_Graph(merged_graph, notebook=True).pyvis_graph.show('test.html')