In [None]:
from caffe2.python import core, workspace, test_util, dyndep, nomnigraph as ng
from caffe2.proto import caffe2_pb2
import pprint as pp
import graphviz as gv
import google.protobuf.text_format

## Load model from a protobuf file
Load a caffe2 model from protobuf and convert it to nomnigraph representation (https://github.com/pytorch/pytorch/tree/master/caffe2/core/nomnigraph).

In [None]:
# Protobuf file of the model.
MODEL_FILE = "model.pb"
with open(MODEL_FILE, 'rb') as f:
    netdef_proto = caffe2_pb2.NetDef()
    # Use this for text format protobuf file
    #google.protobuf.text_format.Merge(f.read(), netdef_proto)
    # Use this for binary format protobuf file
    netdef_proto.ParseFromString(f.read())
    nnmodule = ng.NNModule(netdef_proto)
    dfGraph = nnmodule.dataFlow

## Simple graph exploration
Basic graph traversal is supported by nomnigraph.

In [None]:
# Number of operators
print(len(dfGraph.operators))

In [None]:
# Print operator names
for op in dfGraph.operators:
    print(op.name)

In [None]:
# Get inputs, outputs for a particular operator
OP_NAME = "Mul"
for op in dfGraph.operators:
    if op.name == OP_NAME:
        print(op.name)
        print("Inputs")
        pp.pprint([tensor.name for tensor in op.inputs])
        print("Outputs")
        pp.pprint([tensor.name for tensor in op.outputs])
        break
    

## Visualize graph

In [None]:
def viz(graph):
    return gv.Source(str(graph))
viz(dfGraph)

## Subgraph matching
Nomnigraph can be used to perform subgraph pattern matching.

In [None]:
mg = ng.NNMatchGraph()
matchMul = mg.createNode(ng.NeuralNetOperator("Mul"), strict=True)
matchT = mg.createNode(ng.NeuralNetData("*"), strict=True)
matchReplaceNan = mg.createNode(ng.NeuralNetOperator("ReplaceNaN"))
mg.createEdge(matchMul, matchT)
mg.createEdge(matchT, matchReplaceNan)

matches = nnmodule.match(mg)
for match in matches:
    # TODO: visualize subgraph
    for node in match.nodes:
        print(node.name)
    break

## Construct and visualize a subgraph

In [None]:
sg = ng.NNSubgraph()
feature_name = "TEST_FEATURE"
for blob in dfGraph.tensors:
    if feature_name in blob.name:
        #print(blob.name)
        sg.addNode(blob)
        [sg.addNode(x) for x in blob.consumers]
        if blob.hasProducer():
            pro = blob.producer
            [sg.addNode(x) for x in pro.inputs]
            sg.addNode(pro)
sg.induceEdges()
viz(sg)