In [1]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.tools.checkpoint_load import load_model
from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")


  from .autonotebook import tqdm as notebook_tqdm
[32mINFO    [0m [34mSet logging level to info[0m


In [2]:
batch_size = 512
model_name = "jsc-tiny"
dataset_name = "jsc"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
    # custom_dataset_cache_path="../../chop/dataset"
)
data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False,
    checkpoint = None)

# LAB1_CUSTOM_PATH = "/home/bkt123/dev/advanced-deep-learning-systems/mase/mase_output/lab-1_jsc-custom/software/training_ckpts/best.ckpt"
# model = load_model(load_name=LAB1_CUSTOM_PATH, load_type="pl", model=model)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
    max_batches=1
)

dummy_in = next(iter(input_generator))
_ = model(**dummy_in)

# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)

In [3]:
# from chop.actions import train
# import torch

# # print(isinstance(mg.model, torch.nn.Module))

# model = mg.model
# model_info = get_model_info('jsc-tiny')
# dataset_info = get_dataset_info('jsc')
# task = "cls"

# train_params = {
#     "model": model,
#     "model_info": model_info,
#     "data_module": data_module,
#     "dataset_info": dataset_info,
#     "task": task,
#     "optimizer": "adam",
#     "learning_rate": 1e-3,
#     "weight_decay": 0,
#     "plt_trainer_args": {
#         "max_epochs": 1,
#     }, 
#     "auto_requeue": False,
#     "save_path": None,
#     "visualizer": None,
#     "load_name": None,
#     "load_type": None
# }

# train(**train_params)

In [7]:
from pprint import pprint

from chop.passes.graph.utils import get_mase_op

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

# pprint(mg.meta['mase'].__dict__)

for node in mg.fx_graph.nodes:
    if get_mase_op(node) == 'linear':
        if node.op == "call_module":
            #print(node.name)
            #print(50*'-')
            #print(node.target)
            # print(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
            # print(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
            # print(node.meta['mase'].parameters['common']['args']['weight']['value'])
            # print(node.meta['mase'].parameters['common']['results']['data_out_0']['value'])
            pprint(mg.modules[node.target].weight)
            print(50*'-')
            pprint(mg.modules[node.target].state_dict())
            #pprint(mg.mo)
            
#print(mg.modules)
#print(type(mg.model))

Parameter containing:
tensor([[ 0.1452, -0.0541, -0.0288,  0.1386,  0.0652,  0.0067,  0.1749, -0.2419,
         -0.1334,  0.1519,  0.1223, -0.0931, -0.0446,  0.0301,  0.1822, -0.0948],
        [ 0.0662, -0.1325,  0.1769,  0.0744,  0.0389,  0.1381, -0.2016, -0.0752,
         -0.1015,  0.1191,  0.0031, -0.2110, -0.0728, -0.1725,  0.1924, -0.1513],
        [ 0.1350, -0.2261, -0.2043,  0.1860,  0.2388, -0.1962, -0.0749, -0.0336,
         -0.2333,  0.0880,  0.0954,  0.0788, -0.0431,  0.0796, -0.1077,  0.0506],
        [-0.0026, -0.1182,  0.2398,  0.0558,  0.1095, -0.0805,  0.2243,  0.1972,
         -0.1097, -0.1086,  0.0303,  0.1558,  0.0597, -0.2079, -0.0612,  0.1764],
        [ 0.0299, -0.0387,  0.1119,  0.1972, -0.1651, -0.0459,  0.1382,  0.1111,
         -0.2060, -0.1861,  0.0973, -0.0466, -0.2340,  0.0902, -0.1532,  0.0789]],
       requires_grad=True)
OrderedDict([('weight',
              tensor([[ 0.1452, -0.0541, -0.0288,  0.1386,  0.0652,  0.0067,  0.1749, -0.2419,
         -0.1334

In [None]:
from chop.passes.graph.transforms import (
    prune_transform_pass,
    activation_pruning_pass,
)
pass_args = {
    "activation":{
        "scope" : "local",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.5,
    },
}
 
mg, _ = activation_pruning_pass(mg, pass_args)



In [8]:

for node in mg.fx_graph.nodes:
    if get_mase_op(node) == 'linear':
        if node.op == "call_module":
            #print(node.name)
            #print(50*'-')
            #print(node.target)
            # print(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
            # print(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
            # print(node.meta['mase'].parameters['common']['args']['weight']['value'])
            # print(node.meta['mase'].parameters['common']['results']['data_out_0']['value'])
            pprint(mg.modules[node.target].weight)
            print(50*'-')
            pprint(mg.modules[node.target].state_dict())

Parameter containing:
tensor([[ 0.1452, -0.0541, -0.0288,  0.1386,  0.0652,  0.0067,  0.1749, -0.2419,
         -0.1334,  0.1519,  0.1223, -0.0931, -0.0446,  0.0301,  0.1822, -0.0948],
        [ 0.0662, -0.1325,  0.1769,  0.0744,  0.0389,  0.1381, -0.2016, -0.0752,
         -0.1015,  0.1191,  0.0031, -0.2110, -0.0728, -0.1725,  0.1924, -0.1513],
        [ 0.1350, -0.2261, -0.2043,  0.1860,  0.2388, -0.1962, -0.0749, -0.0336,
         -0.2333,  0.0880,  0.0954,  0.0788, -0.0431,  0.0796, -0.1077,  0.0506],
        [-0.0026, -0.1182,  0.2398,  0.0558,  0.1095, -0.0805,  0.2243,  0.1972,
         -0.1097, -0.1086,  0.0303,  0.1558,  0.0597, -0.2079, -0.0612,  0.1764],
        [ 0.0299, -0.0387,  0.1119,  0.1972, -0.1651, -0.0459,  0.1382,  0.1111,
         -0.2060, -0.1861,  0.0973, -0.0466, -0.2340,  0.0902, -0.1532,  0.0789]],
       requires_grad=True)
--------------------------------------------------
OrderedDict([('weight',
              tensor([[ 0.1452, -0.0541, -0.0288,  0.1386,  

In [None]:
from pprint import pprint

from chop.passes.graph.utils import get_mase_op

# pprint(mg.meta['mase'].__dict__)

for node in mg.fx_graph.nodes:
    #if get_mase_op(node) == 'linear':
    if node.op == "call_module":
        #print(f"Layer: {node.name}")
        # pprint(node.meta['mase'].parameters['common'])
        # pprint(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
        #pprint(node.meta['mase'].parameters['common']['args'])
        pprint(mg.modules[node.target]._forward_pre_hooks.items())
        #pprint(mg.modules[node.target].parametrizations['weight'][0].mask)
       # print(mg.model.state_dict())
        # pprint(node.meta['mase'].parameters['common']['results']['data_out_0']['value'])
        #print(mg.modules[node.target].parametrizations['weight'][0].mask == mg.modules[node.target].parametrizations['weight'][1].mask)
        # total_w = 0
        # pruned_w = 0
        # w = mg.modules[node.target].weight
        # for s in w:
        #     total_w += s.numel()
        #     pruned_w += s.numel() - s.nonzero().numel()

        # pruned_percent = pruned_w / total_w
        # print(f"Pruned percent: {pruned_percent}")

        # print(50*'-')


odict_items([(0, <function get_activation_hook.<locals>.sparsify_input at 0x7f9810134af0>)])
odict_items([])
odict_items([])
odict_items([(1, <function get_activation_hook.<locals>.sparsify_input at 0x7f9810134b80>)])
odict_items([])
odict_items([(2, <function get_activation_hook.<locals>.sparsify_input at 0x7f9810134c10>)])
odict_items([])
odict_items([(3, <function get_activation_hook.<locals>.sparsify_input at 0x7f9810134ca0>)])
odict_items([])
odict_items([])
odict_items([(4, <function get_activation_hook.<locals>.sparsify_input at 0x7f9810134d30>)])


In [None]:
from pprint import pprint

from chop.passes.graph.utils import get_mase_op

# pprint(mg.meta['mase'].__dict__)

for node in mg.fx_graph.nodes:
    if get_mase_op(node) == 'linear':
        print(f"Layer: {node.name}")
        # pprint(node.meta['mase'].parameters['common'])
        # pprint(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
        pprint(node.meta['mase'].parameters['common']['args'])
        #pprint(mg.modules[node.target].weight)
        # pprint(mg.modules[node.target].parametrizations['weight'][0].mask)
        # pprint(node.meta['mase'].parameters['common']['results']['data_out_0']['value'])

        # total_w = 0
        # pruned_w = 0
        # mask_2= mg.modules[node.target].parametrizations['weight'][0].mask
        # for s in mask_2:
        #     total_w += s.numel()
        #     pruned_w += s.numel() - s.nonzero().numel()

        # pruned_percent = pruned_w / total_w
        # print(f"Pruned percent: {pruned_percent}")

        print(50*'-')
#print(mg.model.state_dict())

Layer: linear
{'bias': {'from': None,
          'precision': [32],
          'shape': [10],
          'type': 'float',
          'value': Parameter containing:
tensor([ 0.1145,  0.0742,  0.1038, -0.0820,  0.0053,  0.0576,  0.0494,  0.0499,
        -0.0856, -0.0203], requires_grad=True)},
 'data_in_0': {'precision': [32],
               'shape': [512, 64],
               'torch_dtype': torch.float32,
               'type': 'float',
               'value': tensor([[0.0163, 0.0360, 0.0481,  ..., 0.0202, 0.0000, 0.0566],
        [0.0308, 0.0515, 0.0366,  ..., 0.0317, 0.0013, 0.0747],
        [0.0456, 0.0404, 0.0565,  ..., 0.0366, 0.0002, 0.0623],
        ...,
        [0.0108, 0.0597, 0.0219,  ..., 0.0206, 0.0006, 0.0814],
        [0.0336, 0.0452, 0.0330,  ..., 0.0347, 0.0012, 0.0821],
        [0.0238, 0.0463, 0.0337,  ..., 0.0249, 0.0008, 0.0804]],
       grad_fn=<ViewBackward0>)},
 'weight': {'from': None,
            'precision': [32],
            'shape': [10, 64],
            'type': '

In [None]:
from chop.passes.graph.transforms import (
    prune_transform_pass,
)
pass_args = {
    "weight":{
        "scope" : "global",
        "granularity" : "elementwise",
        "method" :  "l1-norm",
        "sparsity" : 0.6,
    },
    "activation":{
        "scope" : "local",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.6,
    },
}
 
mg, _ = prune_transform_pass(mg, pass_args)


In [None]:
print(mg.model.state_dict())
#print(node.meta['mase'].parameters['common']['args']['weight']['value'])

OrderedDict([('block_1.0.bias', tensor([-0.1754, -0.0768, -0.1785,  0.0166,  0.0276,  0.1453, -0.1076,  0.0642])), ('block_1.0.parametrizations.weight.original', tensor([[[[ 5.1704e-02,  1.2624e-01, -7.0364e-02],
          [ 8.7340e-02, -2.5986e-02,  3.7529e-02],
          [ 1.1660e-01,  2.4349e-02,  1.8000e-01]],

         [[-1.0022e-02, -9.4345e-02, -6.9259e-02],
          [ 1.5370e-01,  1.2753e-01,  3.1925e-02],
          [-1.1334e-01,  1.6628e-01,  8.8755e-02]],

         [[-8.4247e-02, -1.4654e-01, -1.2246e-01],
          [ 9.5408e-02,  8.9733e-02,  3.2500e-02],
          [ 3.8013e-02, -2.2334e-02,  1.1118e-01]]],


        [[[-7.5120e-02, -7.2599e-02,  1.7118e-01],
          [ 1.0351e-01,  7.9553e-02, -7.9617e-02],
          [-1.4024e-01,  2.5723e-02, -5.1885e-02]],

         [[-1.7011e-01, -1.9724e-02,  1.7424e-01],
          [-2.9218e-02,  1.9196e-01, -1.2844e-01],
          [-3.1470e-02,  7.6522e-02,  5.6215e-03]],

         [[ 5.6314e-02,  6.3420e-02,  1.8989e-01],
          

In [None]:
# from chop.actions import train
# import torch

# # print(isinstance(mg.model, torch.nn.Module))

# model = mg.model
# model_info = get_model_info('jsc-tiny')
# dataset_info = get_dataset_info('jsc')
# task = "cls"

# train_params = {
#     "model": model,
#     "model_info": model_info,
#     "data_module": data_module,
#     "dataset_info": dataset_info,
#     "task": task,
#     "optimizer": "adam",
#     "learning_rate": 1e-3,
#     "weight_decay": 0,
#     "plt_trainer_args": {
#         "max_epochs": 1,
#     }, 
#     "auto_requeue": False,
#     "save_path": None,
#     "visualizer": None,
#     "load_name": None,
#     "load_type": None
# }

# train(**train_params)

In [None]:
# from pprint import pprint

# from chop.passes.graph.utils import get_mase_op


# # pprint(mg.meta['mase'].__dict__)

for node in mg.fx_graph.nodes:
    if get_mase_op(node) == 'linear':
        print(node.name)
        print(50*'-')
        # pprint(node.meta['mase'].parameters['common'])
        # pprint(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
        pprint(node.meta['mase'].parameters['common']['args']['weight']['value'])
        #pprint(mg.modules[node.target].weight)
        #pprint(mg.modules[node.target].parametrizations['weight'][0].mask)
        # pprint(node.meta['mase'].parameters['common']['results']['data_out_0']['value'])

        print(50*'-')

linear
--------------------------------------------------
Parameter containing:
tensor([[ 0.1231,  0.1070,  0.0702, -0.0399, -0.0537,  0.0388,  0.0295, -0.0526,
         -0.0089, -0.1205, -0.0649, -0.0419,  0.0919, -0.0557,  0.0272, -0.0472,
         -0.0578, -0.0621, -0.0587,  0.0975,  0.0703,  0.0270, -0.0734,  0.0563,
         -0.0547, -0.1000,  0.0131,  0.0098, -0.0884,  0.0028, -0.0037, -0.0332,
          0.0846,  0.0170, -0.0016, -0.0207,  0.0802, -0.1033,  0.0559,  0.0725,
         -0.0536, -0.0952, -0.0811,  0.0894,  0.0623, -0.0156,  0.0636, -0.0475,
          0.0232,  0.0176, -0.0543, -0.0095,  0.0646, -0.0257,  0.1148, -0.1008,
         -0.0753, -0.0745,  0.1248,  0.0132, -0.0006, -0.0839, -0.0127,  0.0429],
        [ 0.1139,  0.1173, -0.0668,  0.0785,  0.0143,  0.1133, -0.0583, -0.0021,
          0.0836, -0.0153,  0.0491,  0.0481, -0.0030, -0.0932, -0.0828, -0.0406,
          0.1122,  0.1198,  0.0265, -0.1046, -0.0704, -0.0071, -0.0817, -0.0882,
         -0.0713,  0.0616,  