In [None]:
import tvm
from tvm import relay
import tvm.relay.analysis_tools

We'll start by examining a simple Relay program:

In [None]:
program = relay.const(1) - (relay.var('x') * relay.var('y'))

This simple analysis pass visits all `Call`s. It uses the `AnalysisPass` helper method `_add_detail` to attach analysis results to an expression. In this case, it attaches an analysis result named `'readable_name'` to the `Call` being visited. `_add_detail` is one of the main conveniences added by this simple analysis framework.

In [None]:
class GetReadableName(relay.analysis_tools.AnalysisPass):
    def visit_call(self, call):
        super().visit_call(call)
        self._add_detail(call, readable_name=call.op.name)

We can build another simple `AnalysisPass` to give all of our nodes unique ids:

In [None]:
class GetIndex(relay.analysis_tools.AnalysisPass):
    def __init__(self):
        super().__init__()
        self.__id = 0

    def visit_call(self, call):
        super().visit_call(call)
        self._add_detail(call, id=self.__id)
        self.__id += 1

Next, we use the `run_analyses` helper method to run a batch of analyses on our program.

In [None]:
analyses = [GetReadableName(), GetIndex()]
analysis_results, summary_results = relay.analysis_tools.run_analyses(program, analyses)
analysis_results

As we can see, the analysis results are in a bit of a raw form. The rest of the interesting functions in the analysis framework mostly pertain to helping wrangle the output into a useful form.

For example, we can get the columns from the output, and then once we have the columns, we can turn the data into record format:

In [None]:
analysis_columns = relay.analysis_tools.get_analysis_columns(analysis_results)
analysis_columns

In [None]:
records = relay.analysis_tools.get_records(analysis_results, analysis_columns)
records

Record format helps us put the data into pandas:

In [None]:
import pandas as pd
pd.DataFrame.from_records(records, columns=['id', 'op'], index='id')

Now, let's see what it looks like with a bigger program, like Mobilenet:

In [None]:
from tvm.relay.testing.mobilenet import get_workload

module, params = get_workload()

analyses = [GetReadableName(), GetIndex()]
analysis_results, summary_results = relay.analysis_tools.run_analyses(module['main'], analyses)
records = relay.analysis_tools.get_records(analysis_results, analysis_columns)
pd.DataFrame.from_records(records, columns=['id', 'op'], index='id')

## Building Analysis Summaries

The last feature of the analysis framework is the ability to build _summaries_, which are simply useful ways to add analysis results which are not tied to an expression within the program, but to the program as a whole.

We can make a summary by overriding `_summarize`. This method runs once the `AnalysisPass` finishes visiting all expressions in a program. `_summarize` can access the `_existing_data` field to access previously-generated data. Using this data, it can then generate summary info and attach it to the program using `_add_summary`.

Here, we show a summary pass which generates a histogram of call types:

In [None]:
class SummarizeOpTypes(relay.analysis_tools.AnalysisPass):
    def _summarize(self):
        histogram = {}
        for node, data in self._existing_data.items():
            if data['readable_name'] not in histogram:
                histogram[data['readable_name']] = 1
            else:
                histogram[data['readable_name']] += 1
        self._add_summary(histogram)

We generate summaries with `run_analyses` as well: 

In [None]:
analyses = [GetReadableName(), GetIndex(), SummarizeOpTypes()]
analysis_results, summary_results = relay.analysis_tools.run_analyses(module['main'], analyses)
summary_results

There are also utilities for pulling out information about summaries:

In [None]:
summary_columns = relay.analysis_tools.get_summary_columns(summary_results)
summary_columns

We can generate summaries for two networks, and display the results in a summary table:

In [None]:
from tvm.relay.testing.resnet import get_workload

module, params = get_workload()

mobilenet_summary = summary_results

analyses = [GetReadableName(), GetIndex(), SummarizeOpTypes()]
_, resnet_summary = relay.analysis_tools.run_analyses(module['main'], analyses)
resnet_columns = relay.analysis_tools.get_summary_columns(resnet_summary)

summary_columns.update(resnet_columns)
summary_columns

In [None]:
mobilenet_record = relay.analysis_tools.summary_to_record(summary_columns, mobilenet_summary)
resnet_record = relay.analysis_tools.summary_to_record(summary_columns, resnet_summary)

pd.DataFrame.from_records([mobilenet_record, resnet_record], columns = summary_columns).fillna(0).astype('int')

# Dynamic Analysis

I also experimented with adding simple dynamic analyses in Relay this summer. These analyses were effectively just counters inserted into the program to count program events. 

In [None]:
class InstrumentProgramWithReference(relay.ExprMutator):
    def __init__(self):
        super().__init__()
        self.reference = relay.RefCreate(relay.const(0))
        
    def visit_call(self, call):
        call = super().visit_call(call)
        if call.op.name != "nn.relu": return call
        call = relay.Let(relay.var('unused'), 
                         relay.RefWrite(self.reference,
                                        relay.RefRead(self.reference) + relay.const(1)),
                         call)
        return call
    
def instrument_program(function : relay.Function):
    instrumentation_pass = InstrumentProgramWithReference()
    function = instrumentation_pass.visit(function)
    return relay.Function(params=function.params,
                          body=relay.Tuple([function.body, relay.RefRead(instrumentation_pass.reference)]),
                          ret_type=None, # TODO
                          type_params=function.type_params,
                          attrs=function.attrs)

In [None]:
resnet_mod, resnet_params = get_workload()
resnet = resnet_mod['main']

resnet_instrumented = instrument_program(resnet)
resnet_instrumented = relay.transform.InferType()(relay.Module.from_expr(resnet_instrumented))['main']
print(resnet_instrumented)

In [None]:
import numpy as np
input = np.random.rand(1,3,224,224).astype('float32')

ex = relay.create_executor()
ex.evaluate(resnet_instrumented)(input, **resnet_params)

In [None]:
class InstrumentProgramFunctionally(relay.ExprMutator):
    def visit_function(self, fn):
        new_params = [self.visit(x) for x in fn.params]
        new_body = self.visit(fn.body)
        return relay.Function(
            list(new_params),
            new_body,
            # the return type will change.
            # also commenting out the other stuff.
            #fn.ret_type,
            #fn.type_params,
            #fn.attrs
            )
    
    """def visit_let(self, let):
        new_var = self.visit(let.var)
        new_val = self.visit(let.value)
        new_body = self.visit(let.body)
        return Let(new_var, new_val, new_body)"""

    def visit_call(self, call):
        """
        
        How does this work?
        We have to return the call as first item of the tuple.
        The args will change and become tuples. We'll need to pull
        the args out of the tuples that they become."""
        # TODO(gus) how might op change?
        new_fn = self.visit(call.op)
        # This makes them tuples
        new_args = [self.visit(arg) for arg in call.args]
        # Reduce the counters
        new_counter = sum([relay.TupleGetItem(arg, 1) for arg in new_args], relay.const(0, dtype='int32'))
        # Get the first value
        new_args = [relay.TupleGetItem(arg, 0) for arg in new_args]
        new_call = relay.Call(new_fn, new_args, call.attrs)
        return relay.Tuple([new_call, new_counter])

    # TODO(gus) does var need to change? i don't think it does. I don't
    # think it makes sense.
    def visit_var(self, var):
        """Return a var with the appropriate tuple type."""
        tuple_type = relay.TupleType([var.checked_type, relay.TensorType(tuple(), dtype='int32')])
        return relay.var(var.name_hint, type_annotation=tuple_type)

    """def visit_global_id(self, global_var):
        # TODO(gus) how is this different from a global var?
        return global_var"""

    """def visit_if(self, ite):
        # TODO(gus) doesn't change?
        return If(
            self.visit(ite.cond),
            self.visit(ite.true_branch),
            self.visit(ite.false_branch))"""

    """
    # TODO(gus) these are wrong, pretty sure.
    def visit_tuple(self, tup):
        return relay.Tuple([relay.Tuple([self.visit(field) for field in tup.fields]), relay.const(0)])

    def visit_tuple_getitem(self, op):
        tuple_value = self.visit(op.tuple_value)
        return relay.TupleGetItem(relay.TupleGetItem(tuple_value, 0), op.index)"""

    # TODO(gus) also this var
    """def visit_global_var(self, gvar):
        return relay.Tuple([gvar, relay.const(1)])"""

    """def visit_op(self, op):
        return op"""

    def visit_constant(self, const):
        return relay.Tuple([const, relay.const(0, dtype='int32')])

    """def visit_constructor(self, con):
        return con"""

    def visit_match(self, m):
        return Match(
            # TODO(gus) do we need to get the item here? I think we do.
            relay.TupleGetItem(self.visit(m.data), 0),
            [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
            complete=m.complete)

    def visit_ref_create(self, r):
        raise NotImplementedError()
        return RefCreate(self.visit(r.value))

    def visit_ref_write(self, r):
        raise NotImplementedError()
        return RefWrite(self.visit(r.ref), self.visit(r.value))

    def visit_ref_read(self, r):
        raise NotImplementedError()
        return RefRead(self.visit(r.ref))


In [None]:
resnet_instrumented = InstrumentProgramFunctionally().visit(resnet)
print(resnet)
print(resnet_instrumented)

In [None]:
class CountReLUFunctionally(InstrumentProgramFunctionally):
    def visit_call(self, call):
        op = call.op.name
        call = super().visit_call(call)
        if op != 'nn.relu':
            return call
        else:
            return relay.Tuple([relay.TupleGetItem(call, 0),
                                # increment counter
                                relay.TupleGetItem(call, 1) + relay.const(1, dtype='int32')])

In [None]:
resnet_instrumented = relay.transform.SimplifyInference()(relay.Module.from_expr(resnet))['main']
print(resnet_instrumented)

In [None]:
resnet_instrumented = relay.transform.SimplifyInference()(relay.Module.from_expr(resnet))['main']
resnet_instrumented = CountReLUFunctionally().visit(resnet_instrumented)
#resnet_instrumented = relay.transform.InferType()(relay.Module.from_expr(resnet_instrumented))['main']
#resnet_instrumented = relay.transform.PartialEvaluate()(relay.Module.from_expr(resnet_instrumented))['main']
print(resnet_instrumented)