# Gathering Data on Relay Models for Guiding Hardware Design

In [None]:
from datetime import datetime
from IPython.display import display, Markdown

display(Markdown(f"Last updated {datetime.now():%Y-%m-%d %H:%M:%S%z}."))

## Setting up

In [None]:
import tvm
from tvm import relay
from tvm.relay import analysis_tools
import pandas as pd
from tvm.relay.testing import mlp
from tvm.relay.testing import resnet
from tvm.relay.testing import dqn
from tvm.relay.testing import dcgan
from tvm.relay.testing import mobilenet
from tvm.relay.testing import lstm
from tvm.relay.testing import inception_v3
from tvm.relay.testing import squeezenet
from tvm.relay.testing import vgg
from tvm.relay.testing import densenet

Don't truncate Pandas dataframes.

In [None]:
pd.set_option("display.max_rows", None)

## Defining Analyses

In [None]:
class GetReadableName(analysis_tools.AnalysisPass):
    def visit_call(self, call):
        super().visit_call(call)
        self._add_detail(call, readable_name=call.op.name)

In [None]:
class GetIndex(analysis_tools.AnalysisPass):
    def __init__(self):
        super().__init__()
        self.__id = 0

    def visit_call(self, call):
        super().visit_call(call)
        self._add_detail(call, id=self.__id)
        self.__id += 1

In [None]:
class SummarizeOpTypes(relay.analysis_tools.AnalysisPass):
    """Requires GetReadableName analysis to run first."""

    def _summarize(self):
        histogram = {}
        for node, data in self._existing_data.items():
            if data["readable_name"] not in histogram:
                histogram[data["readable_name"]] = 1
            else:
                histogram[data["readable_name"]] += 1
        self._add_summary(histogram)

In [None]:
def _extract_shape(t):
    if isinstance(t, relay.TensorType):
        return [int(v) for v in t.shape]
    elif isinstance(t, relay.TupleType):
        return tuple(_extract_shape(u) for u in t.fields)
    else:
        import sys

        print("Unhandled type " + str(type(t)), file=sys.stderr)

In [None]:
class OutputShape(relay.analysis_tools.AnalysisPass):
    def __init__(self):
        super().__init__()

    def visit_call(self, call):
        super().visit_call(call)
        t = call.checked_type
        self._add_detail(call, shape=_extract_shape(t))

In [None]:
class InputShapes(relay.analysis_tools.AnalysisPass):
    """Requires OutputShape analysis to run first."""

    def __init__(self):
        super().__init__()

    def visit_call(self, call):
        super().visit_call(call)
        input_arg_analysis_data = {}
        for i, arg in enumerate(call.args):
            input_arg_analysis_data[i] = _extract_shape(arg.checked_type)
        self._add_detail(call, input_shapes=input_arg_analysis_data)

In [None]:
summaries = {}
results = {}
summary_columns = set()
for (module, _), name in [
    (resnet.get_workload(num_layers=18), "resnet18"),
    (resnet.get_workload(num_layers=50), "resnet50"),
    (mobilenet.get_workload(), "mobilenet"),
    (mlp.get_workload(batch_size=1), "mlp"),
    (dqn.get_workload(batch_size=1), "dqn"),
    (dcgan.get_workload(batch_size=1), "dcgan"),
    # LSTM throws an error w/ analysis framework
    #    (lstm.get_workload(iterations=32, num_hidden=32), 'lstm'),
    (inception_v3.get_workload(), "inception_v3"),
    (squeezenet.get_workload(), "squeezenet"),
    (vgg.get_workload(batch_size=1), "vgg"),
    (densenet.get_workload(), "densenet"),
]:

    # Simplify model for inference, which replaces batch norms
    # with their component operations (add, sqrt, etc)
    module = relay.transform.SimplifyInference()(module)

    program = module["main"]
    analyses = [
        GetReadableName(),
        GetIndex(),
        SummarizeOpTypes(),
        OutputShape(),
        InputShapes(),
    ]
    these_results, summary_results = relay.analysis_tools.run_analyses(
        program, analyses
    )
    summary_columns.update(relay.analysis_tools.get_summary_columns(summary_results))
    summaries[name] = summary_results
    results[name] = these_results

summary_columns_ordered = sorted(list(summary_columns))
summary_column_names = list(map(lambda t: t[0], summary_columns_ordered))
summary_records = list(
    map(
        lambda t: (t[0],)
        + analysis_tools.summary_to_record(summary_columns_ordered, t[1]),
        summaries.items(),
    )
)

models_and_operators = pd.DataFrame.from_records(
    summary_records, columns=["model"] + summary_column_names, index="model"
)

In [None]:
# Make it look nicer by replacing NaNs.
models_and_operators = models_and_operators.fillna("")

Summary table, comparing data across multiple networks:

In [None]:
models_and_operators

In [None]:
for name, these_results in results.items():
    # Contains (column_id, column_name) pairs.
    columns = [
        (("id",), "layer #"),
        (("readable_name",), "op in this layer"),
        (("shape",), "output shape"),
        (("input_shapes", 0), "input 0 shape"),
        (("input_shapes", 1), "input 1 shape"),
    ]
    # Unzipping for later use
    column_ids = [t[0] for t in columns]
    column_names = [t[1] for t in columns]

    for column in relay.analysis_tools.get_analysis_columns(these_results):
        if column not in column_ids:
            import sys

            print(
                "Warning: missing column " + str(column) + ", is this intentional?",
                file=sys.stderr,
            )

    as_records = relay.analysis_tools.get_records(these_results, column_ids)

    df = pd.DataFrame.from_records(as_records, columns=column_names, index="layer #")

    # Make output prettier
    df = df.fillna(value="")

    from IPython.display import display, Markdown

    display(Markdown(f"### {name}"))
    display(df)