In [1]:
%cd /home/qizhu/Desktop/Work/mase/machop

import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity, get_logger

from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

from chop.passes.graph.utils import (
    deepcopy_mase_graph,
    get_mase_op,
    get_mase_type,
    get_node_actual_target,
    get_parent_name,
    get_similar_node_actual_target,
    match_a_pattern,
    get_node_target_by_name,
)

set_logging_verbosity("info")

logger = get_logger("chop")
logger.setLevel(logging.INFO)

batch_size = 8
model_name = "jsc-tiny"
dataset_name = "jsc"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

dummy_in = {"x": next(iter(data_module.train_dataloader()))[0]}

/home/qizhu/Desktop/Work/mase/machop


  from .autonotebook import tqdm as notebook_tqdm
[32mINFO    [0m [34mSet logging level to info[0m


In [2]:
from torch import nn
from chop.passes.graph.utils import get_parent_name

# define a new model
class JSC_Three_Linear_Layers(nn.Module):
    def __init__(self):
        super(JSC_Three_Linear_Layers, self).__init__()
        self.seq_blocks = nn.Sequential(
            nn.BatchNorm1d(16),  # 0
            nn.ReLU(),  # 1
            nn.Linear(16, 16),  # linear seq_2
            nn.ReLU(),  # 3
            nn.Linear(16, 16),  # linear seq_4
            nn.ReLU(),  # 5
            nn.Linear(16, 5),  # linear seq_6
            nn.ReLU(),  # 7
        )

    def forward(self, x):
        return self.seq_blocks(x)

model = JSC_Three_Linear_Layers()

# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

In [6]:
for node in mg.fx_graph.nodes:
    print(node.meta["mase"].parameters["common"]["results"].keys())

dict_keys(['data_out_0'])
dict_keys(['data_out_0'])
dict_keys(['data_out_0'])
dict_keys(['data_out_0'])
dict_keys(['data_out_0'])
dict_keys(['data_out_0'])
dict_keys(['data_out_0'])
dict_keys(['data_out_0'])
dict_keys(['data_out_0'])
dict_keys(['data_out_0'])


In [3]:
# def instantiate_linear(in_features, out_features, bias):
#     if bias is not None:
#         bias = True
#     return nn.Linear(
#         in_features=in_features,
#         out_features=out_features,
#         bias=bias)

# def redefine_linear_transform_pass(graph, pass_args=None):
#     main_config = pass_args.pop('config')
#     default = main_config.pop('default', None)
#     if default is None:
#         raise ValueError(f"default value must be provided.")
#     i = 0
#     for node in graph.fx_graph.nodes:
#         i += 1
#         # if node name is not matched, it won't be tracked
#         config = main_config.get(node.name, default)['config']
#         name = config.get("name", None)
#         if name is not None:
#             ori_module = graph.modules[node.target]
#             if name == "relu":
#                 new_module = nn.ReLU(node.meta["mase"].parameters["common"]["args"]["data_in_0"]["shape"][1] * config["channel_multiplier"])
#             else:
#                 in_features = ori_module.in_features
#                 out_features = ori_module.out_features
#                 bias = ori_module.bias
#                 if name == "output_only":
#                     out_features = out_features * config["channel_multiplier"]
#                 elif name == "both":
#                     in_features = in_features * config["channel_multiplier"][0]
#                     out_features = out_features * config["channel_multiplier"][1]
#                 elif name == "input_only":
#                     in_features = in_features * config["channel_multiplier"]
#                 new_module = instantiate_linear(in_features, out_features, bias)
#             parent_name, name = get_parent_name(node.target)
#             print(parent_name, name)
#             setattr(graph.modules[parent_name], name, new_module)
#     return graph, {}
from chop.passes.graph import redefine_linear_transform_pass


pass_config = {
"by": "name",
"default": {"config": {"name": None}},
"seq_blocks_2": {
    "config": {
        "name": "linear",
        # weight
        "channel_multiplier": 2,
        }
    },
"seq_blocks_4": {
    "config": {
        "name": "linear",
        # "input_channel_multiplier": 2,
        "channel_multiplier": 4,
        }
    },
"seq_blocks_6": {
    "config": {
        "name": "linear",
        "channel_multiplier": 1,
        }
    },
}

# this performs the architecture transformation based on the config
mg, _ = redefine_linear_transform_pass(
    graph=mg, pass_args=pass_config)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

In [4]:
print(mg.model)

GraphModule(
  (seq_blocks): Module(
    (0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=64, bias=True)
    (5): ReLU()
    (6): Linear(in_features=64, out_features=20, bias=True)
    (7): ReLU()
  )
)



def forward(self, x):
    seq_blocks_0 = getattr(self.seq_blocks, "0")(x);  x = None
    seq_blocks_1 = getattr(self.seq_blocks, "1")(seq_blocks_0);  seq_blocks_0 = None
    seq_blocks_2 = getattr(self.seq_blocks, "2")(seq_blocks_1);  seq_blocks_1 = None
    seq_blocks_3 = getattr(self.seq_blocks, "3")(seq_blocks_2);  seq_blocks_2 = None
    seq_blocks_4 = getattr(self.seq_blocks, "4")(seq_blocks_3);  seq_blocks_3 = None
    seq_blocks_5 = getattr(self.seq_blocks, "5")(seq_blocks_4);  seq_blocks_4 = None
    seq_blocks_6 = getattr(self.seq_blocks, "6")(seq_blocks_5);  seq_blocks_5 = None
    seq_blocks_7 = getattr

In [9]:
for (i, node) in enumerate(mg.fx_graph.nodes):
    mase_meta = node.meta["mase"].parameters
    mase_op = mase_meta["common"]["mase_op"]
    mase_type = mase_meta["common"]["mase_type"] 
    parent_name, name = get_parent_name(node.target)
    print(mase_op)

placeholder
batch_norm1d
relu
linear
relu
linear
relu
linear
relu
output
