In [1]:
from faulthandler import disable
from unittest import result
from SlicingMachine import TVMSlicer
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf
import tvm
import tvm.relay as relay
from tvm.contrib import graph_executor 
import numpy as np
import json
import pygraphviz as pgv
from argparse import ArgumentParser
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *

class UnetPreProcessCallback(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, require_type=False):
        super().__init__(require_type)
        super().__init__(rewrite_once=True)

        self.var2 = wildcard()
        tuple_node = is_tuple([wildcard(), self.var2])
        concat_node = is_op('concatenate')(tuple_node)
        self.pattern = concat_node
        self.match_node = []
        self.match_node2 = []

    def callback(self, pre, post, node_map):
        var2 = node_map[self.var2][0]
        self.match_node.append(var2)
        self.match_node2.append(pre)
        return pre 
        
class UnetCallback(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, match_node, require_type=False):
        super().__init__(require_type)
        super().__init__(rewrite_once=True)

        self.tuple_get_item_node = is_tuple_get_item(wildcard(), 0)
        self.pattern_1 = self.tuple_get_item_node

        self.pattern = self.pattern_1 
        self.match_node = match_node
        self.counter = 0
        self.tmp = []

    def quant(self, node):
        cast_to_int8 = relay.cast(
            relay.clip(
                relay.round(
                    relay.multiply(node, relay.const(8.0))
                ), 
                a_min=-127.0, a_max=127.0
            ),
            dtype="int8"
        )
        result_node = relay.annotation.stop_fusion(cast_to_int8)
        self.tmp.append(result_node)
        return result_node

    def dequant(self, node):
        cast_to_float32 = relay.divide(
            relay.cast(node, dtype='float32'), relay.const(8.0)
        )
        return cast_to_float32

    def callback(self, pre, post, node_map):
        if self.pattern_1.match(pre):
            if pre in self.match_node:
                # print("pat 1")
                return self.dequant(self.quant(post))
        return post

class UnetCallback2(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, match_node, require_type=False):
        super().__init__(require_type)
        super().__init__(rewrite_once=True)

        # self.tuple_get_item_node = is_tuple_get_item(wildcard(), 0)
        # self.pattern_1 = self.tuple_get_item_node
        self.var2 = wildcard()
        tuple_node = is_tuple([wildcard(), self.var2])
        concat_node = is_op('concatenate')(tuple_node)
        self.pattern = concat_node
        # self.pattern = self.pattern_1 
        self.match_node = match_node
        self.counter = 0
        self.tmp = []

    def quant(self, node):
        cast_to_int8 = relay.cast(
            relay.clip(
                relay.round(
                    relay.multiply(node, relay.const(8.0))
                ), 
                a_min=-127.0, a_max=127.0
            ),
            dtype="int8"
        )
        result_node = relay.annotation.stop_fusion(cast_to_int8)
        self.tmp.append(result_node)
        return result_node

    def dequant(self, node):
        cast_to_float32 = relay.divide(
            relay.cast(node, dtype='float32'), relay.const(8.0)
        )
        return cast_to_float32

    def callback(self, pre, post, node_map):
        if self.pattern.match(pre):
            if pre in self.match_node:
                # print("pat 1")
                return self.dequant(self.quant(post))
        return post


class UnetMaxPool2dCallback(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, require_type=False):
        super().__init__(require_type)
        super().__init__(rewrite_once=True)

        max_pool2d_node = is_op('nn.max_pool2d')(wildcard())
        self.pattern = max_pool2d_node
        self.match_node = []

    def callback(self, pre, post, node_map):
        self.match_node.append(pre)
        return post


class UnetCallback3(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, match_node, require_type=False):
        super().__init__(require_type)
        super().__init__(rewrite_once=True)

        # self.tuple_get_item_node = is_tuple_get_item(wildcard(), 0)
        # self.pattern_1 = self.tuple_get_item_node
        max_pool2d_node = is_op('nn.max_pool2d')(wildcard())
        self.pattern = max_pool2d_node
        self.match_node = match_node
        self.counter = 0
        self.tmp = []

    def quant(self, node):
        cast_to_int8 = relay.cast(
            relay.clip(
                relay.round(
                    relay.multiply(node, relay.const(8.0))
                ), 
                a_min=-127.0, a_max=127.0
            ),
            dtype="int8"
        )
        result_node = relay.annotation.stop_fusion(cast_to_int8)
        self.tmp.append(result_node)
        return result_node

    def dequant(self, node):
        cast_to_float32 = relay.divide(
            relay.cast(node, dtype='float32'), relay.const(8.0)
        )
        return cast_to_float32

    def callback(self, pre, post, node_map):
        # print("match pool2d")

        if self.pattern.match(pre):
            if pre in self.match_node:
                # print("pat 1")
                return self.dequant(self.quant(post))
        return post

class UnetLeakyReLUCallback(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, require_type=False):
        super().__init__(require_type)
        super().__init__(rewrite_once=True)

        leaky_relu_node = is_op('nn.leaky_relu')(wildcard())
        self.pattern = leaky_relu_node
        self.match_node = []

    def callback(self, pre, post, node_map):
        self.match_node.append(pre)
        return post


class UnetCallback4(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, match_node, require_type=False):
        super().__init__(require_type)
        super().__init__(rewrite_once=True)

        # self.tuple_get_item_node = is_tuple_get_item(wildcard(), 0)
        # self.pattern_1 = self.tuple_get_item_node
        leaky_relu_node = is_op('nn.leaky_relu')(wildcard())
        self.pattern = leaky_relu_node
        self.match_node = match_node
        self.counter = 0
        self.tmp = []

    def quant(self, node):
        cast_to_int8 = relay.cast(
            relay.clip(
                relay.round(
                    relay.multiply(node, relay.const(8.0))
                ), 
                a_min=-127.0, a_max=127.0
            ),
            dtype="int8"
        )
        result_node = relay.annotation.stop_fusion(cast_to_int8)
        self.tmp.append(result_node)
        return result_node

    def dequant(self, node):
        cast_to_float32 = relay.divide(
            relay.cast(node, dtype='float32'), relay.const(8.0)
        )
        return cast_to_float32

    def callback(self, pre, post, node_map):
        # print("match leaky_relu_node")

        if self.pattern.match(pre):
            if pre in self.match_node:
                # print("pat 1")
                return self.dequant(self.quant(post))
        return post

class Int8Collector(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, require_type=False):
        super().__init__(require_type)
        super().__init__(rewrite_once=True)

        int8_cast_node = is_op('cast')(wildcard()).has_attr({'dtype': 'int8'})

        self.pattern = int8_cast_node
        self.match_node = []

    def callback(self, pre, post, node_map):
        # print(pre)
        self.match_node.append(pre)
        return post

In [3]:
model_config = [2, 0, 0, 0]
np.random.seed(0)
img_size = 256
input_data = np.random.normal(0,1,(1,img_size,img_size,3)).astype(np.float32)
model_keras = tf.keras.models.load_model("UNet_M[{}-{}-{}-{}].h5".format(*model_config))

# tvm result
input_data = input_data.transpose([0, 3, 1, 2])
shape_dict = {"input_1": input_data.shape}
mod, params = relay.frontend.from_keras(model_keras, shape_dict)
target = 'cuda'
dev = tvm.cuda()

In [4]:
quantization_level = 2
upc = UnetPreProcessCallback()
out = rewrite(upc, mod['main'])

if quantization_level == 0:
    maxpool = UnetMaxPool2dCallback()
    rewrite(maxpool, out)
    out = relay.Function(out.params, relay.Tuple(upc.match_node + upc.match_node2 + maxpool.match_node + [out.body]), out.ret_type, out.type_params, out.attrs)
else:
    uc = UnetCallback(upc.match_node)
    out = rewrite(uc, mod['main'])
    upc = UnetPreProcessCallback()
    rewrite(upc, out)
    uc2 = UnetCallback2(upc.match_node2)
    out = rewrite(uc2, out)
    
    if quantization_level == 1:
        out = relay.Function(out.params, relay.Tuple(uc.tmp + [out.body]), out.ret_type, out.type_params, out.attrs)

    elif quantization_level == 2:

        upc = UnetMaxPool2dCallback()
        rewrite(upc, out)
        # print(len(upc.match_node))
        uc2 = UnetCallback3(upc.match_node)
        out = rewrite(uc2, out)

        upc = UnetLeakyReLUCallback()
        rewrite(upc, out)
        # print(len(upc.match_node))
        uc2 = UnetCallback4(upc.match_node)
        out = rewrite(uc2, out)

        int8_collector = Int8Collector()
        rewrite(int8_collector, out)
        kk = relay.Function(out.params, relay.Tuple(int8_collector.match_node + [out.body]), out.ret_type, out.type_params, out.attrs)


In [6]:
len(out.body)

TypeError: object of type 'Call' has no len()

In [8]:
len(int8_collector.match_node)

30

In [9]:
len(list(set(int8_collector.match_node)))

30

In [28]:
int8_collector.match_node[1]

CallNode(Op(cast), [CallNode(Op(clip), [CallNode(Op(round), [CallNode(Op(multiply), [CallNode(Op(nn.max_pool2d), [TupleGetItemNode(CallNode(Op(nn.batch_norm), [CallNode(Op(nn.bias_add), [CallNode(Op(nn.conv2d), [CallNode(Op(divide), [CallNode(Op(cast), [CallNode(Op(annotation.stop_fusion), [CallNode(Op(cast), [CallNode(Op(clip), [CallNode(Op(round), [CallNode(Op(multiply), [CallNode(Op(nn.leaky_relu), [TupleGetItemNode(CallNode(Op(nn.batch_norm), [CallNode(Op(nn.bias_add), [CallNode(Op(nn.conv2d), [Var(input_1, ty=TensorType([1, 3, 256, 256], float32)), Var(_param_1, ty=TensorType([64, 3, 3, 3], float32))], relay.attrs.Conv2DAttrs(0x1ab98cf8), []), Var(_param_2, ty=TensorType([64], float32))], relay.attrs.BiasAddAttrs(0x20ca7328), []), Var(_param_3, ty=TensorType([64], float32)), Var(_param_4, ty=TensorType([64], float32)), Var(_param_5, ty=TensorType([64], float32)), Var(_param_6, ty=TensorType([64], float32))], relay.attrs.BatchNormAttrs(0x20825098), []), 0)], relay.attrs.LeakyReluAt

In [30]:
str(int8_collector.match_node[-8])

'free_var %input_1: Tensor[(1, 3, 256, 256), float32];\nfree_var %v_param_1: Tensor[(64, 3, 3, 3), float32];\n%0 = nn.conv2d(%input_1, %v_param_1, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]);\nfree_var %v_param_2: Tensor[(64), float32];\n%1 = nn.bias_add(%0, %v_param_2);\nfree_var %v_param_3: Tensor[(64), float32];\nfree_var %v_param_4: Tensor[(64), float32];\nfree_var %v_param_5: Tensor[(64), float32];\nfree_var %v_param_6: Tensor[(64), float32];\n%2 = nn.batch_norm(%1, %v_param_3, %v_param_4, %v_param_5, %v_param_6, epsilon=0.001f);\n%3 = %2.0;\n%4 = nn.leaky_relu(%3, alpha=0.2f);\n%5 = multiply(%4, 8f);\n%6 = round(%5);\n%7 = clip(%6, a_min=-127f, a_max=127f);\n%8 = cast(%7, dtype="int8");\n%9 = annotation.stop_fusion(%8);\n%10 = cast(%9, dtype="float32");\n%11 = divide(%10, 8f);\nfree_var %v_param_7: Tensor[(64, 64, 3, 3), float32];\n%12 = nn.conv2d(%11, %v_param_7, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]);\nfree_var %v_param_8: Tensor[(64), float32];\n%13 

In [32]:
type(int8_collector.match_node[1)

tvm.relay.expr.Call

In [31]:
str(int8_collector.match_node[1]) == str(int8_collector.match_node[-8])

True

In [11]:
a = [    [
      5, 
      0, 
      0
    ], 
    [
      12, 
      0, 
      0
    ], 
    [
      18, 
      0, 
      0
    ], 
    [
      24, 
      0, 
      0
    ], 
    [
      26, 
      0, 
      0
    ], 
    [
      32, 
      0, 
      0
    ], 
    [
      38, 
      0, 
      0
    ], 
    [
      40, 
      0, 
      0
    ], 
    [
      46, 
      0, 
      0
    ], 
    [
      52, 
      0, 
      0
    ], 
    [
      54, 
      0, 
      0
    ], 
    [
      60, 
      0, 
      0
    ], 
    [
      71, 
      0, 
      0
    ], 
    [
      72, 
      0, 
      0
    ], 
    [
      78, 
      0, 
      0
    ], 
    [
      89, 
      0, 
      0
    ], 
    [
      90, 
      0, 
      0
    ], 
    [
      96, 
      0, 
      0
    ], 
    [
      107, 
      0, 
      0
    ], 
    [
      108, 
      0, 
      0
    ], 
    [
      114, 
      0, 
      0
    ], 
    [
      125, 
      0, 
      0
    ], 
    [
      12, 
      0, 
      0
    ], 
    [
      130, 
      0, 
      0
    ], 
    [
      132, 
      0, 
      0
    ], 
    [
      138, 
      0, 
      0
    ], 
    [
      144, 
      0, 
      0
    ], 
    [
      150, 
      0, 
      0
    ], 
    [
      151, 
      0, 
      0
    ], 
    [
      157, 
      0, 
      0
    ], 
    [
      165, 
      0, 
      0
    ]]

In [16]:
len(list(set([i[0] for i in a])))

30

In [17]:
len(a)

31