## Lab 4

In [10]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity, get_logger

from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")

logger = get_logger("chop")
logger.setLevel(logging.INFO)

batch_size = 8
model_name = "jsc-tiny"
dataset_name = "jsc"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

dummy_in = {"x": next(iter(data_module.train_dataloader()))[0]}

[32mINFO    [0m [34mSet logging level to info[0m


In [38]:
from torch import nn
from chop.passes.graph.utils import get_parent_name

# define a new model
class JSC_Three_Linear_Layers(nn.Module):
    def __init__(self):
        super(JSC_Three_Linear_Layers, self).__init__()
        self.seq_blocks = nn.Sequential(
            nn.BatchNorm1d(16),  # 0
            nn.ReLU(16),  # 1
            nn.Linear(16, 16),  # linear  2
            nn.ReLU(16),  # 2
            nn.Linear(16, 16),  # linear  3
            nn.ReLU(16),  # 3
            nn.Linear(16, 5),   # linear  4
            nn.ReLU(5),  # 5
        )

    def forward(self, x):
        return self.seq_blocks(x)


model = JSC_Three_Linear_Layers()

# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)
mg, _ = init_metadata_analysis_pass(mg, None)


In [21]:
mg.fx_graph

<torch.fx.graph.Graph at 0x7fe523e4c4d0>

In [39]:
from chop.passes.graph.analysis.report.report_graph import report_graph_analysis_pass

def instantiate_linear(in_features, out_features, bias):
    if bias is not None:
        bias = True
    return nn.Linear(
        in_features=in_features,
        out_features=out_features,
        bias=bias)

def instantiate_relu(in_features):
    return nn.ReLU(in_features)

def instantiate_batchnorm(in_features):
    return nn.BatchNorm1d(in_features)

def redefine_linear_transform_pass(graph, pass_args=None):
    main_config = pass_args.pop('config')
    default = main_config.pop('default', None)
    if default is None:
        raise ValueError(f"default value must be provided.")
    i = 0
    for node in graph.fx_graph.nodes:
        i += 1
        # if node name is not matched, it won't be tracked
        config = main_config.get(node.name, default)['config']
        name = config.get("name", None)

        next_node = node.next
        prev_node = node.prev
        if name is not None:
            ori_module = graph.modules[node.target]
            if not isinstance(ori_module, nn.Linear):
                raise ValueError(f"Node {node.name} is not a linear layer.")

            in_features = ori_module.in_features
            out_features = ori_module.out_features
            bias = ori_module.bias
            if name == "output_only" or name == "both":
                output_multiplier = config.get("output_multiplier", config.get("channel_multiplier"))
                if not output_multiplier:
                    logger.warning(f"Could not find output_multiplier or channel_multiplier for node {node.name}. Using value of 1.")
                    output_multiplier = 1
                out_features = out_features * output_multiplier
            elif name == "input_only" or name == "both":
                input_multiplier = config.get("input_multiplier", config["channel_multiplier"])
                if not input_multiplier:
                    logger.warning(f"Could not find input_multiplier or channel_multiplier for node {node.name}. Using value of 1.")
                    input_multiplier = 1
                in_features = in_features * input_multiplier

            # Find the previous linear module
            # All the previous modules should be either Linear, ReLU, or BatchNorm1d
            # The batchnorm1d and relu layers should be resized to the new in_features
            # The previous linear layer's output should be scaled to match the new in_features
            if name == "input_only" or name == "both":
                valid = False
                prev_node = node.prev
                prev_module = graph.modules.get(prev_node.target, None)
                while (prev_node and prev_module and not valid):
                    if isinstance(prev_module, nn.Linear):
                        valid = True
                    prev_node = prev_node.prev
                    prev_module = graph.modules.get(prev_node.target, None)
                
                if valid:
                    prev_node = node.prev
                    prev_module = graph.modules[prev_node.target]
                    while (not isinstance(prev_module, nn.Linear)):
                        if isinstance(prev_module, nn.ReLU):
                            new_prev_module = instantiate_relu(in_features)
                            parent_name, name = get_parent_name(prev_node.target)
                            setattr(graph.modules[parent_name], name, new_prev_module)
                        elif isinstance(prev_module, nn.BatchNorm1d):
                            new_prev_module = instantiate_batchnorm(in_features)
                            parent_name, name = get_parent_name(prev_node.target)
                            setattr(graph.modules[parent_name], name, new_prev_module)
                        prev_node = prev_node.prev
                        prev_module = graph.modules[prev_node.target]
                    assert isinstance(prev_module, nn.Linear)
                    new_prev_module = instantiate_linear(prev_module.in_features, in_features, prev_module.bias)
                    parent_name, name = get_parent_name(prev_node.target)
                    setattr(graph.modules[parent_name], name, new_prev_module)
                else:
                    logger.warning(f"Node {node.name} is not connected to a linear layer on the input side. " + 
                                   "Skipping input transformation.")
                    in_features = ori_module.in_features

            if name == "output_only" or name == "both":
                valid = False
                next_node = node.next
                next_module = graph.modules.get(next_node.target, None)
                while (next_node and not valid):
                    if isinstance(next_module, nn.Linear):
                        valid = True
                    next_node = next_node.prev
                    next_module = graph.modules.get(next_node.target, None)


                if valid:
                    next_node = node.next
                    next_module = graph.modules[next_node.target]
                    while (not isinstance(next_module, nn.Linear)):
                        if isinstance(next_module, nn.ReLU):
                            new_next_module = instantiate_relu(out_features)
                            parent_name, name = get_parent_name(next_node.target)
                            setattr(graph.modules[parent_name], name, new_next_module)
                        elif isinstance(next_module, nn.BatchNorm1d):
                            new_next_module = instantiate_batchnorm(out_features)
                            parent_name, name = get_parent_name(next_node.target)
                            setattr(graph.modules[parent_name], name, new_next_module)
                        next_node = next_node.next
                        next_module = graph.modules[next_node.target]
                    assert isinstance(next_module, nn.Linear)
                    new_next_module = instantiate_linear(out_features, next_module.out_features, next_module.bias)
                    parent_name, name = get_parent_name(next_node.target)
                    setattr(graph.modules[parent_name], name, new_next_module)
                else:
                    logger.warning(f"Node {node.name} is not connected to a linear layer on the output side." + 
                                   "Skipping output transformation.")
                    out_features = ori_module.out_features

             # Finally, set the new linear module.
            new_module = instantiate_linear(in_features, out_features, bias)
            parent_name, name = get_parent_name(node.target)
            setattr(graph.modules[parent_name], name, new_module)

    return graph, {}


pass_config = {
    "by": "name",
    "default": {"config": {"name": None}},
    "seq_blocks_2": {
        "config": {
            "name": "output_only",
            "channel_multiplier": 2,
        }
    },
    "seq_blocks_4": {
        "config": {
            "name": "both",
            "input_multiplier": 2,
            "output_multiplier": 4,
        }
    },
    "seq_blocks_6": {
        "config": {
            "name": "input_only",
            "channel_multiplier": 4,
        }
    },
}

# this performs the architecture transformation based on the config
mg, _ = redefine_linear_transform_pass(graph=mg, pass_args={"config": pass_config})

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)
mg, _ = report_graph_analysis_pass(mg, {})


graph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    %seq_blocks_4 : [num_users=1] = call_module[target=seq_blocks.4](args = (%seq_blocks_3,), kwargs = {})
    %seq_blocks_5 : [num_users=1] = call_module[target=seq_blocks.5](args = (%seq_blocks_4,), kwargs = {})
    %seq_blocks_6 : [num_users=1] = call_module[target=seq_blocks.6](args = (%seq_blocks_5,), kwargs = {})
    %seq_blocks_7 : [num_users=1] = call_module[target=seq_blocks.7](args = (%seq_blocks_6,), kwargs = {})
    return seq_blocks_7
Network overview:
{'placeholder': 1, 'get_attr': 0, 'call_function': 0, 'ca