# Graph Manipulation

Just trying ways to do graph manipulation in TensorFlow. [This][ge] seems promising, but the usage is not trivial.

[ge]: https://www.tensorflow.org/api_guides/python/contrib.graph_editor

In [None]:
import tensorflow as tf
import numpy as np
import keras
import keras.backend as K

config = tf.ConfigProto(
        device_count = {'GPU': 0},
    )
sess = tf.Session(config=config)
K.set_session(sess)

## define a model

Just a model that we're gonna perform surgery on

In [None]:
import keras.layers as layers
import keras.regularizers as reg

with tf.variable_scope('model'):
    inputs = layers.Input(shape=(32, 32, 3))
    x = layers.Conv2D(32, (3, 3),
                      kernel_initializer='glorot_uniform',
                      kernel_regularizer=reg.l2(0.01))(inputs)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(64, (3, 3),
                      kernel_initializer='glorot_uniform',
                      kernel_regularizer=reg.l2(0.01))(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation='relu',
                     kernel_initializer='glorot_uniform',
                     kernel_regularizer=reg.l2(0.01))(x)

    ys = []
    for i in range(3):
        ys.append(layers.Dense(10, activation='softmax',
            kernel_initializer='glorot_uniform')(x))

model = keras.models.Model(inputs=[inputs], outputs=ys)

In [None]:
for ts in model.losses:
    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, ts)

# stripping

Here is our plan:
- Assign weights to some layers
- Save the graph state and terminate the session
- Modify the graph with the Graph Editor.
- create a new session and restore the graph state
- Check the weights

In [None]:
# first, we assign dope weights
for var in tf.trainable_variables():
    dope_w = np.ones(shape=var.get_shape()) * 1337
    var.load(dope_w, sess)

In [None]:
import tensorflow.contrib.graph_editor as ge

def get_tentacles(scope_name, within_ops=None):
    if within_ops is None:
        within_ops = []
        for op in tf.get_default_graph().get_operations():
            within_ops.append(op)

    within_op_names = [op.name for op in within_ops]
    ops = ge.get_name_scope_ops(within_ops, scope_name)
    incoming, outcoming = {}, {}
    
    for op in ops:
        src_ops = ge.get_generating_ops(op.inputs)
        src_ops = [o for o in src_ops
                   if o.name in within_op_names]
        dst_ops = ge.get_consuming_ops(op.outputs)
        dst_ops = [o for o in dst_ops
                   if o.name in within_op_names]
        
        for o in src_ops:
            if not o.name.startswith(scope_name):
                if not op in incoming:
                    incoming[op] = []
                incoming[op].append(o)
        for o in dst_ops:
            if not o.name.startswith(scope_name):
                if not op in outcoming:
                    outcoming[op] = []
                outcoming[op].append(o)
                
    return incoming, outcoming

In [None]:
from tensorflow.core.framework import variable_pb2


def duplicate_layer(layer_name,
                    layer_sgv,
                    branch_name,
                    add_to_collections=True):
    
    if layer_name[-1] == '/':
        new_layer_name = layer_name[:-1] + branch_name + '/'
    else:
        new_layer_name = layer_name + branch_name

    replacement_ts = {}
    for op in layer_sgv.inputs:
        replacement_ts[op] = op

    duplicate_sgv, info = ge.copy_with_input_replacements(
        layer_sgv,
        replacement_ts=replacement_ts,
        src_scope=layer_name,
        dst_scope=new_layer_name)
    
    var_duplication = []
    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        if layer_name not in v.name:
            continue
        vproto = v.to_proto()
        new_vardef = variable_pb2.VariableDef()        
        for field, val in vproto.ListFields():
            if isinstance(val, str):
                new_val = val.replace(layer_name, new_layer_name)
            else:
                new_val = val
            setattr(new_vardef, field.name, new_val)
        new_var = tf.Variable(variable_def=new_vardef)
        tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, new_var)
        var_duplication.append((v, new_var))
        
        if add_to_collections:
            for k in tf.get_default_graph().get_all_collection_keys():
                collection = tf.get_collection(k)
                if v in collection and new_var not in collection:
                    tf.add_to_collection(k, new_var)            
        
    return info, var_duplication

    
def reroute_network(outcoming_dict, endpoints, dup_info):
    branch_ops = ge.get_walks_intersection_ops(
        forward_seed_ops=list(outcoming_dict),
        backward_seed_ops=endpoints,
        forward_inclusive=False,
        backward_inclusive=True)
    
    outputs_to_swap = []
    for op, outputs in outcoming_dict.items():
        outputs_to_swap += [o for o in outputs if o in branch_ops]
    
    for node in outputs_to_swap:
        orig_inputs = list(node.inputs)
        new_inputs = []
        for ts in orig_inputs:
            new_op = dup_info.transformed(ts.op)
            if new_op is not None:
                new_inputs.extend(new_op.outputs)
            else:
                new_inputs.append(ts)
        ge.reroute_inputs(new_inputs, node)
    
    
def do_branching(layer_name, branching_scheme, network_ops=None):
    incoming, outcoming = get_tentacles(layer_name, network_ops)
    layer_sgv = ge.make_view_from_scope(layer_name, tf.get_default_graph())
    
    duplicates = []
    for branch_name, network_outputs in branching_scheme.items():
        if branch_name == '':
            continue
        info, dups = duplicate_layer(layer_name, layer_sgv, branch_name)
        reroute_network(outcoming, network_outputs, info)
        duplicates.extend(dups)
    
    return duplicates

In [None]:
def unzip(sess,
          network_ops,
          layer_name,
          branching_scheme,
          session_prep=None,
          saver=None,
          saver_scope='save'):
    
    if saver is None:
        pre_surgery_saver = tf.train.Saver(name=saver_scope)
    else:
        pre_surgery_saver = saver
    pre_surgery_saver.save(sess, '/tmp/pre_surgery')
    sess.close()
    
    duplicate_var_pairs = do_branching(layer_name, branching_scheme, network_ops)
    
    if session_prep is None:
        sess = tf.Session()
    else:
        sess = session_prep()
    K.set_session(sess)
    
    pre_surgery_saver.restore(sess, '/tmp/pre_surgery')

    for var, new_var in duplicate_var_pairs:
        new_var.load(var.eval(sess), sess)
    post_surgery_saver = tf.train.Saver()
    post_surgery_saver.save(sess, '/tmp/post_surgery', write_meta_graph=False)
    
    non_saver_nodes = []
    for node in tf.get_default_graph().as_graph_def().node:
        if not node.name.startswith(saver_scope):
            non_saver_nodes.append(node.name)
    no_saver_graphdef = tf.graph_util.extract_sub_graph(
        tf.get_default_graph().as_graph_def(), non_saver_nodes)
    tf.train.export_meta_graph('/tmp/full_saver.meta', graph_def=no_saver_graphdef)
    
    sess.close()
    K.clear_session()
    tf.reset_default_graph()
    
    if session_prep is None:
        sess = tf.Session()
    else:
        sess = session_prep()
    K.set_session(sess)
    
    full_saver = tf.train.import_meta_graph('/tmp/full_saver.meta')
    full_saver.restore(sess, '/tmp/post_surgery')
    
    return sess, full_saver

In [None]:
network_ops = ge.get_backward_walk_ops(model.outputs)
layer_name = 'model/dense_1/'
branching_scheme = {
    '': model.outputs[0],
    'a': model.outputs[1:]
}

writer = tf.summary.FileWriter('vis_before', tf.get_default_graph())
writer.close()

sess, saver = unzip(sess, network_ops, layer_name, branching_scheme)

writer = tf.summary.FileWriter('vis_after', tf.get_default_graph())
writer.close()